diff --git a/packages/modern-di/modern_di/container.py b/packages/modern-di/modern_di/container.py index 951a3d4..64a5703 100644 --- a/packages/modern-di/modern_di/container.py +++ b/packages/modern-di/modern_di/container.py @@ -111,7 +111,7 @@ def override( mock: typing.Any, # noqa: ANN401 ) -> None: self.cache_registry.clear_kwargs() - new_provider = mock if isinstance(mock, AbstractProvider) else Object(obj=mock) + new_provider = mock if isinstance(mock, AbstractProvider) else Object(obj=mock, bound_type=dependency_type) return self.providers_registry.override_provider( dependency_name=dependency_name, dependency_type=dependency_type, new_provider=new_provider ) diff --git a/packages/modern-di/modern_di/types_parser.py b/packages/modern-di/modern_di/types_parser.py index aa2d069..c3c9d60 100644 --- a/packages/modern-di/modern_di/types_parser.py +++ b/packages/modern-di/modern_di/types_parser.py @@ -18,6 +18,9 @@ def from_type(cls, type_: type, default: object = UNSET) -> "SignatureItem": if type_ is types.NoneType: return cls() + if isinstance(type_, typing._AnnotatedAlias): # type: ignore[attr-defined] # noqa: SLF001 + type_ = type_.__args__[0] + result: dict[str, typing.Any] = {"default": default} if isinstance(type_, types.GenericAlias): result["arg_type"] = type_.__origin__ diff --git a/packages/modern-di/tests_core/test_types_parser.py b/packages/modern-di/tests_core/test_types_parser.py index 3fb9b6e..5c3d56e 100644 --- a/packages/modern-di/tests_core/test_types_parser.py +++ b/packages/modern-di/tests_core/test_types_parser.py @@ -9,11 +9,13 @@ ("type_", "result"), [ (int, SignatureItem(arg_type=int)), + (typing.Annotated[int, None], SignatureItem(arg_type=int)), (list[int], SignatureItem(arg_type=list, args=[int])), (dict[str, typing.Any], SignatureItem(arg_type=dict, args=[str, typing.Any])), (typing.Optional[str], SignatureItem(arg_type=str, is_nullable=True)), # noqa: UP045 (str | None, SignatureItem(arg_type=str, is_nullable=True)), (str | int, SignatureItem(args=[str, int])), + (typing.Union[str | int], SignatureItem(args=[str, int])), # noqa: UP007 (list[str] | None, SignatureItem(arg_type=list, is_nullable=True)), ], ) @@ -22,7 +24,7 @@ def test_signature_item_parser(type_: type, result: SignatureItem) -> None: def simple_func(arg1: int, arg2: str | None = None) -> int: ... # type: ignore[empty-body] -def none_func(arg1: int, arg2: str | None = None) -> None: ... +def none_func(arg1: typing.Annotated[int, None], arg2: str | None = None) -> None: ... def args_kwargs_func(*args: int, **kwargs: str) -> None: ... def func_with_str_annotations(arg1: "list[int]", arg2: "str") -> None: ... async def async_func(arg1: int = 1, arg2="str") -> int: ... # type: ignore[no-untyped-def,empty-body] # noqa: ANN001