Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions sqlalchemy_history/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Model Builder module build Versioned Models"""

from copy import copy

import sqlalchemy as sa
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import column_property
from sqlalchemy.orm import MappedColumn, column_property
from sqlalchemy_utils.functions import get_declarative_base, get_primary_keys
from sqlalchemy_utils.models import generic_repr

Expand Down Expand Up @@ -82,11 +83,13 @@ def copy_mapper_args(model):
args[arg] = model.__mapper_args__[arg]

if "polymorphic_on" in model.__mapper_args__:
column = model.__mapper_args__["polymorphic_on"]
if isinstance(column, str):
args["polymorphic_on"] = column
discriminator_column = model.__mapper_args__["polymorphic_on"]
if isinstance(discriminator_column, str):
args["polymorphic_on"] = discriminator_column
elif isinstance(discriminator_column, MappedColumn):
args["polymorphic_on"] = discriminator_column.column.key
else:
args["polymorphic_on"] = column.key
args["polymorphic_on"] = discriminator_column.key
return args


Expand Down Expand Up @@ -244,7 +247,10 @@ def mapper_args(cls):
name = "%sVersion" % (self.model.__name__,)
version_cls = type(name, self.base_classes(), args)
if option(self.model, "base_classes") is None:
primary_keys = list(get_primary_keys(self.model).keys()) + ["transaction_id", "operation_type"]
primary_keys = list(get_primary_keys(self.model).keys()) + [
"transaction_id",
"operation_type",
]
version_cls = generic_repr(*primary_keys)(version_cls)
return version_cls

Expand Down
35 changes: 35 additions & 0 deletions tests/builders/test_model_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from tests import TestCase
from sqlalchemy_history.model_builder import copy_mapper_args
from sqlalchemy.orm import MappedColumn


class TestVersionModelBuilder(TestCase):
Expand Down Expand Up @@ -49,3 +51,36 @@ def test_version_cls_repr(self):
self.session.add(article)
self.session.commit()
assert repr(article.versions[0]) == "Class_ArticleVersion(id=1)"


class TestCopyMapperArgs(TestCase):
def test_copy_mapper_args_with_mapped_column_polymorphic_on(self):
# Test that copy_mapper_args handles MappedColumn for polymorphic_on
import sqlalchemy as sa

class MockModel:
__mapper_args__ = {"polymorphic_on": MappedColumn(sa.String(50), name="type")}

args = copy_mapper_args(MockModel)
assert args["polymorphic_on"] == "type"

def test_copy_mapper_args_with_str_polymorphic_on(self):
# Test that copy_mapper_args handles str for polymorphic_on

class MockModel:
__mapper_args__ = {"polymorphic_on": "type"}

args = copy_mapper_args(MockModel)
assert args["polymorphic_on"] == "type"

def test_copy_mapper_args_with_column_polymorphic_on(self):
# Test that copy_mapper_args handles Column for polymorphic_on
import sqlalchemy as sa

column = sa.Column(sa.String(50), name="type")

class MockModel:
__mapper_args__ = {"polymorphic_on": column}

args = copy_mapper_args(MockModel)
assert args["polymorphic_on"] == "type"