Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
89d94a7
add ty
pavelzw Dec 17, 2025
a54715a
introduce ty
pavelzw Dec 17, 2025
ea0ff71
WIP
pavelzw Dec 17, 2025
275f1c3
Address some ty complaints.
kklein Dec 17, 2025
867a959
Merge branch 'main' of github.com:Quantco/datajudge into ty
kklein Dec 17, 2025
816c849
Don't expect children to implement retrieve and compare.
kklein Dec 17, 2025
07e031d
use --no-progress
pavelzw Dec 17, 2025
79ea6a8
Address some more ty complaints.
kklein Dec 18, 2025
2596c26
Fix another ty complaint.
kklein Dec 18, 2025
48b8c86
Add type annotation.
kklein Dec 18, 2025
041ce7f
Remove type igores.
kklein Dec 18, 2025
b7d158c
Bring back instance caching
kklein Dec 19, 2025
99222d3
Tell ty that the first argument to a method doesn't correspond to self.
kklein Dec 19, 2025
2ccfe58
Remove some mypy hacks
kklein Dec 20, 2025
226cef5
Update lock
kklein Jan 6, 2026
d700e93
Merge branch 'main' of github.com:Quantco/datajudge into ty
kklein Jan 14, 2026
979f81b
Merge branch 'main' of github.com:Quantco/datajudge into ty
kklein Jan 19, 2026
b174f39
Add annotation and fix typo
kklein Jan 19, 2026
fd60cfb
Fix method references
kklein Jan 20, 2026
309ddfa
Fix imports
kklein Jan 20, 2026
b27450a
Rework interfaces of select methods
kklein Jan 20, 2026
a58ac3c
Ensure that method overwrites are legitimate overwrites
kklein Jan 20, 2026
3d2af01
Fix field references
kklein Jan 20, 2026
e043bec
Fix more method references.
kklein Jan 20, 2026
e1b733f
Bring back mypy
kklein Jan 20, 2026
a09f601
Use default environment for mypy
kklein Jan 20, 2026
3750c16
Use updated lock
kklein Jan 20, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ repos:
# mypy
- id: mypy
name: mypy
entry: pixi run -e mypy mypy
entry: pixi run -e default mypy
language: system
types: [python]
require_serial: true
# ty
- id: ty
name: ty
entry: pixi run -e default ty check --no-progress
language: system
types: [python]
require_serial: true
Expand Down
3,520 changes: 2,691 additions & 829 deletions pixi.lock

Large diffs are not rendered by default.

14 changes: 5 additions & 9 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,16 @@ sqlalchemy = "2.*"
[feature.test.dependencies]
pytest-cov = "*"
pytest-xdist = "*"
pytest-html = "*"
ty = "*"
mypy = "*"
types-colorama = "*"
pandas-stubs = "*"

[feature.test.target.unix.dependencies]
pytest-memray = "*"
memray = "*"

[feature.mypy.dependencies]
mypy = "*"
types-setuptools = "*"
types-colorama = "*"
pandas-stubs = "*"
types-jinja2 = "*"

[feature.lint.dependencies]
pre-commit = "*"
docformatter = "*"
Expand Down Expand Up @@ -163,5 +161,3 @@ bigquery-py310 = ["bigquery", "py310", "test"]
bigquery-sa1 = ["bigquery", "sa1", "test"]

lint = { features = ["lint"], no-default-feature = true }

mypy = ["mypy"]
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,15 @@ known-first-party = ["datajudge"]
quote-style = "double"
indent-style = "space"

[tool.ty.terminal]
error-on-warning = true

[tool.mypy]
python_version = '3.10'
no_implicit_optional = true
allow_empty_bodies = true
check_untyped_defs = true
disable_error_code = ["method-assign"]

[[tool.mypy.overrides]]
module = ["scipy.*", "pytest_html"]
Expand Down
35 changes: 23 additions & 12 deletions src/datajudge/constraints/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations

import abc
from collections.abc import Callable, Collection
from collections.abc import Callable, Collection, Sequence
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Any, TypeVar

import sqlalchemy as sa
from sqlalchemy.sql import selectable

from ..db_access import DataReference
from ..formatter import Formatter
Expand All @@ -16,7 +15,8 @@

_DEFAULT_FORMATTER = Formatter()

