diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/filter/factory.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/filter/factory.py index 25672f674..b391d6aae 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/filter/factory.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/filter/factory.py @@ -45,7 +45,7 @@ def __replace(leaf: ConditionTree) -> ConditionTree: leaf = cast(ConditionTreeLeaf, leaf) time_transform = time_transforms(1) if leaf.operator not in SHIFTED_OPERATORS: - raise FilterFactoryException(f"'{leaf.operator}' is not shiftable ") + return leaf alternative: Alternative = time_transform[leaf.operator][0] leaf = alternative["replacer"](leaf, tz) diff --git a/src/datasource_toolkit/tests/interfaces/query/filter/test_factory.py b/src/datasource_toolkit/tests/interfaces/query/filter/test_factory.py index 08598395b..5593d2166 100644 --- a/src/datasource_toolkit/tests/interfaces/query/filter/test_factory.py +++ b/src/datasource_toolkit/tests/interfaces/query/filter/test_factory.py @@ -10,11 +10,25 @@ import pytest from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.collections import Collection -from forestadmin.datasource_toolkit.interfaces.fields import FieldType, ManyToMany, Operator -from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import Aggregator, ConditionTreeBranch -from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf -from forestadmin.datasource_toolkit.interfaces.query.filter.factory import FilterFactory, FilterFactoryException -from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter +from forestadmin.datasource_toolkit.interfaces.fields import ( + FieldType, + ManyToMany, + Operator, +) +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import ( + Aggregator, + ConditionTreeBranch, +) +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ( + ConditionTreeLeaf, +) +from forestadmin.datasource_toolkit.interfaces.query.filter.factory import ( + FilterFactory, + FilterFactoryException, +) +from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import ( + PaginatedFilter, +) from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter from forestadmin.datasource_toolkit.interfaces.query.projections import Projection @@ -38,15 +52,75 @@ def test_shift_period_filter(mock_time_transform: mock.MagicMock): mock_replacer = mock.MagicMock(return_value="fake_replacer") mock_time_transform.return_value = {Operator.PREVIOUS_YEAR: [{"replacer": mock_replacer}]} with mock.patch( - "forestadmin.datasource_toolkit.interfaces.query.filter.factory.SHIFTED_OPERATORS", {Operator.PREVIOUS_YEAR} + "forestadmin.datasource_toolkit.interfaces.query.filter.factory.SHIFTED_OPERATORS", + {Operator.PREVIOUS_YEAR}, ): assert shift_period_filter_replacer(leaf) == "fake_replacer" mock_time_transform.assert_called_once_with(1) mock_replacer.assert_called_once_with(leaf, "UTC") - with mock.patch("forestadmin.datasource_toolkit.interfaces.query.filter.factory.SHIFTED_OPERATORS", {}): - with pytest.raises(FilterFactoryException): - shift_period_filter_replacer(leaf) + with mock.patch( + "forestadmin.datasource_toolkit.interfaces.query.filter.factory.SHIFTED_OPERATORS", + {}, + ): + assert shift_period_filter_replacer(leaf) == leaf + + +@mock.patch("forestadmin.datasource_toolkit.interfaces.query.filter.factory.time_transforms") +def test_shift_period_filter_with_complex_condition_tree( + mock_time_transform: mock.MagicMock, +): + tz = zoneinfo.ZoneInfo("UTC") + shift_period_filter_replacer = FilterFactory._shift_period_filter(tz) + + leaf_previous_year = ConditionTreeLeaf(field="date_field", operator=Operator.PREVIOUS_YEAR) + leaf_equal = ConditionTreeLeaf(field="name", operator=Operator.EQUAL, value="test") + leaf_greater_than = ConditionTreeLeaf(field="age", operator=Operator.GREATER_THAN, value=18) + leaf_previous_month = ConditionTreeLeaf(field="created_at", operator=Operator.PREVIOUS_MONTH) + + mock_replacer_year = mock.MagicMock( + return_value=ConditionTreeLeaf(field="date_field", operator=Operator.EQUAL, value="replaced_year") + ) + mock_replacer_month = mock.MagicMock( + return_value=ConditionTreeLeaf(field="created_at", operator=Operator.EQUAL, value="replaced_month") + ) + + mock_time_transform.return_value = { + Operator.PREVIOUS_YEAR: [{"replacer": mock_replacer_year}], + Operator.PREVIOUS_MONTH: [{"replacer": mock_replacer_month}], + } + + with mock.patch( + "forestadmin.datasource_toolkit.interfaces.query.filter.factory.SHIFTED_OPERATORS", + {Operator.PREVIOUS_YEAR, Operator.PREVIOUS_MONTH}, + ): + complex_tree = ConditionTreeBranch( + aggregator=Aggregator.AND, + conditions=[ + leaf_previous_year, + leaf_equal, + leaf_greater_than, + leaf_previous_month, + ], + ) + + result = complex_tree.replace(shift_period_filter_replacer) + + mock_replacer_year.assert_called_once_with(leaf_previous_year, tz) + mock_replacer_month.assert_called_once_with(leaf_previous_month, tz) + + assert isinstance(result, ConditionTreeBranch) + assert result.aggregator == Aggregator.AND + assert len(result.conditions) == 4 + + assert result.conditions[0] == ConditionTreeLeaf( + field="date_field", operator=Operator.EQUAL, value="replaced_year" + ) + assert result.conditions[1] == leaf_equal # EQUAL should remain unchanged + assert result.conditions[2] == leaf_greater_than # GREATER_THAN should remain unchanged + assert result.conditions[3] == ConditionTreeLeaf( + field="created_at", operator=Operator.EQUAL, value="replaced_month" + ) @mock.patch("forestadmin.datasource_toolkit.interfaces.query.filter.factory.FilterFactory._shift_period_filter") @@ -106,7 +180,11 @@ async def test_make_through_filter(): mock_collection = mock.MagicMock() mock_collection.list = mock.AsyncMock(return_value=[{"id": 1}]) - with mock.patch.object(collection.datasource, "get_collection", return_value=mock_collection): + with mock.patch.object( + collection.datasource, + "get_collection", + return_value=mock_collection, + ): res = await FilterFactory.make_through_filter( mocked_caller, collection, @@ -182,7 +260,11 @@ async def test_make_through_filter(): Aggregator.AND, conditions=[ ConditionTreeLeaf("child_id", Operator.EQUAL, "fake_value"), - ConditionTreeLeaf("parent_id", Operator.IN, ["fake_record_1", "fake_record_2"]), + ConditionTreeLeaf( + "parent_id", + Operator.IN, + ["fake_record_1", "fake_record_2"], + ), ], ) }