diff --git a/freeadmin/contrib/adapters/tortoise/adapter.py b/freeadmin/contrib/adapters/tortoise/adapter.py index ef946c6..791bb6f 100644 --- a/freeadmin/contrib/adapters/tortoise/adapter.py +++ b/freeadmin/contrib/adapters/tortoise/adapter.py @@ -20,6 +20,7 @@ from tortoise import Tortoise, connections from tortoise import fields +from tortoise.fields.relational import ManyToManyRelation from tortoise.exceptions import ( ConfigurationError, DoesNotExist as TortoiseDoesNotExist, @@ -308,7 +309,11 @@ async def create( This coroutine must be awaited. """ - include_m2m = list(include_m2m or []) + include_m2m = set(include_m2m or []) + meta = getattr(model_cls, "_meta", None) + for fname in getattr(meta, "m2m_fields", set()): + include_m2m.add(fname) + m2m_values: dict[str, list[Any]] = {} for fname in include_m2m: @@ -327,27 +332,54 @@ async def create( obj = await model_cls.create(**data) for fname, values in m2m_values.items(): - manager = getattr(obj, fname) + cached_value = getattr(obj, "__dict__", {}).get(fname) + if cached_value is not None and not hasattr(cached_value, "remote_model"): + obj.__dict__.pop(fname, None) + manager_descriptor = getattr(type(obj), fname, None) + manager = ( + manager_descriptor.__get__(obj, type(obj)) + if hasattr(manager_descriptor, "__get__") + else getattr(obj, fname) + ) + if not hasattr(manager, "remote_model"): + field = getattr(obj._meta, "fields_map", {}).get(fname) + if field is not None: + manager = ManyToManyRelation(obj, field) + prefetched_map = getattr(obj, "_prefetched_map", None) + if isinstance(prefetched_map, dict): + prefetched_map[fname] = manager + else: + obj._prefetched_map = {fname: manager} remote_model = manager.remote_model pk_attr = self.get_pk_attr(remote_model) - normalized: list[Any] = [] + normalized_instances: list[Any] = [] + pending_pks: list[Any] = [] + for value in values: if value is None: continue if isinstance(value, remote_model): - normalized.append(value) - continue - if isinstance(value, dict) and pk_attr in value: - normalized.append(value[pk_attr]) + normalized_instances.append(value) continue + if isinstance(value, dict): + if pk_attr in value: + pending_pks.append(value[pk_attr]) + continue + if "id" in value: + pending_pks.append(value["id"]) + continue if isinstance(value, str) and value.isdigit(): - normalized.append(int(value)) + pending_pks.append(int(value)) continue - normalized.append(value) + pending_pks.append(value) + + if pending_pks: + fetched = await remote_model.filter(**{f"{pk_attr}__in": pending_pks}) + normalized_instances.extend(fetched) - if normalized: - await manager.add(*normalized) + if normalized_instances: + await manager.add(*normalized_instances) return obj @@ -440,7 +472,7 @@ async def delete(self, obj: Model) -> None: """ await obj.delete() - async def fetch_related(self, obj: Model, *fields: str) -> None: + async def fetch_related(self, obj: Model, *related_fields: str) -> None: """Populate related fields on an object. Args: @@ -449,7 +481,25 @@ async def fetch_related(self, obj: Model, *fields: str) -> None: This coroutine must be awaited. """ - await obj.fetch_related(*fields) + prefetched_map = getattr(obj, "_prefetched_map", None) + remaining_fields: list[str] = [] + + for fname in related_fields: + field = getattr(getattr(obj, "_meta", None), "fields_map", {}).get(fname) + if isinstance(field, fields.relational.ManyToManyFieldInstance): + relation = ManyToManyRelation(obj, field) + related = await relation.all() + relation._set_result_for_query(related, None) + if isinstance(prefetched_map, dict): + prefetched_map[fname] = relation + else: + prefetched_map = {fname: relation} + obj._prefetched_map = prefetched_map + continue + remaining_fields.append(fname) + + if remaining_fields: + await obj.fetch_related(*remaining_fields) async def m2m_clear(self, manager) -> None: """Clear all links from a many-to-many relation manager. diff --git a/tests/test_tortoise_adapter_m2m_creation.py b/tests/test_tortoise_adapter_m2m_creation.py new file mode 100644 index 0000000..d3c9959 --- /dev/null +++ b/tests/test_tortoise_adapter_m2m_creation.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +""" +Tortoise adapter many-to-many creation tests + +Validate that passing many-to-many payloads to the Tortoise adapter does not +clobber relation managers and attaches related objects after creation. + +Version:0.1.0 +Author: Timur Kady +Email: timurkady@yandex.com +""" + +from __future__ import annotations + +import asyncio +from typing import ClassVar + +import pytest +from tortoise import Tortoise, fields, models + +from freeadmin.contrib.adapters.tortoise.adapter import Adapter +from tests.system_models import system_models + + +class Listener(models.Model): + id = fields.IntField(pk=True) + name = fields.CharField(max_length=50) + + +class Event(models.Model): + id = fields.IntField(pk=True) + name = fields.CharField(max_length=50) + listeners: fields.ManyToManyRelation[Listener] = fields.ManyToManyField( + "models.Listener", related_name="events" + ) + + +class TestTortoiseAdapterM2MCreation: + """Ensure the adapter handles many-to-many payloads safely.""" + + adapter: ClassVar[Adapter] + + @classmethod + def setup_class(cls) -> None: + asyncio.run( + Tortoise.init( + db_url="sqlite://:memory:", + modules={ + "models": [__name__], + "admin": list(system_models.module_names()), + }, + ) + ) + asyncio.run(Tortoise.generate_schemas()) + cls.adapter = Adapter() + + @classmethod + def teardown_class(cls) -> None: + asyncio.run(Tortoise.close_connections()) + + @pytest.mark.asyncio + async def test_many_to_many_payload_does_not_replace_manager(self) -> None: + """Verify that the adapter keeps the relation manager intact.""" + + listener = await Listener.create(name="first") + payload = {"name": "event", "listeners": [{"id": listener.id}]} + + event = await self.adapter.create(Event, include_m2m=["listeners"], **payload) + + await self.adapter.fetch_related(event, "listeners") + + relation = fields.relational.ManyToManyRelation( + event, Event._meta.fields_map["listeners"] + ) + related = await relation.all() + assert len(related) == 1 + assert related[0].id == listener.id + + +# The End