Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 63 additions & 13 deletions freeadmin/contrib/adapters/tortoise/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from tortoise import Tortoise, connections
from tortoise import fields
from tortoise.fields.relational import ManyToManyRelation
from tortoise.exceptions import (
ConfigurationError,
DoesNotExist as TortoiseDoesNotExist,
Expand Down Expand Up @@ -308,7 +309,11 @@ async def create(

This coroutine must be awaited.
"""
include_m2m = list(include_m2m or [])
include_m2m = set(include_m2m or [])
meta = getattr(model_cls, "_meta", None)
for fname in getattr(meta, "m2m_fields", set()):
include_m2m.add(fname)

m2m_values: dict[str, list[Any]] = {}

for fname in include_m2m:
Expand All @@ -327,27 +332,54 @@ async def create(
obj = await model_cls.create(**data)

for fname, values in m2m_values.items():
manager = getattr(obj, fname)
cached_value = getattr(obj, "__dict__", {}).get(fname)
if cached_value is not None and not hasattr(cached_value, "remote_model"):
obj.__dict__.pop(fname, None)
manager_descriptor = getattr(type(obj), fname, None)
manager = (
manager_descriptor.__get__(obj, type(obj))
if hasattr(manager_descriptor, "__get__")
else getattr(obj, fname)
)
if not hasattr(manager, "remote_model"):
field = getattr(obj._meta, "fields_map", {}).get(fname)
if field is not None:
manager = ManyToManyRelation(obj, field)
prefetched_map = getattr(obj, "_prefetched_map", None)
if isinstance(prefetched_map, dict):
prefetched_map[fname] = manager
else:
obj._prefetched_map = {fname: manager}
remote_model = manager.remote_model
pk_attr = self.get_pk_attr(remote_model)

normalized: list[Any] = []
normalized_instances: list[Any] = []
pending_pks: list[Any] = []

for value in values:
if value is None:
continue
if isinstance(value, remote_model):
normalized.append(value)
continue
if isinstance(value, dict) and pk_attr in value:
normalized.append(value[pk_attr])
normalized_instances.append(value)
continue
if isinstance(value, dict):
if pk_attr in value:
pending_pks.append(value[pk_attr])
continue
if "id" in value:
pending_pks.append(value["id"])
continue
if isinstance(value, str) and value.isdigit():
normalized.append(int(value))
pending_pks.append(int(value))
continue
normalized.append(value)
pending_pks.append(value)

if pending_pks:
fetched = await remote_model.filter(**{f"{pk_attr}__in": pending_pks})
normalized_instances.extend(fetched)

if normalized:
await manager.add(*normalized)
if normalized_instances:
await manager.add(*normalized_instances)

return obj

Expand Down Expand Up @@ -440,7 +472,7 @@ async def delete(self, obj: Model) -> None:
"""
await obj.delete()

async def fetch_related(self, obj: Model, *fields: str) -> None:
async def fetch_related(self, obj: Model, *related_fields: str) -> None:
"""Populate related fields on an object.

Args:
Expand All @@ -449,7 +481,25 @@ async def fetch_related(self, obj: Model, *fields: str) -> None:

This coroutine must be awaited.
"""
await obj.fetch_related(*fields)
prefetched_map = getattr(obj, "_prefetched_map", None)
remaining_fields: list[str] = []

for fname in related_fields:
field = getattr(getattr(obj, "_meta", None), "fields_map", {}).get(fname)
if isinstance(field, fields.relational.ManyToManyFieldInstance):
relation = ManyToManyRelation(obj, field)
related = await relation.all()
relation._set_result_for_query(related, None)
if isinstance(prefetched_map, dict):
prefetched_map[fname] = relation
else:
prefetched_map = {fname: relation}
obj._prefetched_map = prefetched_map
continue
remaining_fields.append(fname)

if remaining_fields:
await obj.fetch_related(*remaining_fields)

async def m2m_clear(self, manager) -> None:
"""Clear all links from a many-to-many relation manager.
Expand Down
80 changes: 80 additions & 0 deletions tests/test_tortoise_adapter_m2m_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
"""
Tortoise adapter many-to-many creation tests

Validate that passing many-to-many payloads to the Tortoise adapter does not
clobber relation managers and attaches related objects after creation.

Version:0.1.0
Author: Timur Kady
Email: timurkady@yandex.com
"""

from __future__ import annotations

import asyncio
from typing import ClassVar

import pytest
from tortoise import Tortoise, fields, models

from freeadmin.contrib.adapters.tortoise.adapter import Adapter
from tests.system_models import system_models


class Listener(models.Model):
id = fields.IntField(pk=True)
name = fields.CharField(max_length=50)


class Event(models.Model):
id = fields.IntField(pk=True)
name = fields.CharField(max_length=50)
listeners: fields.ManyToManyRelation[Listener] = fields.ManyToManyField(
"models.Listener", related_name="events"
)


class TestTortoiseAdapterM2MCreation:
"""Ensure the adapter handles many-to-many payloads safely."""

adapter: ClassVar[Adapter]

@classmethod
def setup_class(cls) -> None:
asyncio.run(
Tortoise.init(
db_url="sqlite://:memory:",
modules={
"models": [__name__],
"admin": list(system_models.module_names()),
},
)
)
asyncio.run(Tortoise.generate_schemas())
cls.adapter = Adapter()

@classmethod
def teardown_class(cls) -> None:
asyncio.run(Tortoise.close_connections())

@pytest.mark.asyncio
async def test_many_to_many_payload_does_not_replace_manager(self) -> None:
"""Verify that the adapter keeps the relation manager intact."""

listener = await Listener.create(name="first")
payload = {"name": "event", "listeners": [{"id": listener.id}]}

event = await self.adapter.create(Event, include_m2m=["listeners"], **payload)

await self.adapter.fetch_related(event, "listeners")

relation = fields.relational.ManyToManyRelation(
event, Event._meta.fields_map["listeners"]
)
related = await relation.all()
assert len(related) == 1
assert related[0].id == listener.id


# The End