Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions packages/modern-di/modern_di/providers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from modern_di import types
from modern_di.providers.abstract import AbstractProvider
from modern_di.scope import Scope
from modern_di.types_parser import parse_creator
from modern_di.types_parser import SignatureItem, parse_creator


if typing.TYPE_CHECKING:
Expand All @@ -25,17 +25,24 @@ def __post_init__(self) -> None:
class Factory(AbstractProvider[types.T_co]):
__slots__ = [*AbstractProvider.BASE_SLOTS, "_creator", "_kwargs", "_parsed_kwargs", "cache_settings"]

def __init__(
def __init__( # noqa: PLR0913
self,
*,
scope: Scope = Scope.APP,
creator: typing.Callable[..., types.T_co],
bound_type: type | None = types.UNSET, # type: ignore[assignment]
kwargs: dict[str, typing.Any] | None = None,
cache_settings: CacheSettings[types.T_co] | None = None,
skip_creator_parsing: bool = False,
) -> None:
dependency_type, self._parsed_kwargs = parse_creator(creator)
super().__init__(scope=scope, bound_type=bound_type if bound_type != types.UNSET else dependency_type.arg_type)
if skip_creator_parsing:
parsed_type: type | None = None
parsed_kwargs: dict[str, SignatureItem] = {}
else:
return_sig, parsed_kwargs = parse_creator(creator)
parsed_type = return_sig.arg_type
self._parsed_kwargs = parsed_kwargs
super().__init__(scope=scope, bound_type=bound_type if bound_type != types.UNSET else parsed_type)
self._creator = creator
self.cache_settings = cache_settings
self._kwargs = kwargs
Expand Down
24 changes: 18 additions & 6 deletions packages/modern-di/modern_di/types_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class SignatureItem:

@classmethod
def from_type(cls, type_: type, default: object = UNSET) -> "SignatureItem":
if type_ is types.NoneType:
return cls()

result: dict[str, typing.Any] = {"default": default}
if isinstance(type_, types.GenericAlias):
result["arg_type"] = type_.__origin__
Expand All @@ -39,21 +42,30 @@ def parse_creator(creator: typing.Callable[..., typing.Any]) -> tuple[SignatureI
except ValueError:
return SignatureItem.from_type(typing.cast(type, creator)), {}

is_class = isinstance(creator, type)
if is_class and hasattr(creator, "__init__"):
type_hints = typing.get_type_hints(creator.__init__)
else:
type_hints = typing.get_type_hints(creator)

param_hints = {}
for param_name, param in sig.parameters.items():
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
continue

default = UNSET
if param.default is not param.empty:
default = param.default
if param.annotation is not param.empty:
param_hints[param_name] = SignatureItem.from_type(param.annotation, default=default)

if param_name in type_hints:
param_hints[param_name] = SignatureItem.from_type(type_hints[param_name], default=default)
else:
param_hints[param_name] = SignatureItem(default=default)
if sig.return_annotation:
return_sig = SignatureItem.from_type(sig.return_annotation)
elif isinstance(creator, type):
return_sig = SignatureItem.from_type(creator)

if is_class:
return_sig = SignatureItem.from_type(typing.cast(type, creator))
elif "return" in type_hints:
return_sig = SignatureItem.from_type(type_hints["return"])
else:
return_sig = SignatureItem()

Expand Down
10 changes: 10 additions & 0 deletions packages/modern-di/tests_core/providers/test_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import re

import pytest
from modern_di import Container, Group, Scope, providers
Expand Down Expand Up @@ -27,6 +28,7 @@ def func_with_union(dep1: SimpleCreator | int) -> str:
class MyGroup(Group):
app_factory = providers.Factory(creator=SimpleCreator, kwargs={"dep1": "original"})
app_factory_unresolvable = providers.Factory(creator=SimpleCreator, bound_type=None)
app_factory_skip_creator_parsing = providers.Factory(creator=SimpleCreator, skip_creator_parsing=True)
func_with_union_factory = providers.Factory(creator=func_with_union, bound_type=None)
request_factory = providers.Factory(scope=Scope.REQUEST, creator=DependentCreator)
request_factory_with_di_container = providers.Factory(scope=Scope.REQUEST, creator=AnotherCreator)
Expand All @@ -45,6 +47,14 @@ def test_app_factory() -> None:
assert instance2 is not instance3


def test_app_factory_skip_creator_parsing() -> None:
app_container = Container(groups=[MyGroup])
with pytest.raises(
TypeError, match=re.escape("SimpleCreator.__init__() missing 1 required keyword-only argument: 'dep1'")
):
app_container.resolve_provider(MyGroup.app_factory_skip_creator_parsing)


def test_app_factory_unresolvable() -> None:
app_container = Container(groups=[MyGroup])
with pytest.raises(RuntimeError, match="Argument dep1 cannot be resolved, type=<class 'str'"):
Expand Down
41 changes: 41 additions & 0 deletions packages/modern-di/tests_core/test_types_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_signature_item_parser(type_: type, result: SignatureItem) -> None:
def simple_func(arg1: int, arg2: str | None = None) -> int: ... # type: ignore[empty-body]
def none_func(arg1: int, arg2: str | None = None) -> None: ...
def args_kwargs_func(*args: int, **kwargs: str) -> None: ...
def func_with_str_annotations(arg1: "list[int]", arg2: "str") -> None: ...
async def async_func(arg1: int = 1, arg2="str") -> int: ... # type: ignore[no-untyped-def,empty-body] # noqa: ANN001


Expand All @@ -43,6 +44,10 @@ class SomeRegularClass:
def __init__(self, arg1: str, arg2: int) -> None: ...


class ClassWithStringAnnotations:
def __init__(self, arg1: "str", arg2: "int") -> None: ...


@pytest.mark.parametrize(
("creator", "result"),
[
Expand Down Expand Up @@ -73,6 +78,16 @@ def __init__(self, arg1: str, arg2: int) -> None: ...
{},
),
),
(
func_with_str_annotations,
(
SignatureItem(),
{
"arg1": SignatureItem(arg_type=list, args=[int]),
"arg2": SignatureItem(arg_type=str),
},
),
),
(
async_func,
(
Expand Down Expand Up @@ -112,8 +127,34 @@ def __init__(self, arg1: str, arg2: int) -> None: ...
},
),
),
(
ClassWithStringAnnotations,
(
SignatureItem(arg_type=ClassWithStringAnnotations),
{
"arg1": SignatureItem(arg_type=str),
"arg2": SignatureItem(arg_type=int),
},
),
),
(int, (SignatureItem(arg_type=int), {})),
],
)
def test_parse_creator(creator: type, result: tuple[SignatureItem | None, dict[str, SignatureItem]]) -> None:
assert parse_creator(creator) == result


def func_with_wrong_annotations(arg1: "WrongType", arg2: "str") -> None: ... # type: ignore[name-defined] # noqa: F821


class ClassWithWrongAnnotations:
def __init__(self, arg1: "WrongType", arg2: "int") -> None: ... # type: ignore[name-defined] # noqa: F821


@pytest.mark.parametrize(
"creator",
[func_with_wrong_annotations, ClassWithWrongAnnotations],
)
def test_parse_creator_wrong_annotations(creator: type) -> None:
with pytest.raises(NameError, match="name 'WrongType' is not defined"):
assert parse_creator(creator)