diff --git a/mostlyai/engine/_tabular/generation.py b/mostlyai/engine/_tabular/generation.py index e36426bb..c8ec78f6 100644 --- a/mostlyai/engine/_tabular/generation.py +++ b/mostlyai/engine/_tabular/generation.py @@ -820,6 +820,7 @@ def generate( if not enable_flexible_generation: check_column_order(gen_column_order, trn_column_order) + _LOG.info(f"{rare_category_replacement_method=}") rare_token_fixed_probs = fix_rare_token_probs(tgt_stats, rare_category_replacement_method) imputation_fixed_probs = _fix_imputation_probs(tgt_stats, imputation) diff --git a/mostlyai/engine/_tabular/probability.py b/mostlyai/engine/_tabular/probability.py index a97a9f3e..43c59a5e 100644 --- a/mostlyai/engine/_tabular/probability.py +++ b/mostlyai/engine/_tabular/probability.py @@ -363,15 +363,14 @@ def predict_proba( ) ) - # Get seed column names (needed for column order check and _generate_marginal_probs) seed_columns = list(seed_data.columns) - # Check column order when flexible generation is disabled if not enable_flexible_generation: seed_columns_argn = get_argn_column_names(tgt_stats["columns"], seed_columns) target_columns_argn = get_argn_column_names(tgt_stats["columns"], target_columns) - gen_column_order = seed_columns_argn + target_columns_argn - check_column_order(gen_column_order, all_columns) + columns_to_check = seed_columns_argn + target_columns_argn + expected_order = [col for col in all_columns if col in columns_to_check] + check_column_order(columns_to_check, expected_order) # Encode seed data (features to condition on) - common for both single and multi-target # seed_data should NOT include any target columns diff --git a/tests/end_to_end/test_tabular_interface.py b/tests/end_to_end/test_tabular_interface.py index 601049db..be6d9ee4 100644 --- a/tests/end_to_end/test_tabular_interface.py +++ b/tests/end_to_end/test_tabular_interface.py @@ -347,11 +347,15 @@ def test_predict_proba_multi_target( # Numeric binned values (may be bin labels or ranges) assert len(col_values) >= 3 # At least some bins present - def test_predict_proba_wrong_column_order_raises(self, classification_data, tmp_path_factory): - """Test predict_proba raises error with different column order when flexible generation is disabled.""" - data = classification_data - X = data[["feature1", "feature2"]] - y = data["target"] + def test_wrong_column_order_raises(self, tmp_path_factory): + """Test that wrong column order raises error when flexible generation is disabled.""" + data = pd.DataFrame( + { + "col_a": ["x", "y", "z"] * 20, + "col_b": ["p", "q", "r"] * 20, + "col_c": ["1", "2", "3"] * 20, + } + ) argn = TabularARGN( model="MOSTLY_AI/Small", @@ -360,13 +364,33 @@ def test_predict_proba_wrong_column_order_raises(self, classification_data, tmp_ enable_flexible_generation=False, workspace_dir=tmp_path_factory.mktemp("workspace"), ) - argn.fit(X=X, y=y) + argn.fit(X=data) - # Reorder columns in test data - test_X = X.head(10)[["feature2", "feature1"]] + # Wrong seed order for sample + X_wrong_seed = data.head(5)[["col_b", "col_a"]] # wrong: should be col_a, col_b + with pytest.raises(ValueError, match="(?i)column order.*does not match"): + argn.sample(n_samples=5, seed_data=X_wrong_seed) + + # Wrong seed order for predict_proba + with pytest.raises(ValueError, match="(?i)column order.*does not match"): + argn.predict_proba(X_wrong_seed, target="col_c") + # Wrong seed order for predict with pytest.raises(ValueError, match="(?i)column order.*does not match"): - argn.predict_proba(test_X, target="target") + argn.predict(X_wrong_seed, target="col_c") + + # Wrong target order for predict_proba (computes joint probabilities in order) + X_seed = data.head(5)[["col_a"]] + with pytest.raises(ValueError, match="(?i)column order.*does not match"): + argn.predict_proba(X_seed, target=["col_c", "col_b"]) # wrong: should be col_b, col_c + + # predict() doesn't require target order - it generates all columns and extracts targets + result = argn.predict(X_seed, target=["col_c", "col_b"]) + assert list(result.columns) == ["col_c", "col_b"] + + # predict() works even with targets completely out of original order + result = argn.predict(X_seed, target=["col_c", "col_b", "col_a"]) + assert list(result.columns) == ["col_c", "col_b", "col_a"] class TestTabularARGNRegression: