From 943a3b21645d414d6d8bcd3a4bf97646dac17b4b Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Fri, 16 Jan 2026 08:56:06 +0300 Subject: [PATCH] allow skipping annotations parsing and fix string annotations parsing --- .../modern-di/modern_di/providers/factory.py | 15 +++++-- packages/modern-di/modern_di/types_parser.py | 24 ++++++++--- .../tests_core/providers/test_factory.py | 10 +++++ .../modern-di/tests_core/test_types_parser.py | 41 +++++++++++++++++++ 4 files changed, 80 insertions(+), 10 deletions(-) diff --git a/packages/modern-di/modern_di/providers/factory.py b/packages/modern-di/modern_di/providers/factory.py index 44991d6..11a88ad 100644 --- a/packages/modern-di/modern_di/providers/factory.py +++ b/packages/modern-di/modern_di/providers/factory.py @@ -5,7 +5,7 @@ from modern_di import types from modern_di.providers.abstract import AbstractProvider from modern_di.scope import Scope -from modern_di.types_parser import parse_creator +from modern_di.types_parser import SignatureItem, parse_creator if typing.TYPE_CHECKING: @@ -25,7 +25,7 @@ def __post_init__(self) -> None: class Factory(AbstractProvider[types.T_co]): __slots__ = [*AbstractProvider.BASE_SLOTS, "_creator", "_kwargs", "_parsed_kwargs", "cache_settings"] - def __init__( + def __init__( # noqa: PLR0913 self, *, scope: Scope = Scope.APP, @@ -33,9 +33,16 @@ def __init__( bound_type: type | None = types.UNSET, # type: ignore[assignment] kwargs: dict[str, typing.Any] | None = None, cache_settings: CacheSettings[types.T_co] | None = None, + skip_creator_parsing: bool = False, ) -> None: - dependency_type, self._parsed_kwargs = parse_creator(creator) - super().__init__(scope=scope, bound_type=bound_type if bound_type != types.UNSET else dependency_type.arg_type) + if skip_creator_parsing: + parsed_type: type | None = None + parsed_kwargs: dict[str, SignatureItem] = {} + else: + return_sig, parsed_kwargs = parse_creator(creator) + parsed_type = return_sig.arg_type + self._parsed_kwargs = parsed_kwargs + super().__init__(scope=scope, bound_type=bound_type if bound_type != types.UNSET else parsed_type) self._creator = creator self.cache_settings = cache_settings self._kwargs = kwargs diff --git a/packages/modern-di/modern_di/types_parser.py b/packages/modern-di/modern_di/types_parser.py index 92ccef1..aa2d069 100644 --- a/packages/modern-di/modern_di/types_parser.py +++ b/packages/modern-di/modern_di/types_parser.py @@ -15,6 +15,9 @@ class SignatureItem: @classmethod def from_type(cls, type_: type, default: object = UNSET) -> "SignatureItem": + if type_ is types.NoneType: + return cls() + result: dict[str, typing.Any] = {"default": default} if isinstance(type_, types.GenericAlias): result["arg_type"] = type_.__origin__ @@ -39,21 +42,30 @@ def parse_creator(creator: typing.Callable[..., typing.Any]) -> tuple[SignatureI except ValueError: return SignatureItem.from_type(typing.cast(type, creator)), {} + is_class = isinstance(creator, type) + if is_class and hasattr(creator, "__init__"): + type_hints = typing.get_type_hints(creator.__init__) + else: + type_hints = typing.get_type_hints(creator) + param_hints = {} for param_name, param in sig.parameters.items(): if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): continue + default = UNSET if param.default is not param.empty: default = param.default - if param.annotation is not param.empty: - param_hints[param_name] = SignatureItem.from_type(param.annotation, default=default) + + if param_name in type_hints: + param_hints[param_name] = SignatureItem.from_type(type_hints[param_name], default=default) else: param_hints[param_name] = SignatureItem(default=default) - if sig.return_annotation: - return_sig = SignatureItem.from_type(sig.return_annotation) - elif isinstance(creator, type): - return_sig = SignatureItem.from_type(creator) + + if is_class: + return_sig = SignatureItem.from_type(typing.cast(type, creator)) + elif "return" in type_hints: + return_sig = SignatureItem.from_type(type_hints["return"]) else: return_sig = SignatureItem() diff --git a/packages/modern-di/tests_core/providers/test_factory.py b/packages/modern-di/tests_core/providers/test_factory.py index c617932..e9e287f 100644 --- a/packages/modern-di/tests_core/providers/test_factory.py +++ b/packages/modern-di/tests_core/providers/test_factory.py @@ -1,4 +1,5 @@ import dataclasses +import re import pytest from modern_di import Container, Group, Scope, providers @@ -27,6 +28,7 @@ def func_with_union(dep1: SimpleCreator | int) -> str: class MyGroup(Group): app_factory = providers.Factory(creator=SimpleCreator, kwargs={"dep1": "original"}) app_factory_unresolvable = providers.Factory(creator=SimpleCreator, bound_type=None) + app_factory_skip_creator_parsing = providers.Factory(creator=SimpleCreator, skip_creator_parsing=True) func_with_union_factory = providers.Factory(creator=func_with_union, bound_type=None) request_factory = providers.Factory(scope=Scope.REQUEST, creator=DependentCreator) request_factory_with_di_container = providers.Factory(scope=Scope.REQUEST, creator=AnotherCreator) @@ -45,6 +47,14 @@ def test_app_factory() -> None: assert instance2 is not instance3 +def test_app_factory_skip_creator_parsing() -> None: + app_container = Container(groups=[MyGroup]) + with pytest.raises( + TypeError, match=re.escape("SimpleCreator.__init__() missing 1 required keyword-only argument: 'dep1'") + ): + app_container.resolve_provider(MyGroup.app_factory_skip_creator_parsing) + + def test_app_factory_unresolvable() -> None: app_container = Container(groups=[MyGroup]) with pytest.raises(RuntimeError, match="Argument dep1 cannot be resolved, type= None: def simple_func(arg1: int, arg2: str | None = None) -> int: ... # type: ignore[empty-body] def none_func(arg1: int, arg2: str | None = None) -> None: ... def args_kwargs_func(*args: int, **kwargs: str) -> None: ... +def func_with_str_annotations(arg1: "list[int]", arg2: "str") -> None: ... async def async_func(arg1: int = 1, arg2="str") -> int: ... # type: ignore[no-untyped-def,empty-body] # noqa: ANN001 @@ -43,6 +44,10 @@ class SomeRegularClass: def __init__(self, arg1: str, arg2: int) -> None: ... +class ClassWithStringAnnotations: + def __init__(self, arg1: "str", arg2: "int") -> None: ... + + @pytest.mark.parametrize( ("creator", "result"), [ @@ -73,6 +78,16 @@ def __init__(self, arg1: str, arg2: int) -> None: ... {}, ), ), + ( + func_with_str_annotations, + ( + SignatureItem(), + { + "arg1": SignatureItem(arg_type=list, args=[int]), + "arg2": SignatureItem(arg_type=str), + }, + ), + ), ( async_func, ( @@ -112,8 +127,34 @@ def __init__(self, arg1: str, arg2: int) -> None: ... }, ), ), + ( + ClassWithStringAnnotations, + ( + SignatureItem(arg_type=ClassWithStringAnnotations), + { + "arg1": SignatureItem(arg_type=str), + "arg2": SignatureItem(arg_type=int), + }, + ), + ), (int, (SignatureItem(arg_type=int), {})), ], ) def test_parse_creator(creator: type, result: tuple[SignatureItem | None, dict[str, SignatureItem]]) -> None: assert parse_creator(creator) == result + + +def func_with_wrong_annotations(arg1: "WrongType", arg2: "str") -> None: ... # type: ignore[name-defined] # noqa: F821 + + +class ClassWithWrongAnnotations: + def __init__(self, arg1: "WrongType", arg2: "int") -> None: ... # type: ignore[name-defined] # noqa: F821 + + +@pytest.mark.parametrize( + "creator", + [func_with_wrong_annotations, ClassWithWrongAnnotations], +) +def test_parse_creator_wrong_annotations(creator: type) -> None: + with pytest.raises(NameError, match="name 'WrongType' is not defined"): + assert parse_creator(creator)