_OptionalSelections = list[sa.sql.expression.Select] | None
_Select = selectable.Select | selectable.CompoundSelect
_OptionalSelections = Sequence[_Select] | None
_ToleranceGetter = Callable[[sa.engine.Engine], float]


Expand Down Expand Up @@ -149,17 +149,25 @@ def __init__(
not isinstance(output_processors, list)
):
output_processors = [output_processors]
self._output_processors = output_processors

self._output_processors: list[OutputProcessor] | None = output_processors

self._cache_size = cache_size
self._setup_caching()

def _setup_caching(self):
# this has an added benefit of allowing the class to be garbage collected
# We don't use cache or lru_cache decorators since those would lead
# to class-based, not instance-based caching.
#
# Using this approach has the added benefit of allowing the class to be garbage collected
# according to https://rednafi.com/python/lru_cache_on_methods/
# and https://docs.astral.sh/ruff/rules/cached-instance-method/
self._get_factual_value = lru_cache(self._cache_size)(self._get_factual_value) # type: ignore[method-assign]
self._get_target_value = lru_cache(self._cache_size)(self._get_target_value) # type: ignore[method-assign]
self._get_factual_value: Callable[[Constraint, sa.engine.Engine], Any] = (
lru_cache(self._cache_size)(self._get_factual_value)
)
self._get_target_value: Callable[[Constraint, sa.engine.Engine], Any] = (
lru_cache(self._cache_size)(self._get_target_value)
)

