-
Notifications
You must be signed in to change notification settings - Fork 63
fix(constraints): consider numeric/datetime extreme value clipping #700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,8 @@ | |
|
|
||
| import hashlib | ||
| import logging | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import pandas as pd | ||
|
|
||
|
|
@@ -41,12 +43,13 @@ class InequalityHandler(ConstraintHandler): | |
|
|
||
| _DATETIME_EPOCH = pd.Timestamp("1970-01-01") # reference epoch for delta representation | ||
|
|
||
| def __init__(self, constraint: Inequality, table=None): | ||
| def __init__(self, constraint: Inequality, table=None, workspace_dir: Path | None = None): | ||
| self.constraint = constraint | ||
| self.table_name = constraint.table_name | ||
| self.low_column = constraint.low_column | ||
| self.high_column = constraint.high_column | ||
| self._delta_column = _generate_internal_column_name("INEQ_DELTA", [self.low_column, self.high_column]) | ||
| self.workspace_dir = workspace_dir | ||
|
|
||
| # determine if this is a datetime constraint based on table encoding types | ||
| self._is_datetime = False | ||
|
|
@@ -134,12 +137,63 @@ def to_original(self, df: pd.DataFrame) -> pd.DataFrame: | |
| violations = both_valid_for_check & (high < low) | ||
| df.loc[violations, self.high_column] = low[violations] | ||
|
|
||
| # clip to training data bounds | ||
| if self.workspace_dir is not None: | ||
| self._clip_to_training_bounds(df) | ||
|
|
||
| # convert back to original dtype | ||
| if pd.api.types.is_integer_dtype(high_dtype): | ||
| df[self.high_column] = df[self.high_column].astype(high_dtype) | ||
|
|
||
| return df.drop(columns=[self._delta_column]) | ||
|
|
||
| def _extract_min_max_from_stats(self, col_stats: dict) -> tuple[Any, Any]: | ||
| """extract min/max from column stats, handling all encoding types (same pattern as parse_min_max).""" | ||
| # try bins/min5/max5 arrays first (for binned/digit/datetime encoding) | ||
| values = col_stats.get("bins", []) + col_stats.get("min5", []) + col_stats.get("max5", []) | ||
| if values: | ||
| return min(values), max(values) | ||
| # fall back to direct min/max (for other encoding types or when arrays are empty) | ||
| return col_stats.get("min"), col_stats.get("max") | ||
|
|
||
| def _clip_to_training_bounds(self, df: pd.DataFrame) -> None: | ||
| """clip high column values to min/max from training data stats.""" | ||
| from mostlyai.engine._workspace import Workspace | ||
|
|
||
| workspace = Workspace(self.workspace_dir) | ||
| tgt_stats = workspace.tgt_stats.read() | ||
| if not tgt_stats or self.high_column not in tgt_stats.get("columns", {}): | ||
| return | ||
|
|
||
| col_stats = tgt_stats["columns"][self.high_column] | ||
| min_val, max_val = self._extract_min_max_from_stats(col_stats) | ||
| if min_val is None and max_val is None: | ||
| return | ||
|
|
||
| high = df[self.high_column] | ||
| low = df[self.low_column] | ||
| if self._is_datetime: | ||
| if min_val is not None: | ||
| min_val = pd.to_datetime(min_val) | ||
| df.loc[high.notna() & (high < min_val), self.high_column] = min_val | ||
| if max_val is not None: | ||
| max_val = pd.to_datetime(max_val) | ||
| df.loc[high.notna() & (high > max_val), self.high_column] = max_val | ||
| else: | ||
| if min_val is not None: | ||
| min_val = float(min_val) | ||
| df.loc[high.notna() & (high < min_val), self.high_column] = min_val | ||
| if max_val is not None: | ||
| max_val = float(max_val) | ||
| df.loc[high.notna() & (high > max_val), self.high_column] = max_val | ||
|
|
||
| # ensure constraint is still satisfied after clipping | ||
| # if clipping made high < low, set high = low | ||
| high = df[self.high_column] | ||
| both_valid = low.notna() & high.notna() | ||
| violations = both_valid & (high < low) | ||
| df.loc[violations, self.high_column] = low[violations] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clipping can be undone by constraint correctionMedium Severity The |
||
|
|
||
| def get_encoding_types(self) -> dict[str, str]: | ||
| # use TABULAR_DATETIME for datetime constraints to preserve precision | ||
| # use TABULAR_NUMERIC_AUTO for numeric constraints | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.