From a1ccad41cb45ab69df3b984e3f8812cdb839cc3d Mon Sep 17 00:00:00 2001 From: "rodrigo.nogueira" Date: Fri, 26 Dec 2025 15:45:52 -0300 Subject: [PATCH] feat: Add Excluding declaration and Faker unique parameter - Add Excluding declaration that wraps base declarations and excludes specified values with retry logic - Add unique parameter to Faker for generating unique values - Supports static exclusion lists and dynamic exclusions via SelfAttribute - Implements recursive _make_hashable for nested structures - Plain value exclusion raises immediately without retrying - Comprehensive test coverage with 15 new tests --- factory/__init__.py | 3 +- factory/declarations.py | 235 +++++++++++++++++++++++++++++++++------- factory/faker.py | 27 +++-- tests/test_excluding.py | 200 ++++++++++++++++++++++++++++++++++ tests/test_faker.py | 166 ++++++++++++++++++++++++---- 5 files changed, 562 insertions(+), 69 deletions(-) create mode 100644 tests/test_excluding.py diff --git a/factory/__init__.py b/factory/__init__.py index 62042a2a..cc9d138e 100644 --- a/factory/__init__.py +++ b/factory/__init__.py @@ -14,6 +14,7 @@ from .declarations import ( ContainerAttribute, Dict, + Excluding, Iterator, LazyAttribute, LazyAttributeSequence, @@ -71,5 +72,5 @@ except ImportError: pass -__author__ = 'Raphaël Barrois ' +__author__ = "Raphaël Barrois " __version__ = importlib.metadata.version("factory_boy") diff --git a/factory/declarations.py b/factory/declarations.py index f835f0d2..01056d8b 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -7,7 +7,7 @@ from . import enums, errors, utils -logger = logging.getLogger('factory.generate') +logger = logging.getLogger("factory.generate") class BaseDeclaration(utils.OrderedBase): @@ -45,6 +45,7 @@ def unroll_context(self, instance, step, context): return full_context import factory.base + subfactory = factory.base.DictFactory return step.recurse(subfactory, full_context, force_sequence=step.sequence) @@ -76,7 +77,7 @@ def evaluate(self, instance, step, extra): extra (dict): additional, call-time added kwargs for the step. """ - raise NotImplementedError('This is an abstract method') + raise NotImplementedError("This is an abstract method") class OrderedDeclaration(BaseDeclaration): @@ -130,11 +131,12 @@ class Force: The forced value can be any declaration, and will be evaluated as if it had been passed instead of the Transformer declaration. """ + def __init__(self, forced_value): self.forced_value = forced_value def __repr__(self): - return f'Transformer.Force({repr(self.forced_value)})' + return f"Transformer.Force({repr(self.forced_value)})" def __init__(self, default, *, transform): super().__init__() @@ -162,6 +164,133 @@ def evaluate_pre(self, instance, step, overrides): return self.transform(value) +def _make_hashable(value: object) -> object: + """Recursively convert a value to a hashable type for set operations. + + Handles nested structures by recursively converting their contents. + Raises TypeError for types that cannot be made hashable. + """ + if isinstance(value, (str, bytes, int, float, bool, type(None))): + return value + if isinstance(value, tuple): + return tuple(_make_hashable(v) for v in value) + if isinstance(value, list): + return tuple(_make_hashable(v) for v in value) + if isinstance(value, dict): + return tuple( + sorted((_make_hashable(k), _make_hashable(v)) for k, v in value.items()) + ) + if isinstance(value, (set, frozenset)): + return frozenset(_make_hashable(v) for v in value) + try: + hash(value) + return value + except TypeError: + raise TypeError( + f"Cannot make {type(value).__name__} hashable for exclusion check" + ) + + +def _to_iterable(value: object) -> tuple | list | set | frozenset: + """Convert a value to an iterable, wrapping single values in a tuple.""" + if value is None: + return () + if isinstance(value, (list, tuple, set, frozenset)): + return value + return (value,) + + +class Excluding(OrderedDeclaration): + """Generate value from base declaration, excluding specified values. + + Retries generation when an excluded value is produced. Supports both + static exclusion lists and dynamic exclusions via declarations. + + Args: + base: The base declaration (or plain value) to generate values from + exclude: Value(s) to exclude. Can be: + - A single value + - An iterable of values (list, tuple, set) + - A declaration that returns value(s) to exclude (e.g., SelfAttribute) + max_retries: Maximum generation attempts before raising ValueError (default: 50) + + Raises: + ValueError: When max_retries exceeded without finding a non-excluded value + TypeError: When exclusion values cannot be made hashable + + Examples: + # Exclude specific values from FuzzyChoice + status = Excluding( + FuzzyChoice(['active', 'inactive', 'pending', 'banned']), + exclude=['banned', 'inactive'] + ) + + # Exclude a single value from Faker + email = Excluding( + Faker('email'), + exclude='admin@example.com' + ) + + # Dynamically exclude based on another field + class UserFactory(factory.Factory): + class Meta: + model = User + + primary_language = FuzzyChoice(['en', 'fr', 'de', 'es']) + secondary_language = Excluding( + FuzzyChoice(['en', 'fr', 'de', 'es']), + exclude=SelfAttribute('primary_language') + ) + """ + + def __init__( + self, + base: "BaseDeclaration | object", + exclude: "BaseDeclaration | object" = (), + max_retries: int = 50, + ) -> None: + super().__init__() + self.base = base + self.exclude_declaration = exclude + self.max_retries = max_retries + + def evaluate(self, instance: object, step: object, extra: dict) -> object: + """Evaluate with retry logic for excluded values.""" + excluded_values = self._resolve_exclusions(instance, step) + excluded_set = frozenset(_make_hashable(v) for v in excluded_values) + + if not isinstance(self.base, BaseDeclaration): + value = self.base + if _make_hashable(value) in excluded_set: + raise ValueError( + f"Plain value {value!r} is excluded and cannot be generated. " + f"Excluded: {self._format_exclusions(excluded_values)}." + ) + return value + + for _ in range(self.max_retries): + value = self.base.evaluate_pre(instance, step, {}) + if _make_hashable(value) not in excluded_set: + return value + + raise ValueError( + f"Could not generate a non-excluded value after {self.max_retries} attempts. " + f"Excluded: {self._format_exclusions(excluded_values)}. " + f"Consider reducing exclusions or increasing max_retries." + ) + + def _resolve_exclusions(self, instance: object, step: object) -> list: + """Resolve exclusion values, evaluating declarations if needed.""" + if isinstance(self.exclude_declaration, BaseDeclaration): + excluded_value = self.exclude_declaration.evaluate(instance, step, {}) + return list(_to_iterable(excluded_value)) + return list(_to_iterable(self.exclude_declaration)) + + def _format_exclusions(self, excluded_values: list) -> str: + """Format exclusion values for error messages.""" + return ", ".join(repr(v) for v in sorted(excluded_values, key=str)) + + class _UNSPECIFIED: pass @@ -183,8 +312,8 @@ def deepgetattr(obj, name, default=_UNSPECIFIED): AttributeError: if obj has no 'name' attribute. """ try: - if '.' in name: - attr, subname = name.split('.', 1) + if "." in name: + attr, subname = name.split(".", 1) return deepgetattr(getattr(obj, attr), subname, default) else: return getattr(obj, name) @@ -210,7 +339,7 @@ class SelfAttribute(BaseDeclaration): def __init__(self, attribute_name, default=_UNSPECIFIED): super().__init__() - depth = len(attribute_name) - len(attribute_name.lstrip('.')) + depth = len(attribute_name) - len(attribute_name.lstrip(".")) attribute_name = attribute_name[depth:] self.depth = depth @@ -224,11 +353,13 @@ def evaluate(self, instance, step, extra): else: target = instance - logger.debug("SelfAttribute: Picking attribute %r on %r", self.attribute_name, target) + logger.debug( + "SelfAttribute: Picking attribute %r on %r", self.attribute_name, target + ) return deepgetattr(target, self.attribute_name, self.default) def __repr__(self): - return '<%s(%r, default=%r)>' % ( + return "<%s(%r, default=%r)>" % ( self.__class__.__name__, self.attribute_name, self.default, @@ -251,7 +382,9 @@ def __init__(self, iterator, cycle=True, getter=None): self.iterator = None if cycle: - self.iterator_builder = lambda: utils.ResetableIterator(itertools.cycle(iterator)) + self.iterator_builder = lambda: utils.ResetableIterator( + itertools.cycle(iterator) + ) else: self.iterator_builder = lambda: utils.ResetableIterator(iterator) @@ -282,12 +415,17 @@ class Sequence(BaseDeclaration): function (function): A function, expecting the current sequence counter and returning the computed value. """ + def __init__(self, function): super().__init__() self.function = function def evaluate(self, instance, step, extra): - logger.debug("Sequence: Computing next value of %r for seq=%s", self.function, step.sequence) + logger.debug( + "Sequence: Computing next value of %r for seq=%s", + self.function, + step.sequence, + ) return self.function(int(step.sequence)) @@ -300,10 +438,14 @@ class LazyAttributeSequence(Sequence): type (function): A function converting an integer into the expected kind of counter for the 'function' attribute. """ + def evaluate(self, instance, step, extra): logger.debug( "LazyAttributeSequence: Computing next value of %r for seq=%s, obj=%r", - self.function, step.sequence, instance) + self.function, + step.sequence, + instance, + ) return self.function(instance, int(step.sequence)) @@ -316,6 +458,7 @@ class ContainerAttribute(BaseDeclaration): strict (bool): Whether evaluating should fail when the containers are not passed in (i.e used outside a SubFactory). """ + def __init__(self, function, strict=True): super().__init__() self.function = function @@ -336,7 +479,8 @@ def evaluate(self, instance, step, extra): if self.strict and not chain: raise TypeError( "A ContainerAttribute in 'strict' mode can only be used " - "within a SubFactory.") + "within a SubFactory." + ) return self.function(instance, chain) @@ -388,18 +532,20 @@ class _FactoryWrapper: Such args can be either a Factory subclass, or a fully qualified import path for that subclass (e.g 'myapp.factories.MyFactory'). """ + def __init__(self, factory_or_path): self.factory = None - self.module = self.name = '' + self.module = self.name = "" if isinstance(factory_or_path, type): self.factory = factory_or_path else: - if not (isinstance(factory_or_path, str) and '.' in factory_or_path): + if not (isinstance(factory_or_path, str) and "." in factory_or_path): raise ValueError( "A factory= argument must receive either a class " "or the fully qualified path to a Factory subclass; got " - "%r instead." % factory_or_path) - self.module, self.name = factory_or_path.rsplit('.', 1) + "%r instead." % factory_or_path + ) + self.module, self.name = factory_or_path.rsplit(".", 1) def get(self): if self.factory is None: @@ -411,9 +557,9 @@ def get(self): def __repr__(self): if self.factory is None: - return f'<_FactoryImport: {self.module}.{self.name}>' + return f"<_FactoryImport: {self.module}.{self.name}>" else: - return f'<_FactoryImport: {self.factory.__class__}>' + return f"<_FactoryImport: {self.factory.__class__}>" class SubFactory(BaseDeclaration): @@ -449,7 +595,8 @@ def evaluate(self, instance, step, extra): subfactory = self.get_factory() logger.debug( "SubFactory: Instantiating %s.%s(%s), create=%r", - subfactory.__module__, subfactory.__name__, + subfactory.__module__, + subfactory.__name__, utils.log_pprint(kwargs=extra), step, ) @@ -462,7 +609,7 @@ class Dict(SubFactory): FORCE_SEQUENCE = True - def __init__(self, params, dict_factory='factory.DictFactory'): + def __init__(self, params, dict_factory="factory.DictFactory"): super().__init__(dict_factory, **dict(params)) @@ -471,7 +618,7 @@ class List(SubFactory): FORCE_SEQUENCE = True - def __init__(self, params, list_factory='factory.ListFactory'): + def __init__(self, params, list_factory="factory.ListFactory"): params = {str(i): v for i, v in enumerate(params)} super().__init__(list_factory, **params) @@ -501,15 +648,19 @@ def __init__(self, decider, yes_declaration=SKIP, no_declaration=SKIP): self.no = no_declaration phases = { - 'yes_declaration': enums.get_builder_phase(yes_declaration), - 'no_declaration': enums.get_builder_phase(no_declaration), + "yes_declaration": enums.get_builder_phase(yes_declaration), + "no_declaration": enums.get_builder_phase(no_declaration), } used_phases = {phase for phase in phases.values() if phase is not None} if len(used_phases) > 1: raise TypeError(f"Inconsistent phases for {self!r}: {phases!r}") - self.FACTORY_BUILDER_PHASE = used_phases.pop() if used_phases else enums.BuilderPhase.ATTRIBUTE_RESOLUTION + self.FACTORY_BUILDER_PHASE = ( + used_phases.pop() + if used_phases + else enums.BuilderPhase.ATTRIBUTE_RESOLUTION + ) def evaluate_post(self, instance, step, overrides): """Handle post-generation declarations""" @@ -518,11 +669,13 @@ def evaluate_post(self, instance, step, overrides): # Note: we work on the *builder stub*, not on the actual instance. # This gives us access to all Params-level definitions. choice = self.decider.evaluate_pre( - instance=step.stub, step=step, overrides=overrides) + instance=step.stub, step=step, overrides=overrides + ) else: assert decider_phase == enums.BuilderPhase.POST_INSTANTIATION choice = self.decider.evaluate_post( - instance=instance, step=step, overrides={}) + instance=instance, step=step, overrides={} + ) target = self.yes if choice else self.no if enums.get_builder_phase(target) == enums.BuilderPhase.POST_INSTANTIATION: @@ -548,7 +701,7 @@ def evaluate_pre(self, instance, step, overrides): ) def __repr__(self): - return f'Maybe({self.decider!r}, yes={self.yes!r}, no={self.no!r})' + return f"Maybe({self.decider!r}, yes={self.yes!r}, no={self.no!r})" class Parameter(utils.OrderedBase): @@ -596,6 +749,7 @@ def wrap(cls, value): class Trait(Parameter): """The simplest complex parameter, it enables a bunch of new declarations based on a boolean flag.""" + def __init__(self, **overrides): super().__init__() self.overrides = overrides @@ -605,8 +759,9 @@ def as_declarations(self, field_name, declarations): for maybe_field, new_value in self.overrides.items(): overrides[maybe_field] = Maybe( decider=SelfAttribute( - '%s.%s' % ( - '.' * maybe_field.count(enums.SPLITTER), + "%s.%s" + % ( + "." * maybe_field.count(enums.SPLITTER), field_name, ), default=False, @@ -621,9 +776,9 @@ def get_revdeps(self, parameters): return [param for param in parameters if param in self.overrides] def __repr__(self): - return '%s(%s)' % ( + return "%s(%s)" % ( self.__class__.__name__, - ', '.join('%s=%r' % t for t in self.overrides.items()) + ", ".join("%s=%r" % t for t in self.overrides.items()), ) @@ -645,9 +800,9 @@ class PostGenerationDeclaration(BaseDeclaration): def evaluate_post(self, instance, step, overrides): context = self.unroll_context(instance, step, overrides) postgen_context = PostGenerationContext( - value_provided=bool('' in context), - value=context.get(''), - extra={k: v for k, v in context.items() if k != ''}, + value_provided=bool("" in context), + value=context.get(""), + extra={k: v for k, v in context.items() if k != ""}, ) return self.call(instance, step, postgen_context) @@ -665,6 +820,7 @@ def call(self, instance, step, context): # pragma: no cover class PostGeneration(PostGenerationDeclaration): """Calls a given function once the object has been generated.""" + def __init__(self, function): super().__init__() self.function = function @@ -680,8 +836,7 @@ def call(self, instance, step, context): ), ) create = step.builder.strategy == enums.CREATE_STRATEGY - return self.function( - instance, create, context.value, **context.extra) + return self.function(instance, create, context.value, **context.extra) class RelatedFactory(PostGenerationDeclaration): @@ -696,7 +851,7 @@ class RelatedFactory(PostGenerationDeclaration): UNROLL_CONTEXT_BEFORE_EVALUATION = False - def __init__(self, factory, factory_related_name='', **defaults): + def __init__(self, factory, factory_related_name="", **defaults): super().__init__() self.name = factory_related_name @@ -715,7 +870,8 @@ def call(self, instance, step, context): logger.debug( "RelatedFactory: Using provided %r instead of generating %s.%s.", context.value, - factory.__module__, factory.__name__, + factory.__module__, + factory.__name__, ) return context.value @@ -745,7 +901,7 @@ class RelatedFactoryList(RelatedFactory): returning a list of 'factory' objects w/ size 'size'. """ - def __init__(self, factory, factory_related_name='', size=2, **defaults): + def __init__(self, factory, factory_related_name="", size=2, **defaults): self.size = size super().__init__(factory, factory_related_name, **defaults) @@ -774,6 +930,7 @@ class UserFactory(factory.Factory): ... password = factory.PostGenerationMethodCall('set_pass', password='') """ + def __init__(self, method_name, *args, **kwargs): super().__init__() if len(args) > 1: diff --git a/factory/faker.py b/factory/faker.py index 88ae644c..2a4600d0 100644 --- a/factory/faker.py +++ b/factory/faker.py @@ -23,11 +23,14 @@ class Meta: class Faker(declarations.BaseDeclaration): - """Wrapper for 'faker' values. + """Wrapper for ' faker' values. Args: provider (str): the name of the Faker field locale (str): the locale to use for the faker + unique (bool): if True, use Faker's .unique to ensure all generated + values are globally unique for this provider/locale combination + (default: False) All other kwargs will be passed to the underlying provider (e.g ``factory.Faker('ean', length=10)`` @@ -35,18 +38,28 @@ class Faker(declarations.BaseDeclaration): Usage: >>> foo = factory.Faker('name') + >>> unique_email = factory.Faker('email', unique=True) + + Note: + When using unique=True, Faker maintains a global cache of generated values. + Clear it between independent tests with: factory.Faker._get_faker().unique.clear() """ + def __init__(self, provider, **kwargs): - locale = kwargs.pop('locale', None) + locale = kwargs.pop("locale", None) + unique = kwargs.pop("unique", False) self.provider = provider - super().__init__( - locale=locale, - **kwargs) + self.unique = unique + super().__init__(locale=locale, **kwargs) def evaluate(self, instance, step, extra): - locale = extra.pop('locale') + locale = extra.pop("locale") subfaker = self._get_faker(locale) - return subfaker.format(self.provider, **extra) + + if self.unique: + return subfaker.unique.format(self.provider, **extra) + else: + return subfaker.format(self.provider, **extra) _FAKER_REGISTRY: Dict[str, faker.Faker] = {} _DEFAULT_LOCALE = faker.config.DEFAULT_LOCALE diff --git a/tests/test_excluding.py b/tests/test_excluding.py new file mode 100644 index 00000000..0da2c026 --- /dev/null +++ b/tests/test_excluding.py @@ -0,0 +1,200 @@ +# Copyright: See the LICENSE file. + +import unittest + +import factory +from factory import declarations +from factory.fuzzy import FuzzyChoice + +from . import utils + + +class ExcludingTestCase(unittest.TestCase): + + def test_basic_exclusion(self): + """Test basic exclusion with FuzzyChoice.""" + decl = declarations.Excluding( + FuzzyChoice(["a", "b", "c", "d"]), exclude=["a", "b"] + ) + for _ in range(20): + value = utils.evaluate_declaration(decl) + self.assertIn(value, ["c", "d"]) + self.assertNotIn(value, ["a", "b"]) + + def test_single_exclusion(self): + """Test excluding a single value.""" + decl = declarations.Excluding(FuzzyChoice(["x", "y", "z"]), exclude="x") + for _ in range(20): + value = utils.evaluate_declaration(decl) + self.assertIn(value, ["y", "z"]) + + def test_no_exclusion(self): + """Test with no exclusions.""" + decl = declarations.Excluding(FuzzyChoice(["a", "b", "c"]), exclude=[]) + value = utils.evaluate_declaration(decl) + self.assertIn(value, ["a", "b", "c"]) + + def test_with_faker(self): + """Test Excluding works with Faker.""" + excluded_value = "fixed@example.com" + decl = declarations.Excluding(factory.Faker("email"), exclude=excluded_value) + for _ in range(10): + value = utils.evaluate_declaration(decl) + self.assertNotEqual(value, excluded_value) + self.assertIn("@", value) + + def test_with_lazy_attribute(self): + """Test Excluding works with LazyAttribute.""" + import random + + random.seed(42) + + decl = declarations.Excluding( + declarations.LazyAttribute(lambda x: random.choice(["a", "b", "c", "d"])), + exclude=["a", "b"], + ) + for _ in range(20): + value = utils.evaluate_declaration(decl) + self.assertIn(value, ["c", "d"]) + + def test_exhaustion_error(self): + """Test that ValueError is raised when all values are excluded.""" + decl = declarations.Excluding( + FuzzyChoice(["a", "b"]), exclude=["a", "b"], max_retries=10 + ) + with self.assertRaisesRegex(ValueError, r"Could not generate.*10 attempts"): + utils.evaluate_declaration(decl) + + def test_custom_max_retries(self): + """Test custom max_retries parameter.""" + decl = declarations.Excluding(FuzzyChoice(["a"]), exclude=["a"], max_retries=5) + with self.assertRaisesRegex(ValueError, r"5 attempts"): + utils.evaluate_declaration(decl) + + def test_plain_value_exclusion_error(self): + """Test that excluding a plain value raises immediately without retrying.""" + decl = declarations.Excluding("fixed_value", exclude="fixed_value") + with self.assertRaisesRegex(ValueError, r"Plain value.*is excluded"): + utils.evaluate_declaration(decl) + + def test_with_factory(self): + """Test Excluding in a full factory context.""" + + class MyModel: + def __init__(self, status, code): + self.status = status + self.code = code + + class MyFactory(factory.Factory): + class Meta: + model = MyModel + + status = declarations.Excluding( + FuzzyChoice(["active", "inactive", "pending", "banned"]), + exclude=["banned"], + ) + code = declarations.Excluding(factory.Faker("country_code"), exclude="US") + + for _ in range(10): + obj = MyFactory() + self.assertIn(obj.status, ["active", "inactive", "pending"]) + self.assertNotEqual(obj.status, "banned") + self.assertNotEqual(obj.code, "US") + + def test_with_self_attribute(self): + """Test dynamic exclusion using SelfAttribute.""" + + class MyModel: + def __init__(self, primary, secondary): + self.primary = primary + self.secondary = secondary + + class MyFactory(factory.Factory): + class Meta: + model = MyModel + + primary = FuzzyChoice(["en", "fr", "de", "es"]) + secondary = declarations.Excluding( + FuzzyChoice(["en", "fr", "de", "es"]), + exclude=declarations.SelfAttribute("primary"), + ) + + for _ in range(20): + obj = MyFactory() + self.assertIn(obj.primary, ["en", "fr", "de", "es"]) + self.assertIn(obj.secondary, ["en", "fr", "de", "es"]) + self.assertNotEqual(obj.primary, obj.secondary) + + def test_with_self_attribute_list(self): + """Test dynamic exclusion with multiple values via SelfAttribute.""" + + class MyModel: + def __init__(self, excluded_values, value): + self.excluded_values = excluded_values + self.value = value + + class MyFactory(factory.Factory): + class Meta: + model = MyModel + + excluded_values = factory.List( + [ + factory.LazyFunction(lambda: "a"), + factory.LazyFunction(lambda: "b"), + ] + ) + value = declarations.Excluding( + FuzzyChoice(["a", "b", "c", "d"]), + exclude=declarations.SelfAttribute("excluded_values"), + ) + + obj = MyFactory() + self.assertIn(obj.value, ["c", "d"]) + + def test_override_works(self): + """Test that factory overrides work correctly with Excluding.""" + + class MyModel: + def __init__(self, status): + self.status = status + + class MyFactory(factory.Factory): + class Meta: + model = MyModel + + status = declarations.Excluding( + FuzzyChoice(["active", "inactive"]), exclude="inactive" + ) + + obj = MyFactory(status="custom") + self.assertEqual(obj.status, "custom") + + def test_hashable_values(self): + """Test that unhashable values (lists, dicts) are properly handled.""" + decl = declarations.Excluding( + declarations.LazyFunction(lambda: [1, 2, 3]), + exclude=[[1, 2, 3]], + max_retries=5, + ) + with self.assertRaisesRegex(ValueError, r"5 attempts"): + utils.evaluate_declaration(decl) + + def test_mixed_exclude_types(self): + """Test excluding with a mix of types.""" + decl = declarations.Excluding(FuzzyChoice([1, 2, 3, 4, 5]), exclude=[1, 2]) + for _ in range(10): + value = utils.evaluate_declaration(decl) + self.assertIn(value, [3, 4, 5]) + + +class ExcludingOrderingTestCase(unittest.TestCase): + """Test that Excluding properly participates in declaration ordering.""" + + def test_ordering(self): + """Ensure Excluding is an OrderedDeclaration.""" + self.assertTrue( + isinstance( + declarations.Excluding(FuzzyChoice(["a", "b"])), + declarations.OrderedDeclaration, + ) + ) diff --git a/tests/test_faker.py b/tests/test_faker.py index d1a16da0..a39de906 100644 --- a/tests/test_faker.py +++ b/tests/test_faker.py @@ -49,8 +49,8 @@ def _setup_advanced_mock_faker(self, locale=None, **handlers): def test_simple_biased(self): self._setup_mock_faker(name="John Doe") - faker_field = factory.Faker('name') - self.assertEqual("John Doe", faker_field.evaluate(None, None, {'locale': None})) + faker_field = factory.Faker("name") + self.assertEqual("John Doe", faker_field.evaluate(None, None, {"locale": None})) def test_full_factory(self): class Profile: @@ -62,17 +62,25 @@ def __init__(self, first_name, last_name, email): class ProfileFactory(factory.Factory): class Meta: model = Profile - first_name = factory.Faker('first_name') - last_name = factory.Faker('last_name', locale='fr_FR') - email = factory.Faker('email') - self._setup_mock_faker(first_name="John", last_name="Doe", email="john.doe@example.org") - self._setup_mock_faker(first_name="Jean", last_name="Valjean", email="jvaljean@exemple.fr", locale='fr_FR') + first_name = factory.Faker("first_name") + last_name = factory.Faker("last_name", locale="fr_FR") + email = factory.Faker("email") + + self._setup_mock_faker( + first_name="John", last_name="Doe", email="john.doe@example.org" + ) + self._setup_mock_faker( + first_name="Jean", + last_name="Valjean", + email="jvaljean@exemple.fr", + locale="fr_FR", + ) profile = ProfileFactory() self.assertEqual("John", profile.first_name) self.assertEqual("Valjean", profile.last_name) - self.assertEqual('john.doe@example.org', profile.email) + self.assertEqual("john.doe@example.org", profile.email) def test_override_locale(self): class Profile: @@ -84,18 +92,20 @@ class ProfileFactory(factory.Factory): class Meta: model = Profile - first_name = factory.Faker('first_name') - last_name = factory.Faker('last_name', locale='fr_FR') + first_name = factory.Faker("first_name") + last_name = factory.Faker("last_name", locale="fr_FR") self._setup_mock_faker(first_name="John", last_name="Doe") - self._setup_mock_faker(first_name="Jean", last_name="Valjean", locale='fr_FR') - self._setup_mock_faker(first_name="Johannes", last_name="Brahms", locale='de_DE') + self._setup_mock_faker(first_name="Jean", last_name="Valjean", locale="fr_FR") + self._setup_mock_faker( + first_name="Johannes", last_name="Brahms", locale="de_DE" + ) profile = ProfileFactory() self.assertEqual("John", profile.first_name) self.assertEqual("Valjean", profile.last_name) - with factory.Faker.override_default_locale('de_DE'): + with factory.Faker.override_default_locale("de_DE"): profile = ProfileFactory() self.assertEqual("Johannes", profile.first_name) self.assertEqual("Valjean", profile.last_name) @@ -114,19 +124,19 @@ class FaceFactory(factory.Factory): class Meta: model = Face - smiley = factory.Faker('smiley') - french_smiley = factory.Faker('smiley', locale='fr_FR') + smiley = factory.Faker("smiley") + french_smiley = factory.Faker("smiley", locale="fr_FR") class SmileyProvider(faker.providers.BaseProvider): def smiley(self): - return ':)' + return ":)" class FrenchSmileyProvider(faker.providers.BaseProvider): def smiley(self): - return '(:' + return "(:" factory.Faker.add_provider(SmileyProvider) - factory.Faker.add_provider(FrenchSmileyProvider, 'fr_FR') + factory.Faker.add_provider(FrenchSmileyProvider, "fr_FR") face = FaceFactory() self.assertEqual(":)", face.smiley) @@ -134,7 +144,7 @@ def smiley(self): def test_faker_customization(self): """Factory declarations in Faker parameters should be accepted.""" - Trip = collections.namedtuple('Trip', ['departure', 'transfer', 'arrival']) + Trip = collections.namedtuple("Trip", ["departure", "transfer", "arrival"]) may_4th = datetime.date(1977, 5, 4) may_25th = datetime.date(1977, 5, 25) @@ -147,9 +157,9 @@ class Meta: departure = may_4th arrival = may_25th transfer = factory.Faker( - 'date_between_dates', - start_date=factory.SelfAttribute('..departure'), - end_date=factory.SelfAttribute('..arrival'), + "date_between_dates", + start_date=factory.SelfAttribute("..departure"), + end_date=factory.SelfAttribute("..arrival"), ) def fake_select_date(start_date, end_date): @@ -168,3 +178,115 @@ def fake_select_date(start_date, end_date): self.assertEqual(may_4th, trip.departure) self.assertEqual(october_19th, trip.transfer) self.assertEqual(may_25th, trip.arrival) + + +class FakerUniqueTests(unittest.TestCase): + """Tests for the unique parameter in Faker declarations.""" + + def setUp(self): + self._real_fakers = factory.Faker._FAKER_REGISTRY + factory.Faker._FAKER_REGISTRY = {} + try: + factory.Faker._get_faker().unique.clear() + except Exception: + pass + + def tearDown(self): + factory.Faker._FAKER_REGISTRY = self._real_fakers + try: + factory.Faker._get_faker().unique.clear() + except Exception: + pass + + def test_unique_faker_generates_unique_values(self): + """Test that Faker with unique=True generates unique values.""" + + class User: + def __init__(self, email): + self.email = email + + class UserFactory(factory.Factory): + class Meta: + model = User + + email = factory.Faker("email", unique=True) + + emails = {UserFactory().email for _ in range(10)} + self.assertEqual(len(emails), 10) + + def test_unique_false_allows_duplicates(self): + """Test that Faker with unique=False (default) can generate duplicates.""" + import random + + random.seed(42) + + class User: + def __init__(self, value): + self.value = value + + class UserFactory(factory.Factory): + class Meta: + model = User + + value = factory.Faker("random_int", min=1, max=3, unique=False) + + values = [UserFactory().value for _ in range(20)] + unique_values = set(values) + self.assertLess(len(unique_values), 20) + + def test_unique_clears_between_factories(self): + """Test that unique cache can be cleared between test runs.""" + + class User: + def __init__(self, email): + self.email = email + + class UserFactory(factory.Factory): + class Meta: + model = User + + email = factory.Faker("email", unique=True) + + emails_first = {UserFactory().email for _ in range(5)} + factory.Faker._get_faker().unique.clear() + + emails_second = {UserFactory().email for _ in range(5)} + + self.assertEqual(len(emails_first), 5) + self.assertEqual(len(emails_second), 5) + + def test_unique_with_locale(self): + """Test that unique works with custom locales.""" + + class User: + def __init__(self, name): + self.name = name + + class UserFactory(factory.Factory): + class Meta: + model = User + + name = factory.Faker("name", locale="fr_FR", unique=True) + + names = {UserFactory().name for _ in range(5)} + self.assertEqual(len(names), 5) + + def test_unique_exhaustion_raises_error(self): + """Test that exhausting unique values raises UniquenessException.""" + import faker.exceptions + + class User: + def __init__(self, value): + self.value = value + + class UserFactory(factory.Factory): + class Meta: + model = User + + value = factory.Faker("boolean", unique=True) + + UserFactory() + UserFactory() + + with self.assertRaises(faker.exceptions.UniquenessException): + UserFactory()