diff --git a/dataframely/_compat.py b/dataframely/_compat.py index a60ba3a..7acc643 100644 --- a/dataframely/_compat.py +++ b/dataframely/_compat.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2025-2025 +# Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause @@ -29,6 +29,7 @@ class DeltaTable: # type: ignore # noqa: N801 try: import sqlalchemy as sa import sqlalchemy.dialects.mssql as sa_mssql + import sqlalchemy.dialects.postgresql as sa_postgresql from sqlalchemy import Dialect from sqlalchemy.dialects.mssql.pyodbc import MSDialect_pyodbc from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 @@ -36,6 +37,7 @@ class DeltaTable: # type: ignore # noqa: N801 except ImportError: sa = _DummyModule("sqlalchemy") # type: ignore sa_mssql = _DummyModule("sqlalchemy") # type: ignore + sa_postgresql = _DummyModule("sqlalchemy") # type: ignore class sa_TypeEngine: # type: ignore # noqa: N801 pass @@ -81,6 +83,7 @@ class Dialect: # type: ignore # noqa: N801 "pydantic_core_schema", "pydantic", "sa_mssql", + "sa_postgresql", "sa_TypeEngine", "sa", ] diff --git a/dataframely/columns/struct.py b/dataframely/columns/struct.py index b8aecf9..edf597b 100644 --- a/dataframely/columns/struct.py +++ b/dataframely/columns/struct.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2025-2025 +# Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations @@ -8,7 +8,7 @@ import polars as pl -from dataframely._compat import pa, sa, sa_TypeEngine +from dataframely._compat import pa, sa, sa_postgresql, sa_TypeEngine from dataframely._polars import PolarsDataType from dataframely.random import Generator @@ -107,8 +107,11 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: } def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: - # NOTE: We might want to add support for PostgreSQL's JSON in the future. - raise NotImplementedError("SQL column cannot have 'Struct' type.") + match dialect.name: + case "postgresql": + return sa_postgresql.JSONB() + case _: + raise NotImplementedError("SQL column cannot have 'Struct' type.") @property def pyarrow_dtype(self) -> pa.DataType: diff --git a/tests/columns/test_sqlalchemy_columns.py b/tests/columns/test_sqlalchemy_columns.py index 732b395..1fc65a2 100644 --- a/tests/columns/test_sqlalchemy_columns.py +++ b/tests/columns/test_sqlalchemy_columns.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2025-2025 +# Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause import pytest @@ -95,6 +95,7 @@ def test_mssql_datatype(column: Column, datatype: str) -> None: (dy.String(regex="^[abc]{1,3}d$"), "VARCHAR(4)"), (dy.Enum(["foo", "bar"]), "CHAR(3)"), (dy.Enum(["a", "abc"]), "VARCHAR(3)"), + (dy.Struct({"a": dy.String(nullable=True)}), "JSONB"), ], ) def test_postgres_datatype(column: Column, datatype: str) -> None: @@ -152,7 +153,7 @@ def test_raise_for_array_column(dialect: Dialect) -> None: dy.Array(dy.String(nullable=True), 1).sqlalchemy_dtype(dialect) -@pytest.mark.parametrize("dialect", [MSDialect_pyodbc(), PGDialect_psycopg2()]) +@pytest.mark.parametrize("dialect", [MSDialect_pyodbc()]) def test_raise_for_struct_column(dialect: Dialect) -> None: with pytest.raises( NotImplementedError, match="SQL column cannot have 'Struct' type."