Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ Types of changes:

## [Unreleased]

### Fixed

- Improve nested/dotted column handling: when data is loaded from external sources via database accessors, nested
columns (e.g., `value.shopId`) are now aliased with underscores (`value_shopId`) for consistent querying. Native
DuckDB struct columns continue to use dot notation. This ensures proper handling in WHERE clauses, JOIN conditions,
and SELECT statements across all check types.
- Consolidate date-type filters to be ORed together (instead of ANDed) when fetching data into memory, allowing multiple
date conditions to apply correctly.

## [0.11.2] - 2026-01-23

### Fixed
Expand Down
115 changes: 78 additions & 37 deletions src/koality/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,20 @@ def __init__(
def in_memory_column(self) -> str:
"""Return the column name to reference in in-memory queries.

If a configured column references a nested field (e.g. "value.shopId"),
the in-memory representation uses the last segment ("shopId"). This
property provides that flattened name without modifying the original
If a configured column references a nested field (e.g. "value.shopId"):
- When querying data loaded via database_accessor: uses underscores ("value_shopId")
because the executor flattens struct columns with underscore aliases
- When querying existing DuckDB tables (no accessor): keeps dots ("value.shopId")
to support native DuckDB struct column syntax

This property provides the appropriate name without modifying the original
configured `self.check_column` which is still used for result writing.
"""
if isinstance(self.check_column, str) and "." in self.check_column:
return self.check_column.split(".")[-1]
if isinstance(self.check_column, str) and "." in self.check_column: # noqa: SIM102
# Only convert to underscores if data was loaded via database_accessor
# (which flattens structs). For native DuckDB tables, keep dotted notation.
if self.database_accessor:
return self.check_column.replace(".", "_")
return self.check_column

@property
Expand Down Expand Up @@ -391,15 +398,25 @@ def get_identifier_filter(filters: dict[str, dict[str, Any]]) -> tuple[str, dict
return None

@staticmethod
def assemble_where_statement(filters: dict[str, dict[str, Any]], *, strip_dotted_columns: bool = True) -> str:
def assemble_where_statement( # noqa: C901
filters: dict[str, dict[str, Any]],
*,
strip_dotted_columns: bool = True,
database_accessor: str | None = None,
) -> str:
"""Generate the where statement for the check query using the specified filters.

Args:
filters: A dict containing filter specifications, e.g.,
strip_dotted_columns: When True (default), dotted column names (e.g. "a.b") are
reduced to their last component ("b") for WHERE clauses. If False, the full
dotted expression is preserved. This is useful when querying source databases
that expect the original dotted column syntax.
transformed based on the data source:
- With database_accessor: converted to underscores ("a_b") for flattened data
- Without database_accessor: kept as dots ("a.b") for native DuckDB structs
If False, the full dotted expression is preserved regardless (used when
querying source databases that expect the original dotted column syntax).
database_accessor: Optional database accessor string. When provided and non-empty,
indicates data was loaded from external source and dotted columns should be
converted to underscores.

Example filters:
`{
Expand Down Expand Up @@ -445,17 +462,22 @@ def assemble_where_statement(filters: dict[str, dict[str, Any]], *, strip_dotted
# If column is not provided we cannot build a WHERE condition
if column is None:
continue
# If the column references a nested field (e.g. "value.shopId"),
# databases that flatten JSON may have the column stored without the
# prefix. By default we use the last component after the dot for the
# WHERE clause, but callers can disable this behavior by setting
# `strip_dotted_columns=False` (used when querying source DBs so the
# original dotted expression is preserved).
if isinstance(column, str) and "." in column and strip_dotted_columns:
column = column.split(".")[-1]
# If the column references a nested field (e.g. "value.shopId"):
# - With database_accessor: convert to underscores (value_shopId) for flattened data
# - Without database_accessor: keep dots (value.shopId) for native DuckDB structs
# Callers can disable this by setting `strip_dotted_columns=False`
# (used when querying source DBs where the original dotted expression is needed).
if isinstance(column, str) and "." in column and strip_dotted_columns and database_accessor:
column = column.replace(".", "_")
# else: keep dotted notation for native DuckDB struct support

operator = filter_dict.get("operator", "=")

# Cast date columns for proper comparison
is_date_filter = filter_dict.get("type") == "date"
if is_date_filter:
column = f"CAST({column} AS DATE)"

# Handle NULL values with IS NULL / IS NOT NULL
if value is None:
if operator == "!=":
Expand All @@ -465,6 +487,9 @@ def assemble_where_statement(filters: dict[str, dict[str, Any]], *, strip_dotted
continue

formatted_value = format_filter_value(value, operator)
# Prefix DATE for date type filters
if is_date_filter and operator not in ("BETWEEN", "IN", "NOT IN"):
formatted_value = f"DATE {formatted_value}"
filters_statements.append(f" {column} {operator} {formatted_value}")

if len(filters_statements) == 0:
Expand Down Expand Up @@ -547,7 +572,7 @@ def assemble_query(self) -> str:
if isinstance(self, IqrOutlierCheck):
filters = {name: cfg for name, cfg in filters.items() if cfg.get("type") != "date"}

if where_statement := self.assemble_where_statement(filters):
if where_statement := self.assemble_where_statement(filters, database_accessor=self.database_accessor):
return main_query + "\n" + where_statement

return main_query
Expand All @@ -561,7 +586,7 @@ def assemble_data_exists_query(self) -> str:
"{self.table}"
"""

if where_statement := self.assemble_where_statement(self.filters):
if where_statement := self.assemble_where_statement(self.filters, database_accessor=self.database_accessor):
return f"{data_exists_query}\n{where_statement}"

return data_exists_query
Expand Down Expand Up @@ -861,7 +886,7 @@ def assemble_query(self) -> str:
f"CAST({date_col} AS DATE) BETWEEN (DATE '{date_val}' - INTERVAL 14 DAY) AND DATE '{date_val}'"
) # TODO: maybe parameterize interval days

if where_statement := self.assemble_where_statement(self.filters):
if where_statement := self.assemble_where_statement(self.filters, database_accessor=self.database_accessor):
return main_query + "\nAND\n" + where_statement.removeprefix("WHERE\n")

return main_query
Expand Down Expand Up @@ -1220,7 +1245,7 @@ def assemble_query(self) -> str:
order = {"max": "DESC", "min": "ASC"}[self.max_or_min]
return f"""
{self.query_boilerplate(self.transformation_statement())}
{self.assemble_where_statement(self.filters)}
{self.assemble_where_statement(self.filters, database_accessor=self.database_accessor)}
GROUP BY {self.in_memory_column}
ORDER BY {self.name} {order}
LIMIT 1 -- only the first entry is needed
Expand Down Expand Up @@ -1343,14 +1368,27 @@ def assemble_name(self) -> str:

def assemble_query(self) -> str:
"""Assemble the SQL query for calculating match rate between tables."""
right_column_statement = ",\n ".join(self.join_columns_right)

join_on_statement = "\n AND\n ".join(
[
f"lefty.{left_col} = righty.{right_col.split('.')[-1]}"
for left_col, right_col in zip(self.join_columns_left, self.join_columns_right, strict=False)
],
)
# Transform dotted column names based on data source:
# - With database_accessor: convert to underscores (value.shopId → value_shopId) for flattened data
# - Without database_accessor: keep dots for SELECT (value.shopId), use last part for JOIN (shopId)
if self.database_accessor:
right_column_statement = ",\n ".join([col.replace(".", "_") for col in self.join_columns_right])
join_on_statement = "\n AND\n ".join(
[
f"lefty.{left_col.replace('.', '_')} = righty.{right_col.replace('.', '_')}"
for left_col, right_col in zip(self.join_columns_left, self.join_columns_right, strict=False)
],
)
else:
# For native DuckDB struct columns: SELECT uses dotted notation,
# but DuckDB names the result column as just the last part
right_column_statement = ",\n ".join(self.join_columns_right)
join_on_statement = "\n AND\n ".join(
[
f"lefty.{left_col.split('.')[-1]} = righty.{right_col.split('.')[-1]}"
for left_col, right_col in zip(self.join_columns_left, self.join_columns_right, strict=False)
],
)

return f"""
WITH
Expand All @@ -1360,14 +1398,14 @@ def assemble_query(self) -> str:
TRUE AS in_right_table
FROM
"{self.right_table}"
{self.assemble_where_statement(self.filters_right)}
{self.assemble_where_statement(self.filters_right, database_accessor=self.database_accessor)}
),
lefty AS (
SELECT
*
FROM
"{self.left_table}"
{self.assemble_where_statement(self.filters_left)}
{self.assemble_where_statement(self.filters_left, database_accessor=self.database_accessor)}
)

