From 26f3ca1efbd92d9f7df90c20470e7ac5dc0afe12 Mon Sep 17 00:00:00 2001 From: Timur Kady Date: Thu, 27 Nov 2025 22:04:39 +0300 Subject: [PATCH] Handle many-to-many adapters with reserved attribute names --- .../contrib/adapters/tortoise/adapter.py | 123 +++++++++++++++++- tests/test_tortoise_adapter_m2m_creation.py | 80 ++++++++++++ 2 files changed, 198 insertions(+), 5 deletions(-) create mode 100644 tests/test_tortoise_adapter_m2m_creation.py diff --git a/freeadmin/contrib/adapters/tortoise/adapter.py b/freeadmin/contrib/adapters/tortoise/adapter.py index f7f1c0d..791bb6f 100644 --- a/freeadmin/contrib/adapters/tortoise/adapter.py +++ b/freeadmin/contrib/adapters/tortoise/adapter.py @@ -20,6 +20,7 @@ from tortoise import Tortoise, connections from tortoise import fields +from tortoise.fields.relational import ManyToManyRelation from tortoise.exceptions import ( ConfigurationError, DoesNotExist as TortoiseDoesNotExist, @@ -103,6 +104,19 @@ def _register_admin_models(self) -> None: self.system_setting_model = SystemSetting self.setting_value_type = SettingValueType + def _normalize_relation_value(self, field: Any, value: Any) -> Any: + """Return a primary-key friendly representation for relation values.""" + related_model = getattr(field, "related_model", None) + pk_attr = self.get_pk_attr(related_model) if related_model else "id" + if isinstance(value, dict): + if pk_attr in value: + return value[pk_attr] + if "id" in value: + return value["id"] + if isinstance(value, str) and value.isdigit(): + return int(value) + return value + def normalize_import_data(self, model: type[Model], data: dict[str, Any]) -> dict[str, Any]: """Convert raw import values into ORM-friendly types.""" meta = getattr(model, "_meta", None) @@ -121,10 +135,11 @@ def normalize_import_data(self, model: type[Model], data: dict[str, Any]) -> dic fields.relational.OneToOneFieldInstance, ), ): + normalized_value = self._normalize_relation_value(field, value) if getattr(value, "_saved_in_db", False): cleaned[name] = value else: - cleaned[f"{name}_id"] = value + cleaned[f"{name}_id"] = normalized_value continue if getattr(field, "enum_type", None) and isinstance(value, str): if value.isdigit(): @@ -273,11 +288,20 @@ def in_transaction(self): conn_name = self._resolve_connection_name() return in_transaction(conn_name) - async def create(self, model_cls: type[Model], **data: Any) -> Model: + async def create( + self, + model_cls: type[Model], + *, + include_m2m: Iterable[str] | None = None, + **data: Any, + ) -> Model: """Create and persist a model instance. Args: model_cls: Model class to instantiate. + include_m2m: Iterable of many-to-many field names whose values are + provided in ``data`` and should be assigned after instance + creation. **data: Field values for the new record. Returns: @@ -285,8 +309,79 @@ async def create(self, model_cls: type[Model], **data: Any) -> Model: This coroutine must be awaited. """ + include_m2m = set(include_m2m or []) + meta = getattr(model_cls, "_meta", None) + for fname in getattr(meta, "m2m_fields", set()): + include_m2m.add(fname) + + m2m_values: dict[str, list[Any]] = {} + + for fname in include_m2m: + if fname not in data: + continue + value = data.pop(fname) + if value is None: + m2m_values[fname] = [] + continue + if isinstance(value, (list, tuple, set)): + m2m_values[fname] = list(value) + else: + m2m_values[fname] = [value] + data = self.normalize_import_data(model_cls, data) - return await model_cls.create(**data) + obj = await model_cls.create(**data) + + for fname, values in m2m_values.items(): + cached_value = getattr(obj, "__dict__", {}).get(fname) + if cached_value is not None and not hasattr(cached_value, "remote_model"): + obj.__dict__.pop(fname, None) + manager_descriptor = getattr(type(obj), fname, None) + manager = ( + manager_descriptor.__get__(obj, type(obj)) + if hasattr(manager_descriptor, "__get__") + else getattr(obj, fname) + ) + if not hasattr(manager, "remote_model"): + field = getattr(obj._meta, "fields_map", {}).get(fname) + if field is not None: + manager = ManyToManyRelation(obj, field) + prefetched_map = getattr(obj, "_prefetched_map", None) + if isinstance(prefetched_map, dict): + prefetched_map[fname] = manager + else: + obj._prefetched_map = {fname: manager} + remote_model = manager.remote_model + pk_attr = self.get_pk_attr(remote_model) + + normalized_instances: list[Any] = [] + pending_pks: list[Any] = [] + + for value in values: + if value is None: + continue + if isinstance(value, remote_model): + normalized_instances.append(value) + continue + if isinstance(value, dict): + if pk_attr in value: + pending_pks.append(value[pk_attr]) + continue + if "id" in value: + pending_pks.append(value["id"]) + continue + if isinstance(value, str) and value.isdigit(): + pending_pks.append(int(value)) + continue + pending_pks.append(value) + + if pending_pks: + fetched = await remote_model.filter(**{f"{pk_attr}__in": pending_pks}) + normalized_instances.extend(fetched) + + if normalized_instances: + await manager.add(*normalized_instances) + + return obj async def get( self, @@ -377,7 +472,7 @@ async def delete(self, obj: Model) -> None: """ await obj.delete() - async def fetch_related(self, obj: Model, *fields: str) -> None: + async def fetch_related(self, obj: Model, *related_fields: str) -> None: """Populate related fields on an object. Args: @@ -386,7 +481,25 @@ async def fetch_related(self, obj: Model, *fields: str) -> None: This coroutine must be awaited. """ - await obj.fetch_related(*fields) + prefetched_map = getattr(obj, "_prefetched_map", None) + remaining_fields: list[str] = [] + + for fname in related_fields: + field = getattr(getattr(obj, "_meta", None), "fields_map", {}).get(fname) + if isinstance(field, fields.relational.ManyToManyFieldInstance): + relation = ManyToManyRelation(obj, field) + related = await relation.all() + relation._set_result_for_query(related, None) + if isinstance(prefetched_map, dict): + prefetched_map[fname] = relation + else: + prefetched_map = {fname: relation} + obj._prefetched_map = prefetched_map + continue + remaining_fields.append(fname) + + if remaining_fields: + await obj.fetch_related(*remaining_fields) async def m2m_clear(self, manager) -> None: """Clear all links from a many-to-many relation manager. diff --git a/tests/test_tortoise_adapter_m2m_creation.py b/tests/test_tortoise_adapter_m2m_creation.py new file mode 100644 index 0000000..d3c9959 --- /dev/null +++ b/tests/test_tortoise_adapter_m2m_creation.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +""" +Tortoise adapter many-to-many creation tests + +Validate that passing many-to-many payloads to the Tortoise adapter does not +clobber relation managers and attaches related objects after creation. + +Version:0.1.0 +Author: Timur Kady +Email: timurkady@yandex.com +""" + +from __future__ import annotations + +import asyncio +from typing import ClassVar + +import pytest +from tortoise import Tortoise, fields, models + +from freeadmin.contrib.adapters.tortoise.adapter import Adapter +from tests.system_models import system_models + + +class Listener(models.Model): + id = fields.IntField(pk=True) + name = fields.CharField(max_length=50) + + +class Event(models.Model): + id = fields.IntField(pk=True) + name = fields.CharField(max_length=50) + listeners: fields.ManyToManyRelation[Listener] = fields.ManyToManyField( + "models.Listener", related_name="events" + ) + + +class TestTortoiseAdapterM2MCreation: + """Ensure the adapter handles many-to-many payloads safely.""" + + adapter: ClassVar[Adapter] + + @classmethod + def setup_class(cls) -> None: + asyncio.run( + Tortoise.init( + db_url="sqlite://:memory:", + modules={ + "models": [__name__], + "admin": list(system_models.module_names()), + }, + ) + ) + asyncio.run(Tortoise.generate_schemas()) + cls.adapter = Adapter() + + @classmethod + def teardown_class(cls) -> None: + asyncio.run(Tortoise.close_connections()) + + @pytest.mark.asyncio + async def test_many_to_many_payload_does_not_replace_manager(self) -> None: + """Verify that the adapter keeps the relation manager intact.""" + + listener = await Listener.create(name="first") + payload = {"name": "event", "listeners": [{"id": listener.id}]} + + event = await self.adapter.create(Event, include_m2m=["listeners"], **payload) + + await self.adapter.fetch_related(event, "listeners") + + relation = fields.relational.ManyToManyRelation( + event, Event._meta.fields_map["listeners"] + ) + related = await relation.all() + assert len(related) == 1 + assert related[0].id == listener.id + + +# The End