def _check_if_valid_between_or_within(
self,
Expand All @@ -179,13 +187,11 @@ def _check_if_valid_between_or_within(
f"{class_name}. Use exactly either of them."
)

# @lru_cache(maxsize=None), see _setup_caching()
def _get_factual_value(self, engine: sa.engine.Engine) -> Any:
factual_value, factual_selections = self._retrieve(engine, self._ref)
self._factual_selections = factual_selections
return factual_value

# @lru_cache(maxsize=None), see _setup_caching()
def _get_target_value(self, engine: sa.engine.Engine) -> Any:
if self._ref2 is None:
return self._ref_value
Expand Down Expand Up @@ -246,9 +252,14 @@ def _compare(
raise NotImplementedError()

def test(self, engine: sa.engine.Engine) -> TestResult:
value_factual = self._get_factual_value(engine)
value_target = self._get_target_value(engine)
# ty can't figure out that this is a method and that self is passed
# as the first argument.
value_factual = self._get_factual_value(engine=engine) # type: ignore[missing-argument]
# ty can't figure out that this is a method and that self is passed
# as the first argument.
value_target = self._get_target_value(engine=engine) # type: ignore[missing-argument]
is_success, assertion_message = self._compare(value_factual, value_target)

if is_success:
return TestResult.success()

Expand Down
2 changes: 1 addition & 1 deletion src/datajudge/constraints/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _retrieve(
# side effects. This should be removed as soon as snowflake column capitalization
# is fixed by snowflake-sqlalchemy.
if is_snowflake(engine) and self._ref_value is not None:
self._ref_value = lowercase_column_names(self._ref_value) # type: ignore
self._ref_value = lowercase_column_names(self._ref_value)
return db_access.get_column_names(engine, ref)


Expand Down
4 changes: 2 additions & 2 deletions src/datajudge/constraints/date.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .. import db_access
from ..db_access import DataReference
from .base import Constraint, _OptionalSelections
from .interval import NoGapConstraint, NoOverlapConstraint
from .interval import NoGapConstraint, NoOverlapConstraint, _Selects

_INPUT_DATE_FORMAT = "'%Y-%m-%d'"

Expand Down Expand Up @@ -220,7 +220,7 @@ def _compare(
class DateNoGap(NoGapConstraint):
_DIMENSIONS = 1

def select(self, engine: sa.engine.Engine, ref: DataReference):
def _select(self, engine: sa.engine.Engine, ref: DataReference) -> _Selects:
sample_selection, n_violations_selection = db_access.get_date_gaps(
engine,
ref,
Expand Down
23 changes: 11 additions & 12 deletions src/datajudge/constraints/interval.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

import abc
from typing import Any
from typing import Any, TypeAlias

import sqlalchemy as sa

from .. import db_access
from ..db_access import DataReference
from .base import Constraint, _OptionalSelections
from .base import Constraint, _OptionalSelections, _Select

# Both sa.Select and sa.CompoundSelect inherit from sa.GenerativeSelect.
_Selects: TypeAlias = tuple[_Select, _Select]


class IntervalConstraint(Constraint):
Expand All @@ -31,8 +34,7 @@ def __init__(
self._validate_dimensions()

@abc.abstractmethod
def select(self, engine: sa.engine.Engine, ref: DataReference):
pass
def _select(self, engine: sa.engine.Engine, ref: DataReference) -> _Selects: ...

def _validate_dimensions(self):
if (length := len(self._start_columns)) != self._DIMENSIONS:
Expand All @@ -56,7 +58,7 @@ def _retrieve(
engine, keys_ref
)

sample_selection, n_violations_selection = self.select(engine, ref)
sample_selection, n_violations_selection = self._select(engine, ref)
with engine.connect() as connection:
self.sample = connection.execute(sample_selection).first()
n_violation_keys = int(
Expand Down Expand Up @@ -90,7 +92,7 @@ def __init__(
cache_size=cache_size,
)

def select(self, engine: sa.engine.Engine, ref: DataReference):
def _select(self, engine: sa.engine.Engine, ref: DataReference) -> _Selects:
sample_selection, n_violations_selection = db_access.get_interval_overlaps_nd(
engine,
ref,
Expand All @@ -106,8 +108,7 @@ def select(self, engine: sa.engine.Engine, ref: DataReference):
@abc.abstractmethod
def _compare(
self, value_factual: Any, value_target: Any
) -> tuple[bool, str | None]:
pass
) -> tuple[bool, str | None]: ...


class NoGapConstraint(IntervalConstraint):
Expand All @@ -134,11 +135,9 @@ def __init__(
)

@abc.abstractmethod
def select(self, engine: sa.engine.Engine, ref: DataReference):
pass
def _select(self, engine: sa.engine.Engine, ref: DataReference) -> _Selects: ...

@abc.abstractmethod
def _compare(
self, value_factual: tuple[int, int], value_target: Any
) -> tuple[bool, str | None]:
pass
) -> tuple[bool, str | None]: ...
26 changes: 12 additions & 14 deletions src/datajudge/constraints/miscs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,19 @@ def _retrieve(

# Note: Exact equality!
def _compare(
self, primary_keys_factual: set[str], primary_keys_target: set[str]
self, value_factual: set[str], value_target: set[str]
) -> tuple[bool, str | None]:
assertion_message = ""
result = True
# If both are true, just report one.
if len(primary_keys_factual.difference(primary_keys_target)) > 0:
example_key = next(
iter(primary_keys_factual.difference(primary_keys_target))
)
if len(value_factual.difference(value_target)) > 0:
example_key = next(iter(value_factual.difference(value_target)))
assertion_message = (
f"{self._ref} incorrectly includes {example_key} as primary key."
)
result = False
if len(primary_keys_target.difference(primary_keys_factual)) > 0:
example_key = next(
iter(primary_keys_target.difference(primary_keys_factual))
)
if len(value_target.difference(value_factual)) > 0:
example_key = next(iter(value_target.difference(value_factual)))
assertion_message = (
f"{self._ref} doesn't include {example_key} as primary key."
)
Expand Down Expand Up @@ -101,7 +97,9 @@ def test(self, engine: sa.engine.Engine) -> TestResult:
self.target_selections = unique_selections
if row_count == 0:
return TestResult(True, "No occurrences.")
tolerance_kind, tolerance_value = self._ref_value # type: ignore

tolerance_kind, tolerance_value = self._ref_value

if tolerance_kind == "relative":
result = unique_count >= row_count * (1 - tolerance_value)
elif tolerance_kind == "absolute":
Expand Down Expand Up @@ -182,12 +180,12 @@ def _retrieve(self, engine: sa.engine.Engine, ref: DataReference):
return db_access.get_missing_fraction(engine=engine, ref=ref)

def _compare(
self, missing_fraction_factual: float, missing_fracion_target: float
self, value_factual: float, value_target: float
) -> tuple[bool, str | None]:
threshold = missing_fracion_target * (1 + self.max_relative_deviation)
result = missing_fraction_factual <= threshold
threshold = value_target * (1 + self.max_relative_deviation)
result = value_factual <= threshold
assertion_text = (
f"{missing_fraction_factual} of {self._ref} values are NULL "
f"{value_factual} of {self._ref} values are NULL "
f"while only {self._target_prefix}{threshold} were allowed to be NULL."
)
return result, assertion_text
Loading
Loading