SELECT
Expand Down Expand Up @@ -1397,15 +1435,15 @@ def assemble_data_exists_query(self) -> str:
COUNT(*) AS right_counter,
FROM
"{self.right_table}"
{self.assemble_where_statement(self.filters_right)}
{self.assemble_where_statement(self.filters_right, database_accessor=self.database_accessor)}
),

lefty AS (
SELECT
COUNT(*) AS left_counter,
FROM
"{self.left_table}"
{self.assemble_where_statement(self.filters_left)}
{self.assemble_where_statement(self.filters_left, database_accessor=self.database_accessor)}
)

SELECT
Expand Down Expand Up @@ -1508,7 +1546,10 @@ def assemble_name(self) -> str:

def assemble_query(self) -> str:
"""Assemble the SQL query for calculating relative count change."""
where_statement = self.assemble_where_statement(self.filters).replace("WHERE", "AND")
where_statement = self.assemble_where_statement(self.filters, database_accessor=self.database_accessor).replace(
"WHERE",
"AND",
)
date_col = self.date_filter["column"]
date_val = self.date_filter["value"]

Expand Down Expand Up @@ -1573,7 +1614,7 @@ def assemble_data_exists_query(self) -> str:
date_col = self.date_filter["column"]
date_val = self.date_filter["value"]

