Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions src/kat_transform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import typing
import collections.abc


from .schema import SchemaSpec
from .field import FieldSpec, I, O
from .exceptions import FieldResolveError
from .markers import ValueGetter, FieldValue
from .metadata import SchemaMetadata, FieldMetadata
from .resolve_fields import resolve_fields, resolve_getter
from .util import get_item_type, is_typed_mapping, is_typed_sequence

__all__ = [
"field",
Expand Down Expand Up @@ -53,20 +55,23 @@ def transform_value(value: typing.Any, spec: FieldSpec[typing.Any, typing.Any])
if isinstance(spec.output_type, SchemaSpec):
return transform(value)

if not isinstance(spec.item_type, SchemaSpec):
item_type = get_item_type(spec.output_type)
if not isinstance(item_type, SchemaSpec):
return value

if spec.is_typed_mutable_sequence:
if is_typed_sequence(spec.output_type):
assert isinstance(
value, collections.abc.Sequence
), f"Expected sequence value, but got {type(value)}"
sequence = typing.cast(tuple[frozenset[FieldValue]], value)
return [transform(subraw) for subraw in sequence]

elif spec.is_typed_sequence:
sequence = typing.cast(tuple[frozenset[FieldValue]], value)
return tuple(transform(subraw) for subraw in sequence)

elif spec.is_typed_mapping:
mapping = typing.cast(tuple[tuple[str, frozenset[FieldValue]]], value)
return {k: transform(vraw) for k, vraw in mapping}
elif is_typed_mapping(spec.output_type):
assert isinstance(
value, collections.abc.Mapping
), f"Expected mapping value, but got {type(value)}"
mapping = typing.cast(collections.abc.Mapping[str, frozenset[FieldValue]], value)
return {k: transform(vraw) for k, vraw in mapping.items()}

else:
raise RuntimeError("Unexpected behavior")
Expand Down
39 changes: 1 addition & 38 deletions src/kat_transform/field.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing
import collections.abc
from functools import cache
from dataclasses import dataclass
from dataclasses import dataclass, field


from .util import get_by_name
Expand Down Expand Up @@ -39,43 +39,6 @@ class FieldSpec(typing.Generic[I, O]):
Define metadata for this field
"""

@property
@cache
def _origin(self):
return typing.get_origin(self.output_type)

@property
@cache
def is_typed_mutable_sequence(self):
if self._origin is None:
return False

return issubclass(self._origin, collections.abc.MutableSequence)

@property
@cache
def is_typed_sequence(self):
if self._origin is None:
return False

return issubclass(self._origin, collections.abc.Sequence)

@property
@cache
def is_typed_mapping(self):
if self._origin is None:
return False

return issubclass(self._origin, collections.abc.Mapping)

@property
@cache
def item_type(self):
if self.is_typed_sequence:
return typing.get_args(self.output_type)[0]
elif self.is_typed_mapping:
return typing.get_args(self.output_type)[1]

def get(self, from_: typing.Any) -> I | O | ValueGetter:
"""
Get field input value from object
Expand Down
22 changes: 11 additions & 11 deletions src/kat_transform/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,28 @@
import collections.abc
from dataclasses import dataclass


from .field import FieldSpec
from .markers import FieldValue
from .metadata import SchemaMetadata
from .util import get_item_type, is_typed_mapping, is_typed_sequence


def get_by_item(
item_type: "SchemaSpec",
value: collections.abc.Mapping[str, typing.Any] | collections.abc.Sequence[typing.Any],
spec: FieldSpec[typing.Any, typing.Any],
) -> FieldValue:
if spec.is_typed_mapping:
if is_typed_mapping(spec.output_type):
mapping = typing.cast(collections.abc.Mapping[str, typing.Any], value)

mapping_schema_fields = tuple(
(key, frozenset(item_type.get(item))) for key, item in mapping.items()
)
mapping_schema_fields = {key: item_type.get(item) for key, item in mapping.items()}
return FieldValue(spec, mapping_schema_fields)

elif spec.is_typed_sequence:
elif is_typed_sequence(spec.output_type):
array = typing.cast(collections.abc.Sequence[typing.Any], value)

array_schema_fields = [frozenset(item_type.get(item)) for item in array]
array_schema_fields = [item_type.get(item) for item in array]
return FieldValue(spec, array_schema_fields)
else:
raise RuntimeError("Unexpected behavior")
Expand All @@ -39,7 +39,7 @@ class SchemaSpec:
fields: collections.abc.Sequence[FieldSpec[typing.Any, typing.Any]]
metadata: SchemaMetadata | None = None

def get(self, from_: typing.Any) -> set[FieldValue]:
def get(self, from_: typing.Any) -> frozenset[FieldValue]:
"""
Get input values of fields
"""
Expand All @@ -50,16 +50,16 @@ def get(self, from_: typing.Any) -> set[FieldValue]:

if isinstance(spec.output_type, SchemaSpec):
sub_schema_fields = spec.output_type.get(field_value)
fields.add(FieldValue(spec, frozenset(sub_schema_fields)))
fields.add(FieldValue(spec, sub_schema_fields))
continue

elif isinstance(spec.item_type, SchemaSpec):
fields.add(get_by_item(spec.item_type, typing.cast(typing.Any, field_value), spec))
elif isinstance(item_type := get_item_type(spec.output_type), SchemaSpec):
fields.add(get_by_item(item_type, typing.cast(typing.Any, field_value), spec))
continue

fields.add(FieldValue(spec, field_value))

return fields
return frozenset(fields)

def __hash__(self) -> int:
return hash((self.name,) + tuple(self.fields))
23 changes: 23 additions & 0 deletions src/kat_transform/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import cache
import typing
import collections.abc

Expand All @@ -18,3 +19,25 @@ def get_by_name(name: str, from_: typing.Any) -> typing.Any:
assert hasattr(from_, name), f'{from_!r} has no attribute "{name}"'

return getattr(from_, name)


@cache
def is_typed_sequence(annotation: typing.Any) -> bool:
origin = typing.get_origin(annotation)
return origin and issubclass(origin, collections.abc.Sequence)


@cache
def is_typed_mapping(annotation: typing.Any) -> bool:
origin = typing.get_origin(annotation)
return origin and issubclass(origin, collections.abc.Mapping)


@cache
def get_item_type(annotation: typing.Any) -> typing.Any:
args = typing.get_args(annotation)
if is_typed_sequence(annotation):
return args[0]

if is_typed_mapping(annotation):
return args[1]
12 changes: 0 additions & 12 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,6 @@ def test_transform_subschema_in_mutable_sequence():
assert transformed == {"sub": [{"name": "NAME"}]}


def test_transform_subschema_in_immutable_sequence():
sub = schema("Sub", field(str, "name", transform=lambda x: x.upper()))

spec = schema("Schema", field(tuple[sub], "sub"))

raw = spec.get({"sub": [{"name": "name"}]})

transformed = transform(raw)

assert transformed == {"sub": ({"name": "NAME"},)}


def test_transform_subschema_in_mapping():
sub = schema("Sub", field(str, "name", transform=lambda x: x.upper()))

Expand Down