Skip to content

Commit

Permalink
Fix schema class generation (#7522)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 9c20dd5e635ce570d4cf4f09cadf288c18115551
  • Loading branch information
voodoo11 authored and Manul from Pathway committed Jan 13, 2025
1 parent cafd025 commit d924453
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 20 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
### Changed
- **BREAKING**: `pw.io.deltalake.read` now requires explicit specification of primary key fields.

### Fixed
- `generate_class` method in `Schema` now correctly renders columns of `UnionType` and `None` types.

## [0.16.4] - 2025-01-09

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion python/pathway/internals/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def is_value_compatible(self, arg):
return self.wrapped.is_value_compatible(arg)

@cached_property
def typehint(self) -> type[UnionType]:
def typehint(self) -> UnionType:
return self.wrapped.typehint | None

def max_size(self) -> float:
Expand Down
46 changes: 27 additions & 19 deletions python/pathway/internals/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclasses import dataclass
from os import PathLike
from pydoc import locate
from types import MappingProxyType
from types import MappingProxyType, UnionType
from typing import TYPE_CHECKING, Any, NoReturn, get_type_hints
from warnings import warn

Expand Down Expand Up @@ -468,26 +468,34 @@ def generate_class(
run code (default is False)
"""

def get_type_definition_and_modules(type: object) -> tuple[str, list[str]]:
if type.__module__ != "builtins":
modules = [type.__module__]
type_definition = (
type.__module__
+ "."
+ type.__qualname__ # type:ignore[attr-defined]
)
else:
modules = []
type_definition = type.__qualname__ # type:ignore[attr-defined]
if not hasattr(type, "__origin__"):
return (type_definition, modules)
def get_type_definition_and_modules(_type: object) -> tuple[str, list[str]]:
modules = []
if _type in {type(None), None}:
type_repr = "None"
elif not hasattr(_type, "__qualname__"):
type_repr = repr(_type)
elif _type.__module__ != "builtins":
type_repr = f"{_type.__module__}.{_type.__qualname__}"
modules = [_type.__module__]
else:
args_definitions = []
for arg in type.__args__: # type:ignore[attr-defined]
definition, arg_modules = get_type_definition_and_modules(arg)
args_definitions.append(definition)
type_repr = _type.__qualname__

if hasattr(_type, "__args__"):
args = []
for arg in _type.__args__:
arg_repr, arg_modules = get_type_definition_and_modules(arg)
args.append(arg_repr)
modules += arg_modules
return (f"{type_definition}[{', '.join(args_definitions)}]", modules)

if isinstance(_type, UnionType):
return (" | ".join(args), modules)
else:
return (
f"{type_repr}[{', '.join(args)}]",
modules,
)
else:
return type_repr, modules

required_modules: StableSet[str] = StableSet()

Expand Down
33 changes: 33 additions & 0 deletions python/pathway/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def test_schema_class_generation(tmp_path: pathlib.Path):
"f": pw.column_definition(dtype=tuple[int, Any]),
"g": pw.column_definition(dtype=pw.DateTimeUtc),
"h": pw.column_definition(dtype=tuple[int, ...]),
"i": pw.column_definition(dtype=str | None),
"j": pw.column_definition(dtype=None),
},
name="Foo",
)
Expand All @@ -160,6 +162,37 @@ def test_schema_class_generation(tmp_path: pathlib.Path):
del sys.modules[module_name]


def test_schema_class_generation_from_auto_schema(tmp_path: pathlib.Path):
a = pw.Table.empty(
a=int,
b=str,
c=Any,
d=float,
e=tuple[int, Any],
f=pw.DateTimeUtc,
g=tuple[int, ...],
h=str | None,
i=None,
)

schema = a.schema
schema.__name__ = "Foo"

path = tmp_path / "foo.py"

module_name = "pathway_schema_test"

try:
schema.generate_class_to_file(path, class_name="Foo", generate_imports=True)
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
assert_same_schema(schema, module.Foo)
finally:
del sys.modules[module_name]


def test_schema_from_dict():
schema_definition = {
"col1": Any,
Expand Down

0 comments on commit d924453

Please sign in to comment.