diff --git a/sqlalchemy_history/model_builder.py b/sqlalchemy_history/model_builder.py index 6695671..48af755 100644 --- a/sqlalchemy_history/model_builder.py +++ b/sqlalchemy_history/model_builder.py @@ -1,9 +1,10 @@ """Model Builder module build Versioned Models""" from copy import copy + import sqlalchemy as sa from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.orm import column_property +from sqlalchemy.orm import MappedColumn, column_property from sqlalchemy_utils.functions import get_declarative_base, get_primary_keys from sqlalchemy_utils.models import generic_repr @@ -82,11 +83,13 @@ def copy_mapper_args(model): args[arg] = model.__mapper_args__[arg] if "polymorphic_on" in model.__mapper_args__: - column = model.__mapper_args__["polymorphic_on"] - if isinstance(column, str): - args["polymorphic_on"] = column + discriminator_column = model.__mapper_args__["polymorphic_on"] + if isinstance(discriminator_column, str): + args["polymorphic_on"] = discriminator_column + elif isinstance(discriminator_column, MappedColumn): + args["polymorphic_on"] = discriminator_column.column.key else: - args["polymorphic_on"] = column.key + args["polymorphic_on"] = discriminator_column.key return args @@ -244,7 +247,10 @@ def mapper_args(cls): name = "%sVersion" % (self.model.__name__,) version_cls = type(name, self.base_classes(), args) if option(self.model, "base_classes") is None: - primary_keys = list(get_primary_keys(self.model).keys()) + ["transaction_id", "operation_type"] + primary_keys = list(get_primary_keys(self.model).keys()) + [ + "transaction_id", + "operation_type", + ] version_cls = generic_repr(*primary_keys)(version_cls) return version_cls diff --git a/tests/builders/test_model_builder.py b/tests/builders/test_model_builder.py index 19e9066..31c127a 100644 --- a/tests/builders/test_model_builder.py +++ b/tests/builders/test_model_builder.py @@ -1,4 +1,6 @@ from tests import TestCase +from sqlalchemy_history.model_builder import copy_mapper_args +from sqlalchemy.orm import MappedColumn class TestVersionModelBuilder(TestCase): @@ -49,3 +51,36 @@ def test_version_cls_repr(self): self.session.add(article) self.session.commit() assert repr(article.versions[0]) == "Class_ArticleVersion(id=1)" + + +class TestCopyMapperArgs(TestCase): + def test_copy_mapper_args_with_mapped_column_polymorphic_on(self): + # Test that copy_mapper_args handles MappedColumn for polymorphic_on + import sqlalchemy as sa + + class MockModel: + __mapper_args__ = {"polymorphic_on": MappedColumn(sa.String(50), name="type")} + + args = copy_mapper_args(MockModel) + assert args["polymorphic_on"] == "type" + + def test_copy_mapper_args_with_str_polymorphic_on(self): + # Test that copy_mapper_args handles str for polymorphic_on + + class MockModel: + __mapper_args__ = {"polymorphic_on": "type"} + + args = copy_mapper_args(MockModel) + assert args["polymorphic_on"] == "type" + + def test_copy_mapper_args_with_column_polymorphic_on(self): + # Test that copy_mapper_args handles Column for polymorphic_on + import sqlalchemy as sa + + column = sa.Column(sa.String(50), name="type") + + class MockModel: + __mapper_args__ = {"polymorphic_on": column} + + args = copy_mapper_args(MockModel) + assert args["polymorphic_on"] == "type"