diff --git a/src/kat_transform/__init__.py b/src/kat_transform/__init__.py index ea05d1b..a8ffe7e 100644 --- a/src/kat_transform/__init__.py +++ b/src/kat_transform/__init__.py @@ -46,6 +46,32 @@ def schema( return SchemaSpec(name, fields, meta) +def transform_value(value: typing.Any, spec: FieldSpec[typing.Any, typing.Any]) -> typing.Any: + if spec.transform is not None: + value = spec.transform(value) + + if isinstance(spec.output_type, SchemaSpec): + return transform(value) + + if not isinstance(spec.item_type, SchemaSpec): + return value + + if spec.is_typed_mutable_sequence: + sequence = typing.cast(tuple[frozenset[FieldValue]], value) + return [transform(subraw) for subraw in sequence] + + elif spec.is_typed_sequence: + sequence = typing.cast(tuple[frozenset[FieldValue]], value) + return tuple(transform(subraw) for subraw in sequence) + + elif spec.is_typed_mapping: + mapping = typing.cast(tuple[tuple[str, frozenset[FieldValue]]], value) + return {k: transform(vraw) for k, vraw in mapping} + + else: + raise RuntimeError("Unexpected behavior") + + def transform(raw: collections.abc.Set[FieldValue]) -> collections.abc.Mapping[str, typing.Any]: """ Transform input values of fields into final values using field's transformers @@ -58,15 +84,8 @@ def transform(raw: collections.abc.Set[FieldValue]) -> collections.abc.Mapping[s "They should be resolved using dependency injection" ) - value = field_value.value - spec = field_value.field_spec - - if isinstance(spec.output_type, SchemaSpec): - value = transform(value) - - elif spec.transform is not None: - value = spec.transform(field_value.value) - - transformed[spec.name] = value + transformed[field_value.field_spec.name] = transform_value( + field_value.value, field_value.field_spec + ) return transformed diff --git a/src/kat_transform/field.py b/src/kat_transform/field.py index 7a63e3b..0cb1fab 100644 --- a/src/kat_transform/field.py +++ b/src/kat_transform/field.py @@ -1,5 +1,6 @@ import typing import collections.abc +from functools import cache from dataclasses import dataclass @@ -38,6 +39,43 @@ class FieldSpec(typing.Generic[I, O]): Define metadata for this field """ + @property + @cache + def _origin(self): + return typing.get_origin(self.output_type) + + @property + @cache + def is_typed_mutable_sequence(self): + if self._origin is None: + return False + + return issubclass(self._origin, collections.abc.MutableSequence) + + @property + @cache + def is_typed_sequence(self): + if self._origin is None: + return False + + return issubclass(self._origin, collections.abc.Sequence) + + @property + @cache + def is_typed_mapping(self): + if self._origin is None: + return False + + return issubclass(self._origin, collections.abc.Mapping) + + @property + @cache + def item_type(self): + if self.is_typed_sequence: + return typing.get_args(self.output_type)[0] + elif self.is_typed_mapping: + return typing.get_args(self.output_type)[1] + def get(self, from_: typing.Any) -> I | O | ValueGetter: """ Get field input value from object diff --git a/src/kat_transform/schema.py b/src/kat_transform/schema.py index 58d3547..14b7812 100644 --- a/src/kat_transform/schema.py +++ b/src/kat_transform/schema.py @@ -7,6 +7,28 @@ from .metadata import SchemaMetadata +def get_by_item( + item_type: "SchemaSpec", + value: collections.abc.Mapping[str, typing.Any] | collections.abc.Sequence[typing.Any], + spec: FieldSpec[typing.Any, typing.Any], +) -> FieldValue: + if spec.is_typed_mapping: + mapping = typing.cast(collections.abc.Mapping[str, typing.Any], value) + + mapping_schema_fields = tuple( + (key, frozenset(item_type.get(item))) for key, item in mapping.items() + ) + return FieldValue(spec, mapping_schema_fields) + + elif spec.is_typed_sequence: + array = typing.cast(collections.abc.Sequence[typing.Any], value) + + array_schema_fields = [frozenset(item_type.get(item)) for item in array] + return FieldValue(spec, array_schema_fields) + else: + raise RuntimeError("Unexpected behavior") + + @dataclass(frozen=True) class SchemaSpec: """ @@ -31,6 +53,10 @@ def get(self, from_: typing.Any) -> set[FieldValue]: fields.add(FieldValue(spec, frozenset(sub_schema_fields))) continue + elif isinstance(spec.item_type, SchemaSpec): + fields.add(get_by_item(spec.item_type, typing.cast(typing.Any, field_value), spec)) + continue + fields.add(FieldValue(spec, field_value)) return fields diff --git a/tests/test_transform.py b/tests/test_transform.py index 91db6a7..58ac619 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -13,6 +13,54 @@ def test_transform(): assert transformed == {"field": "uppercase message"} +def test_transform_subschema(): + sub = schema("Sub", field(str, "name", transform=lambda x: x.upper())) + + spec = schema("Schema", field(sub, "sub")) + + raw = spec.get({"sub": {"name": "name"}}) + + transformed = transform(raw) + + assert transformed == {"sub": {"name": "NAME"}} + + +def test_transform_subschema_in_mutable_sequence(): + sub = schema("Sub", field(str, "name", transform=lambda x: x.upper())) + + spec = schema("Schema", field(list[sub], "sub")) + + raw = spec.get({"sub": [{"name": "name"}]}) + + transformed = transform(raw) + + assert transformed == {"sub": [{"name": "NAME"}]} + + +def test_transform_subschema_in_immutable_sequence(): + sub = schema("Sub", field(str, "name", transform=lambda x: x.upper())) + + spec = schema("Schema", field(tuple[sub], "sub")) + + raw = spec.get({"sub": [{"name": "name"}]}) + + transformed = transform(raw) + + assert transformed == {"sub": ({"name": "NAME"},)} + + +def test_transform_subschema_in_mapping(): + sub = schema("Sub", field(str, "name", transform=lambda x: x.upper())) + + spec = schema("Schema", field(dict[str, sub], "sub")) + + raw = spec.get({"sub": {"a": {"name": "name"}}}) + + transformed = transform(raw) + + assert transformed == {"sub": {"a": {"name": "NAME"}}} + + def test_transform_with_getter(): spec = schema("Schema", field(str, "field", getter=lambda: "value"))