diff --git a/pyproject.toml b/pyproject.toml index db1b685..cd3f59b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dev-dependencies = [ "pytest>=8.3.4", "coverage>=7.6.9", "docker>=7.1.0", + "sqlmodel>=0.0.27", ] [build-system] diff --git a/src/sqlmodelgen/codegen/cir_to_full_ast/code_ir_to_ast.py b/src/sqlmodelgen/codegen/cir_to_full_ast/code_ir_to_ast.py index bf663a0..5202bec 100644 --- a/src/sqlmodelgen/codegen/cir_to_full_ast/code_ir_to_ast.py +++ b/src/sqlmodelgen/codegen/cir_to_full_ast/code_ir_to_ast.py @@ -61,7 +61,7 @@ def gen_table_args(model_ir: ModelIR) -> ast.Assign | None: return ast.Assign( targets=[ast.Name('__table_args__')], - value=ast.List( + value=ast.Tuple( elts=[table_arg.to_expr() for table_arg in model_ir.table_args] ) ) diff --git a/tests/helpers/helpers.py b/tests/helpers/helpers.py index 1119177..0482cdb 100644 --- a/tests/helpers/helpers.py +++ b/tests/helpers/helpers.py @@ -247,7 +247,11 @@ def collect_table_name(stat: ast.Assign) -> str | None: def collect_uniques(table_args: ast.expr) -> set[tuple[str]]: uniques: set[tuple[str]] = set() - if isinstance(table_args, ast.List): + # TODO: this shall support the parsing of all the possible + # types of values __table_args__ could possess, I remember + # also a dictionary being possible and maybe something else + # other than a tuple + if isinstance(table_args, ast.Tuple): for elt in table_args.elts: if isinstance(elt, ast.Call) and isinstance(elt.func, ast.Name) and elt.func.id == 'UniqueConstraint': uniques.add(tuple(arg.value for arg in elt.args if isinstance(arg, ast.Constant))) diff --git a/tests/test_gen_from_mysql.py b/tests/test_gen_from_mysql.py index e96d5cb..f1a6671 100644 --- a/tests/test_gen_from_mysql.py +++ b/tests/test_gen_from_mysql.py @@ -43,7 +43,7 @@ def test_mysql(): class Hero(SQLModel, table=True): __tablename__ = 'Hero' - __table_args__ = [UniqueConstraint('secret_name')] + __table_args__ = (UniqueConstraint('secret_name'), ) id: int | None name: str | None secret_name: str | None diff --git a/tests/test_gen_from_postgres.py b/tests/test_gen_from_postgres.py index be4d291..4a24557 100644 --- a/tests/test_gen_from_postgres.py +++ b/tests/test_gen_from_postgres.py @@ -80,7 +80,7 @@ def test_gen_code(): class Users(SQLModel, table=True): __tablename__ = 'users' - __table_args__ = [UniqueConstraint('email'), UniqueConstraint('name')] + __table_args__ = (UniqueConstraint('email'), UniqueConstraint('name')) id: UUID = Field(primary_key=True, default_factory=uuid4) email: str name: str @@ -101,7 +101,7 @@ class Participations(SQLModel, table=True): class Leagues(SQLModel, table=True): __tablename__ = 'leagues' - __table_args__ = [UniqueConstraint('name')] + __table_args__ = (UniqueConstraint('name'),) id: UUID = Field(primary_key=True, default_factory=uuid4) name: str public: bool diff --git a/tests/test_gen_from_sql.py b/tests/test_gen_from_sql.py index 8f4702d..40bfb82 100644 --- a/tests/test_gen_from_sql.py +++ b/tests/test_gen_from_sql.py @@ -91,16 +91,20 @@ def test_unique_single_column(): age INTEGER );''' - assert collect_code_info(gen_code_from_sql(sql)) == collect_code_info('''from sqlmodel import SQLModel, Field, UniqueConstraint + code = gen_code_from_sql(sql) + + assert collect_code_info(code) == collect_code_info('''from sqlmodel import SQLModel, Field, UniqueConstraint class Hero(SQLModel, table = True): \t__tablename__ = 'Hero' -\t__table_args__ = [UniqueConstraint('secret_name')] +\t__table_args__ = (UniqueConstraint('secret_name'), ) \tid: int = Field(primary_key=True) \tname: str \tsecret_name: str \tage: int | None''') + + exec(code, globals(), globals()) def test_datetime(): diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 902e8e1..737b6c4 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -41,7 +41,7 @@ def test_collect_code_info(): class a_table(SQLModel, table = True): __tablename__ = 'a_table' - __table_args__ = [UniqueConstraint('name')] + __table_args__ = (UniqueConstraint('name'), ) id: int | None = Field(primary_key=True) name: str email: str | None''')