diff --git a/src/kat_transform/__init__.py b/src/kat_transform/__init__.py index a8ffe7e..ce3faad 100644 --- a/src/kat_transform/__init__.py +++ b/src/kat_transform/__init__.py @@ -1,12 +1,14 @@ import typing import collections.abc + from .schema import SchemaSpec from .field import FieldSpec, I, O from .exceptions import FieldResolveError from .markers import ValueGetter, FieldValue from .metadata import SchemaMetadata, FieldMetadata from .resolve_fields import resolve_fields, resolve_getter +from .util import get_item_type, is_typed_mapping, is_typed_sequence __all__ = [ "field", @@ -53,20 +55,23 @@ def transform_value(value: typing.Any, spec: FieldSpec[typing.Any, typing.Any]) if isinstance(spec.output_type, SchemaSpec): return transform(value) - if not isinstance(spec.item_type, SchemaSpec): + item_type = get_item_type(spec.output_type) + if not isinstance(item_type, SchemaSpec): return value - if spec.is_typed_mutable_sequence: + if is_typed_sequence(spec.output_type): + assert isinstance( + value, collections.abc.Sequence + ), f"Expected sequence value, but got {type(value)}" sequence = typing.cast(tuple[frozenset[FieldValue]], value) return [transform(subraw) for subraw in sequence] - elif spec.is_typed_sequence: - sequence = typing.cast(tuple[frozenset[FieldValue]], value) - return tuple(transform(subraw) for subraw in sequence) - - elif spec.is_typed_mapping: - mapping = typing.cast(tuple[tuple[str, frozenset[FieldValue]]], value) - return {k: transform(vraw) for k, vraw in mapping} + elif is_typed_mapping(spec.output_type): + assert isinstance( + value, collections.abc.Mapping + ), f"Expected mapping value, but got {type(value)}" + mapping = typing.cast(collections.abc.Mapping[str, frozenset[FieldValue]], value) + return {k: transform(vraw) for k, vraw in mapping.items()} else: raise RuntimeError("Unexpected behavior") diff --git a/src/kat_transform/field.py b/src/kat_transform/field.py index 0cb1fab..512b586 100644 --- a/src/kat_transform/field.py +++ b/src/kat_transform/field.py @@ -1,7 +1,7 @@ import typing import collections.abc from functools import cache -from dataclasses import dataclass +from dataclasses import dataclass, field from .util import get_by_name @@ -39,43 +39,6 @@ class FieldSpec(typing.Generic[I, O]): Define metadata for this field """ - @property - @cache - def _origin(self): - return typing.get_origin(self.output_type) - - @property - @cache - def is_typed_mutable_sequence(self): - if self._origin is None: - return False - - return issubclass(self._origin, collections.abc.MutableSequence) - - @property - @cache - def is_typed_sequence(self): - if self._origin is None: - return False - - return issubclass(self._origin, collections.abc.Sequence) - - @property - @cache - def is_typed_mapping(self): - if self._origin is None: - return False - - return issubclass(self._origin, collections.abc.Mapping) - - @property - @cache - def item_type(self): - if self.is_typed_sequence: - return typing.get_args(self.output_type)[0] - elif self.is_typed_mapping: - return typing.get_args(self.output_type)[1] - def get(self, from_: typing.Any) -> I | O | ValueGetter: """ Get field input value from object diff --git a/src/kat_transform/schema.py b/src/kat_transform/schema.py index 14b7812..b5803af 100644 --- a/src/kat_transform/schema.py +++ b/src/kat_transform/schema.py @@ -2,9 +2,11 @@ import collections.abc from dataclasses import dataclass + from .field import FieldSpec from .markers import FieldValue from .metadata import SchemaMetadata +from .util import get_item_type, is_typed_mapping, is_typed_sequence def get_by_item( @@ -12,18 +14,16 @@ def get_by_item( value: collections.abc.Mapping[str, typing.Any] | collections.abc.Sequence[typing.Any], spec: FieldSpec[typing.Any, typing.Any], ) -> FieldValue: - if spec.is_typed_mapping: + if is_typed_mapping(spec.output_type): mapping = typing.cast(collections.abc.Mapping[str, typing.Any], value) - mapping_schema_fields = tuple( - (key, frozenset(item_type.get(item))) for key, item in mapping.items() - ) + mapping_schema_fields = {key: item_type.get(item) for key, item in mapping.items()} return FieldValue(spec, mapping_schema_fields) - elif spec.is_typed_sequence: + elif is_typed_sequence(spec.output_type): array = typing.cast(collections.abc.Sequence[typing.Any], value) - array_schema_fields = [frozenset(item_type.get(item)) for item in array] + array_schema_fields = [item_type.get(item) for item in array] return FieldValue(spec, array_schema_fields) else: raise RuntimeError("Unexpected behavior") @@ -39,7 +39,7 @@ class SchemaSpec: fields: collections.abc.Sequence[FieldSpec[typing.Any, typing.Any]] metadata: SchemaMetadata | None = None - def get(self, from_: typing.Any) -> set[FieldValue]: + def get(self, from_: typing.Any) -> frozenset[FieldValue]: """ Get input values of fields """ @@ -50,16 +50,16 @@ def get(self, from_: typing.Any) -> set[FieldValue]: if isinstance(spec.output_type, SchemaSpec): sub_schema_fields = spec.output_type.get(field_value) - fields.add(FieldValue(spec, frozenset(sub_schema_fields))) + fields.add(FieldValue(spec, sub_schema_fields)) continue - elif isinstance(spec.item_type, SchemaSpec): - fields.add(get_by_item(spec.item_type, typing.cast(typing.Any, field_value), spec)) + elif isinstance(item_type := get_item_type(spec.output_type), SchemaSpec): + fields.add(get_by_item(item_type, typing.cast(typing.Any, field_value), spec)) continue fields.add(FieldValue(spec, field_value)) - return fields + return frozenset(fields) def __hash__(self) -> int: return hash((self.name,) + tuple(self.fields)) diff --git a/src/kat_transform/util.py b/src/kat_transform/util.py index edf24a5..bbab813 100644 --- a/src/kat_transform/util.py +++ b/src/kat_transform/util.py @@ -1,3 +1,4 @@ +from functools import cache import typing import collections.abc @@ -18,3 +19,25 @@ def get_by_name(name: str, from_: typing.Any) -> typing.Any: assert hasattr(from_, name), f'{from_!r} has no attribute "{name}"' return getattr(from_, name) + + +@cache +def is_typed_sequence(annotation: typing.Any) -> bool: + origin = typing.get_origin(annotation) + return origin and issubclass(origin, collections.abc.Sequence) + + +@cache +def is_typed_mapping(annotation: typing.Any) -> bool: + origin = typing.get_origin(annotation) + return origin and issubclass(origin, collections.abc.Mapping) + + +@cache +def get_item_type(annotation: typing.Any) -> typing.Any: + args = typing.get_args(annotation) + if is_typed_sequence(annotation): + return args[0] + + if is_typed_mapping(annotation): + return args[1] diff --git a/tests/test_transform.py b/tests/test_transform.py index 58ac619..52117be 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -37,18 +37,6 @@ def test_transform_subschema_in_mutable_sequence(): assert transformed == {"sub": [{"name": "NAME"}]} -def test_transform_subschema_in_immutable_sequence(): - sub = schema("Sub", field(str, "name", transform=lambda x: x.upper())) - - spec = schema("Schema", field(tuple[sub], "sub")) - - raw = spec.get({"sub": [{"name": "name"}]}) - - transformed = transform(raw) - - assert transformed == {"sub": ({"name": "NAME"},)} - - def test_transform_subschema_in_mapping(): sub = schema("Sub", field(str, "name", transform=lambda x: x.upper()))