where_statement = self.assemble_where_statement(self.filters)
where_statement = self.assemble_where_statement(self.filters, database_accessor=self.database_accessor)
if where_statement:
return f"{data_exists_query}\n{where_statement} AND CAST({date_col} AS DATE) = DATE '{date_val}'"
return f"{data_exists_query}\nWHERE CAST({date_col} AS DATE) = DATE '{date_val}'"
Expand Down Expand Up @@ -1693,7 +1734,7 @@ def transformation_statement(self) -> str:
if filters:
filter_columns = ",\n".join([v["column"] for v in filters.values()])
filter_columns = ",\n" + filter_columns
where_statement = self.assemble_where_statement(filters)
where_statement = self.assemble_where_statement(filters, database_accessor=self.database_accessor)
where_statement = "\nAND\n" + where_statement.removeprefix("WHERE\n")
return f"""
WITH
Expand Down Expand Up @@ -1770,7 +1811,7 @@ def assemble_data_exists_query(self) -> str:

filters = {k: v for k, v in self.filters.items() if v["type"] != "date"}

where_statement = self.assemble_where_statement(filters)
where_statement = self.assemble_where_statement(filters, database_accessor=self.database_accessor)
if where_statement:
where_statement = f"{where_statement} AND CAST({date_col} AS DATE) = DATE '{date_val}'"
else:
Expand Down
30 changes: 25 additions & 5 deletions src/koality/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ def fetch_data_into_memory(self, data_requirements: defaultdict[str, defaultdict
select_parts.append("*")
continue
if isinstance(col, str) and "." in col:
flat = col.split(".")[-1]
# Replace dots with underscores for deterministic aliasing
# e.g., "value.shopId" becomes "value_shopId"
flat = col.replace(".", "_")
# Make flattened name unique if duplicate arises
base = flat
idx = 1
Expand All @@ -364,13 +366,15 @@ def fetch_data_into_memory(self, data_requirements: defaultdict[str, defaultdict
select_parts.append(col)
columns = ", ".join(select_parts)

# Combine all unique filter groups. Treat date-range filters specially and
# combine them with other filters using AND (date range applies to all other filters).
# Combine all unique filter groups. Separate date filters from other filters.
# All date-related conditions (BETWEEN ranges and date equality) should be ORed.
# Non-date filters should be ANDed with the date conditions.
date_filters_sql = set()
other_filters_sql = set()

for filter_group in requirements["filters"]:
filter_dict = {}
date_filter_dict = {}
for item in filter_group:
# Expect each item to be a (name, frozenset(cfg_items)) tuple
if not (isinstance(item, tuple) and len(item) == _DATE_RANGE_TUPLE_SIZE):
Expand All @@ -390,8 +394,24 @@ def fetch_data_into_memory(self, data_requirements: defaultdict[str, defaultdict
date_filters_sql.add(f"({cond})")
# date_range handled; continue to next group
continue
filter_dict[name] = dict(cfg)
# Separate date-type filters from other filters
if cfg.get("type") == "date":
date_filter_dict[name] = dict(cfg)
else:
filter_dict[name] = dict(cfg)

# Process date filters separately and add to date_filters_sql
if date_filter_dict:
where_clause = DataQualityCheck.assemble_where_statement(
date_filter_dict,
strip_dotted_columns=False,
)
if where_clause.strip().startswith("WHERE"):
conditions = where_clause.strip()[len("WHERE") :].strip()
if conditions:
date_filters_sql.add(f"({conditions})")

# Process non-date filters
if filter_dict:
# When fetching from the source DB, preserve dotted column expressions
# (e.g., "value.shopId") in the WHERE so the source provider sees the
Expand All @@ -403,7 +423,7 @@ def fetch_data_into_memory(self, data_requirements: defaultdict[str, defaultdict
if conditions:
other_filters_sql.add(f"({conditions})")

# Build final WHERE clause: if we have date filters, AND them with other filters (if any).
# Build final WHERE clause: OR all date filters together, AND with other filters.
final_where_clause = ""
if date_filters_sql and other_filters_sql:
date_part = " OR ".join(sorted(date_filters_sql))
Expand Down
26 changes: 26 additions & 0 deletions src/koality/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,32 @@ def persist_results(self) -> bool:
class _Check(_LocalDefaults):
"""Base model for all check configurations."""

@model_validator(mode="after")
def validate_filters_have_columns(self) -> Self:
"""Validate that all filters with concrete values have columns specified.

This validation runs after defaults merging, ensuring the final filter
configuration is complete.
"""
for filter_name, filter_config in self.filters.items():
# Skip identifier filters without concrete values (naming-only)
if filter_config.type == "identifier" and (filter_config.value is None or filter_config.value == "*"):
continue

# Skip partial filters with no value
if filter_config.value is None:
continue

# Filter has a value but no column - this is an error
if filter_config.column is None:
msg = (
f"Filter '{filter_name}' has value '{filter_config.value}' "
f"but no column specified. Add 'column: <column_name>' to the filter definition."
)
raise ValueError(msg)

return self


class _SingleTableCheck(_Check):
"""Base model for checks that operate on a single table."""
Expand Down
Loading
Loading