From 62cd6f5a6b87a415d45b0297ba20e5e79cd20d57 Mon Sep 17 00:00:00 2001 From: Javier Buzzi Date: Thu, 12 Dec 2024 13:48:11 +0100 Subject: [PATCH] Rebase --- factory/builder.py | 34 ++++-- factory/django.py | 186 ++++++++++++++++++++++++++++- tests/djapp/models.py | 45 +++++++ tests/djapp/settings.py | 3 +- tests/test_django.py | 256 +++++++++++++++++++++++++++++++++++++++- tests/test_using.py | 13 ++ 6 files changed, 515 insertions(+), 22 deletions(-) diff --git a/factory/builder.py b/factory/builder.py index e76e7556..8af292f3 100644 --- a/factory/builder.py +++ b/factory/builder.py @@ -218,14 +218,18 @@ def chain(self): parent_chain = () return (self.stub,) + parent_chain - def recurse(self, factory, declarations, force_sequence=None): + def recurse(self, factory, declarations, force_sequence=None, collect_instances=None): from . import base if not issubclass(factory, base.BaseFactory): raise errors.AssociatedClassError( "%r: Attempting to recursing into a non-factory object %r" % (self, factory)) builder = self.builder.recurse(factory._meta, declarations) - return builder.build(parent_step=self, force_sequence=force_sequence) + return builder.build( + parent_step=self, + force_sequence=force_sequence, + collect_instances=collect_instances, + ) def __repr__(self): return f"" @@ -246,7 +250,7 @@ def __init__(self, factory_meta, extras, strategy): self.extras = extras self.force_init_sequence = extras.pop('__sequence', None) - def build(self, parent_step=None, force_sequence=None): + def build(self, parent_step=None, force_sequence=None, collect_instances=None): """Build a factory instance.""" # TODO: Handle "batch build" natively pre, post = parse_declarations( @@ -277,19 +281,23 @@ def build(self, parent_step=None, force_sequence=None): kwargs=kwargs, ) - postgen_results = {} - for declaration_name in post.sorted(): - declaration = post[declaration_name] - postgen_results[declaration_name] = declaration.declaration.evaluate_post( + if collect_instances is None: + postgen_results = {} + for declaration_name in post.sorted(): + declaration = post[declaration_name] + postgen_results[declaration_name] = declaration.declaration.evaluate_post( + instance=instance, + step=step, + overrides=declaration.context, + ) + self.factory_meta.use_postgeneration_results( instance=instance, step=step, - overrides=declaration.context, + results=postgen_results, ) - self.factory_meta.use_postgeneration_results( - instance=instance, - step=step, - results=postgen_results, - ) + else: + collect_instances.append(instance) + return instance def recurse(self, factory_meta, extras): diff --git a/factory/django.py b/factory/django.py index b53fd5b5..bc482a74 100644 --- a/factory/django.py +++ b/factory/django.py @@ -9,20 +9,20 @@ import logging import os import warnings +from collections import defaultdict from typing import Dict, TypeVar from django.contrib.auth.hashers import make_password from django.core import files as django_files -from django.db import IntegrityError +from django.db import IntegrityError, connections, models +from django.db.models.sql import InsertQuery -from . import base, declarations, errors +from . import base, builder, declarations, enums, errors logger = logging.getLogger('factory.generate') - DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS T = TypeVar("T") - _LAZY_LOADS: Dict[str, object] = {} @@ -45,11 +45,29 @@ def _lazy_load_get_model(): _LAZY_LOADS['get_model'] = django_apps.apps.get_model +def connection_supports_bulk_insert(using): + """ + Does the database support bulk_insert + + There are 2 pieces to this puzzle: + * The database needs to support `bulk_insert` + * AND it also needs to be capable of returning all the newly minted objects' id + + If any of these is `False`, the database does NOT support bulk_insert + """ + db_features = connections[using].features + return ( + db_features.has_bulk_insert + and db_features.can_return_rows_from_bulk_insert + ) + + class DjangoOptions(base.FactoryOptions): def _build_default_options(self): return super()._build_default_options() + [ base.OptionDefault('django_get_or_create', (), inherit=True), base.OptionDefault('database', DEFAULT_DB_ALIAS, inherit=True), + base.OptionDefault('use_bulk_create', False, inherit=True), base.OptionDefault('skip_postgeneration_save', False, inherit=True), ] @@ -165,6 +183,89 @@ def _get_or_create(cls, model_class, *args, **kwargs): return instance + @classmethod + def supports_bulk_insert(cls): + return (cls._meta.use_bulk_create + and connection_supports_bulk_insert(cls._meta.database)) + + @classmethod + def create(cls, **kwargs): + """Create an instance of the associated class, with overridden attrs.""" + if not cls.supports_bulk_insert(): + return super().create(**kwargs) + + return cls._bulk_create(1, **kwargs)[0] + + @classmethod + def create_batch(cls, size, **kwargs): + if not cls.supports_bulk_insert(): + return super().create_batch(size, **kwargs) + + return cls._bulk_create(size, **kwargs) + + @classmethod + def _refresh_database_pks(cls, model_cls, objs): + # Avoid causing a django.core.exceptions.AppRegistryNotReady throughout all the tests. + # TODO: remove the `from . import django` from the `__init__.py` + from django.contrib.contenttypes.fields import GenericForeignKey + + def get_field_value(instance, field): + if isinstance(field, GenericForeignKey) and field.is_cached(instance): + return field.get_cached_value(instance) + return getattr(instance, field.name) + + # Current Django version's GenericForeignKey is not made to work with bulk_insert. + # + # The issue is that it caches the object referenced, once the object is + # saved and receives a pk, the cache no longer matches. It doesn't + # matter that it's the same obj reference. This is to bypass that pk + # check and reset it. + fields_to_reset = (GenericForeignKey, models.OneToOneField) + + fields = [f for f in model_cls._meta.get_fields() if isinstance(f, fields_to_reset)] + if not fields: + return + + for obj in objs: + for field in fields: + setattr(obj, field.name, get_field_value(obj, field)) + + @classmethod + def _bulk_create(cls, size, **kwargs): + if cls._meta.abstract: + raise errors.FactoryError( + "Cannot generate instances of abstract factory %(f)s; " + "Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract " + "is either not set or False." % dict(f=cls.__name__)) + + models_to_return = [] + instances = [] + for _ in range(size): + step = builder.StepBuilder(cls._meta, kwargs, enums.BUILD_STRATEGY) + models_to_return.append(step.build(collect_instances=instances)) + + for model_cls, objs in dependency_insert_order(instances): + manager = cls._get_manager(model_cls) + cls._refresh_database_pks(model_cls, objs) + + concrete_model = True + for parent in model_cls._meta.get_parent_list(): + if parent._meta.concrete_model is not model_cls._meta.concrete_model: + concrete_model = False + + if concrete_model: + manager.bulk_create(objs) + else: + concrete_fields = model_cls._meta.local_fields + connection = connections[cls._meta.database] + + # Avoids writing the INSERT INTO sql script manually + query = InsertQuery(model_cls) + query.insert_values(concrete_fields, objs) + query.get_compiler(connection=connection).execute_sql() + + return models_to_return + @classmethod def _create(cls, model_class, *args, **kwargs): """Create an instance of the model, and save it to the database.""" @@ -272,6 +373,82 @@ def _make_data(self, params): return thumb_io.getvalue() +def dependency_insert_order(data): + """This is almost the same function from django/core/serializers/__init__.py:sort_dependencies with a slight + modification on `if hasattr(rel_model, 'natural_key') and rel_model != model:` that was removed, so we have the + REAL dependency order. The original implementation was setup to only write to fields in order if they had a known + dependency, we always want it in order regardless of the natural_key. + """ + + lookup = [] + model_cls_by_data = defaultdict(list) + for instance in data: + # Instance has been persisted in the database + if not instance._state.adding: + continue + # Instance already in the list + if id(instance) in lookup: + continue + model_cls_by_data[type(instance)].append(instance) + + # Avoid data leaks + del lookup + del data + + # Process the list of models, and get the list of dependencies + model_dependencies = [] + models = list(model_cls_by_data.keys()) + + for model in models: + deps = set() + + # Now add a dependency for any FK relation with a model that + # defines a natural key + for field in model._meta.fields: + rel_model = field.related_model + if rel_model and rel_model != model: + deps.add(rel_model) + + model_dependencies.append((model, deps)) + + model_dependencies.reverse() + # Now sort the models to ensure that dependencies are met. This + # is done by repeatedly iterating over the input list of models. + # If all the dependencies of a given model are in the final list, + # that model is promoted to the end of the final list. This process + # continues until the input list is empty, or we do a full iteration + # over the input models without promoting a model to the final list. + # If we do a full iteration without a promotion, that means there are + # circular dependencies in the list. + model_list = [] + while model_dependencies: + skipped = [] + changed = False + while model_dependencies: + model, deps = model_dependencies.pop() + + # If all of the models in the dependency list are either already + # on the final model list, or not on the original serialization list, + # then we've found another model with all it's dependencies satisfied. + found = True + for candidate in ((d not in models or d in model_list) for d in deps): + if not candidate: + found = False + if found: + model_list.append(model) + changed = True + else: + skipped.append((model, deps)) + if not changed: + unresolved_models = (f'{model._meta.app_label}.{model._meta.object_name}' + for model, _ in sorted(skipped, key=lambda obj: obj[0].__name__)) + message = f"Can't resolve dependencies for {', '.join(unresolved_models)}." + raise RuntimeError(message) + model_dependencies = skipped + + return [(model_cls, model_cls_by_data[model_cls]) for model_cls in model_list] + + class mute_signals: """Temporarily disables and then restores any django signals. @@ -327,6 +504,7 @@ def __call__(self, callable_obj): if isinstance(callable_obj, base.FactoryMetaClass): # Retrieve __func__, the *actual* callable object. callable_obj._create = self.wrap_method(callable_obj._create.__func__) + callable_obj._bulk_create = self.wrap_method(callable_obj._bulk_create.__func__) callable_obj._generate = self.wrap_method(callable_obj._generate.__func__) callable_obj._after_postgeneration = self.wrap_method( callable_obj._after_postgeneration.__func__ diff --git a/tests/djapp/models.py b/tests/djapp/models.py index b7aa8794..516efdf1 100644 --- a/tests/djapp/models.py +++ b/tests/djapp/models.py @@ -5,6 +5,8 @@ import os.path from django.conf import settings +from django.contrib.contenttypes.fields import GenericForeignKey +from django.contrib.contenttypes.models import ContentType from django.db import models from django.db.models import signals @@ -137,3 +139,46 @@ class FromAbstractWithCustomManager(AbstractWithCustomManager): class HasMultifieldModel(models.Model): multifield = models.ForeignKey(to=MultifieldModel, on_delete=models.CASCADE) + + +class P(models.Model): + pass + + +class R(models.Model): + is_default = models.BooleanField(default=False) + p = models.ForeignKey(P, models.CASCADE, null=True) + + +class S(models.Model): + r = models.ForeignKey(R, models.CASCADE) + + +class T(models.Model): + s = models.ForeignKey(S, models.CASCADE) + + +class U(models.Model): + t = models.ForeignKey(T, models.CASCADE) + + +class RChild(R): + text = models.CharField(max_length=10) + + +class A(models.Model): + p_o = models.OneToOneField('P', models.CASCADE, related_name="+") + p_f = models.ForeignKey('P', models.CASCADE, related_name="+") + p_m = models.ManyToManyField('P') + + +class AA(models.Model): + a = models.OneToOneField(A, models.CASCADE) + u = models.OneToOneField(U, models.CASCADE) + p = models.OneToOneField(P, models.CASCADE) + + +class GenericModel(models.Model): + content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) + object_id = models.PositiveIntegerField() + generic_obj = GenericForeignKey("content_type", "object_id") diff --git a/tests/djapp/settings.py b/tests/djapp/settings.py index 13f7d366..076c393d 100644 --- a/tests/djapp/settings.py +++ b/tests/djapp/settings.py @@ -23,7 +23,8 @@ INSTALLED_APPS = [ - 'tests.djapp' + 'django.contrib.contenttypes', + 'tests.djapp', ] MIDDLEWARE_CLASSES = () diff --git a/tests/test_django.py b/tests/test_django.py index 066d7920..a9c891c8 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -2,9 +2,12 @@ """Tests for factory_boy/Django interactions.""" +import inspect import io import os +import tempfile import unittest +from contextlib import ExitStack from unittest import mock try: @@ -13,9 +16,11 @@ raise unittest.SkipTest("django tests disabled.") from django import test as django_test +from django.apps import apps from django.conf import settings from django.contrib.auth.hashers import check_password -from django.core.management import color +from django.core.management import call_command, color +from django.core.management.commands.migrate import Command as MigrateCommand from django.db import IntegrityError, connections from django.db.models import signals from django.test import utils as django_test_utils @@ -37,17 +42,41 @@ from .djapp import models # noqa:E402 isort:skip test_state = {} +cleanup = ExitStack() def setUpModule(): - django_test_utils.setup_test_environment() - runner_state = django_test_utils.setup_databases(verbosity=0, interactive=False) - test_state['runner_state'] = runner_state + project_path = os.path.abspath(os.curdir) + + for app_config in apps.get_app_configs(): + module = app_config.module + app_path = os.path.dirname(os.path.abspath(inspect.getsourcefile(module))) + + if app_path.startswith(project_path): + temp_dir = cleanup.enter_context(tempfile.TemporaryDirectory(prefix='migrations_', dir=app_path)) + # Need to make this directory a proper python module otherwise django will refuse to recognize it + open(os.path.join(temp_dir, '__init__.py'), 'a').close() + settings.MIGRATION_MODULES[app_config.label] = '%s.%s' % (app_config.module.__name__, + os.path.basename(temp_dir)) + + def WrappedMigrateCommand(*args, **kwargs): + """ + Because we're using django's `contenttypes` there is no way to get the migrations to work properly + """ + for app in settings.MIGRATION_MODULES: + call_command('makemigrations', name=app, verbosity=0) + return MigrateCommand(*args, **kwargs) + + with mock.patch('django.core.management.commands.migrate.Command', wraps=WrappedMigrateCommand): + django_test_utils.setup_test_environment() + runner_state = django_test_utils.setup_databases(verbosity=0, interactive=False) + test_state['runner_state'] = runner_state def tearDownModule(): django_test_utils.teardown_databases(test_state['runner_state'], verbosity=0) django_test_utils.teardown_test_environment() + cleanup.close() class StandardFactory(factory.django.DjangoModelFactory): @@ -150,6 +179,225 @@ class Meta: text = factory.Sequence(lambda n: "text%s" % n) +class PFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.P + use_bulk_create = True + skip_postgeneration_save = True + + +class RFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.R + use_bulk_create = True + skip_postgeneration_save = True + + is_default = True + p = factory.SubFactory(PFactory) + + +class RChildFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.RChild + use_bulk_create = True + skip_postgeneration_save = True + + text = 'test' + r_ptr = factory.SubFactory(RFactory) + + +class SFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.S + use_bulk_create = True + skip_postgeneration_save = True + + r = factory.SubFactory(RFactory) + + +class TFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.T + use_bulk_create = True + skip_postgeneration_save = True + + s = factory.SubFactory(SFactory) + + +class UFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.U + use_bulk_create = True + skip_postgeneration_save = True + + t = factory.SubFactory(TFactory) + + +class APFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.A.p_m.through + use_bulk_create = True + skip_postgeneration_save = True + + a = factory.SubFactory('tests.test_django.AFactory') + p = factory.SubFactory(PFactory) + + +class AFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.A + use_bulk_create = True + skip_postgeneration_save = True + + p_o = factory.SubFactory(PFactory) + p_f = factory.SubFactory(PFactory) + + +class AWithMFactory(AFactory): + p_m = factory.RelatedFactoryList(APFactory, factory_related_name='a') + + +class AAFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.AA + use_bulk_create = True + skip_postgeneration_save = True + + a = factory.SubFactory(AWithMFactory) + u = factory.SubFactory(UFactory) + p = factory.SubFactory(PFactory) + + +def lazy_content_type(o): + from django.contrib.contenttypes.models import ContentType + return ContentType.objects.get_for_model(o.generic_obj) + + +class GenericModelFactory(factory.django.DjangoModelFactory): + class Meta: + # exclude = ['generic_obj'] + abstract = True + skip_postgeneration_save = True + + object_id = factory.SelfAttribute('generic_obj.id') + content_type = factory.LazyAttribute(lazy_content_type) + + +class GenericPFactory(GenericModelFactory): + generic_obj = factory.SubFactory(PFactory) + + class Meta: + use_bulk_create = True + model = models.GenericModel + skip_postgeneration_save = True + + +class DependencyInsertOrderTest(django_test.TestCase): + + def test_empty(self): + actual = factory.django.dependency_insert_order([]) + self.assertEqual(actual, []) + + def test_sub_create(self): + p1 = models.P() + p2 = models.P() + r1 = models.R(p=p1) + r2 = models.R(p=p2) + actual = factory.django.dependency_insert_order([r1, r2, p1, p2]) + self.assertEqual(actual, [(models.P, [p1, p2]), (models.R, [r1, r2])]) + + def test_sub_all_ready_created(self): + p1 = PFactory() + p2 = models.P() + r1 = models.R(p=p1) + r2 = models.R(p=p2) + r3 = RFactory() + actual = factory.django.dependency_insert_order([p1, p2, r1, r2, r3]) + + # Note that `p1` is ignored completely since it was created already + # Note that `r3` along with `r3.p` is ignored completely since it was created already + self.assertEqual(actual, [(models.P, [p2]), (models.R, [r1, r2])]) + + def test_new_m2m(self): + step = factory.builder.StepBuilder(AWithMFactory._meta, {}, factory.enums.BUILD_STRATEGY) + created_instances = [] + a1 = step.build(collect_instances=created_instances) + p1 = a1.p_o + p2 = a1.p_f + p_m1, p_m2 = [x for x in created_instances if isinstance(x, models.A.p_m.through)] + p3 = p_m1.p + p4 = p_m2.p + actual = factory.django.dependency_insert_order(created_instances) + self.assertEqual(actual, [(models.P, [p1, p2, p3, p4]), + (models.A, [a1]), + (models.A.p_m.through, [p_m1, p_m2])]) + + +class DjangoBulkInsertTest(django_test.TestCase): + SUPPORTS_BULK_INSERT = factory.django.connection_supports_bulk_insert( + factory.django.DEFAULT_DB_ALIAS + ) + + def test_single_object_create(self): + EXPECTED_QUERIES = 1 if self.SUPPORTS_BULK_INSERT else 1 + with self.assertNumQueries(EXPECTED_QUERIES): + PFactory() + + def test_single_object_create_batch(self): + EXPECTED_QUERIES = 1 if self.SUPPORTS_BULK_INSERT else 10 + with self.assertNumQueries(EXPECTED_QUERIES): + PFactory.create_batch(10) + + def test_one_level_nested_single_object_create(self): + EXPECTED_QUERIES = 2 if self.SUPPORTS_BULK_INSERT else 2 + with self.assertNumQueries(EXPECTED_QUERIES): + RFactory() + + existing_p = PFactory() + EXPECTED_QUERIES = 1 if self.SUPPORTS_BULK_INSERT else 1 + with self.assertNumQueries(EXPECTED_QUERIES): + RFactory(p=existing_p) + + def test_one_level_nested_single_object_create_batch(self): + EXPECTED_QUERIES = 2 if self.SUPPORTS_BULK_INSERT else 20 + with self.assertNumQueries(EXPECTED_QUERIES): + RFactory.create_batch(10) + + existing_p = PFactory() + EXPECTED_QUERIES = 1 if self.SUPPORTS_BULK_INSERT else 10 + with self.assertNumQueries(EXPECTED_QUERIES): + RFactory.create_batch(10, p=existing_p) + + def test_one_level_nested_m2m_create_batch(self): + EXPECTED_QUERIES = 3 if self.SUPPORTS_BULK_INSERT else 70 + with self.assertNumQueries(EXPECTED_QUERIES): + AWithMFactory.create_batch(10) + + existing_p = PFactory() + EXPECTED_QUERIES = 3 if self.SUPPORTS_BULK_INSERT else 60 + with self.assertNumQueries(EXPECTED_QUERIES): + AWithMFactory.create_batch(10, p_f=existing_p) + + def test_multi_level_nested_m2m_create_batch(self): + EXPECTED_QUERIES = 8 if self.SUPPORTS_BULK_INSERT else 140 + with self.assertNumQueries(EXPECTED_QUERIES): + AAFactory.create_batch(10) + + def test_single_generic(self): + EXPECTED_QUERIES = 3 if self.SUPPORTS_BULK_INSERT else 3 + with self.assertNumQueries(EXPECTED_QUERIES): + GenericPFactory() + + def test_multi_table_inherited_model(self): + EXPECTED_QUERIES = 3 if self.SUPPORTS_BULK_INSERT else 4 + with self.assertNumQueries(EXPECTED_QUERIES): + RChildFactory() + + EXPECTED_QUERIES = 3 if self.SUPPORTS_BULK_INSERT else 40 + with self.assertNumQueries(EXPECTED_QUERIES): + RChildFactory.create_batch(10) + + class ModelTests(django_test.TestCase): databases = {'default', 'replica'} diff --git a/tests/test_using.py b/tests/test_using.py index 5b2200a6..23b31f34 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -78,6 +78,9 @@ def create(self, **kwargs): instance._defaults = None return instance + def bulk_create(self, objs, **kwargs): + return objs + def values_list(self, *args, **kwargs): return self @@ -87,6 +90,16 @@ def order_by(self, *args, **kwargs): def using(self, db): return self + class _meta: + concrete_model = None + + @staticmethod + def get_fields(*args, **kwargs): + return [] + + class _state: + adding = True + objects = FakeModelManager() def __init__(self, **kwargs):