From 805208b5a5e499e4fdfb2302756709bdbd5463de Mon Sep 17 00:00:00 2001 From: Colton Date: Fri, 6 Feb 2026 11:59:03 -0500 Subject: [PATCH 1/3] Have fact loader use bool helper function --- pyreason/pyreason.py | 39 +++++++++------------------------------ 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/pyreason/pyreason.py b/pyreason/pyreason.py index bfc421df..3e36e1ac 100755 --- a/pyreason/pyreason.py +++ b/pyreason/pyreason.py @@ -460,9 +460,9 @@ def fp_version(self, value: bool) -> None: __rules: Optional[numba.typed.List] = None __clause_maps: Optional[dict] = None __node_facts: Optional[numba.typed.List] = None -__node_facts_name_set = set() # We want to warn the user if they add multiple facts with the same name -__rules_name_set = set() # We want to warn the user if they add multiple rules with the same name __edge_facts: Optional[numba.typed.List] = None +__facts_name_set = set() # We want to warn the user if they add multiple facts with the same name +__rules_name_set = set() # We want to warn the user if they add multiple rules with the same name __ipl: Optional[numba.typed.List] = None __specific_node_labels: Optional[numba.typed.List] = None __specific_edge_labels: Optional[numba.typed.List] = None @@ -486,12 +486,12 @@ def reset(): """Resets certain variables to None to be able to do pr.reason() multiple times in a program without memory blowing up """ - global __node_facts, __edge_facts, __graph, __node_facts_name_set + global __node_facts, __edge_facts, __graph, __facts_name_set # Facts __node_facts = None __edge_facts = None - __node_facts_name_set.clear() + __facts_name_set.clear() if __program is not None: __program.reset_facts() @@ -1032,28 +1032,7 @@ def _parse_and_validate_fact_params(idx, name_raw, start_time_raw, end_time_raw, end_time = 0 # Parse static as boolean - static = False - if static_raw is not None: - if isinstance(static_raw, bool): - static = static_raw - elif isinstance(static_raw, str): - static_str = static_raw.strip().lower() - if static_str in ('true', '1', 'yes', 't', 'y'): - static = True - elif static_str in ('false', '0', 'no', 'f', 'n', ''): - static = False - else: - if raise_errors: - raise ValueError(f"{item_label} {idx}: Invalid static value '{static_raw}'") - warnings.warn(f"{item_label} {idx}: Invalid static value '{static_raw}', using default value") - static = False - elif isinstance(static_raw, (int, float)): - static = bool(static_raw) - else: - if raise_errors: - raise ValueError(f"{item_label} {idx}: Invalid static value type '{type(static_raw).__name__}'") - warnings.warn(f"{item_label} {idx}: Invalid static value type '{type(static_raw).__name__}', using default value") - static = False + static = _parse_bool_param(static_raw, 'static', idx, raise_errors, item_label, default=False) return name, start_time, end_time, static @@ -1075,21 +1054,21 @@ def add_fact(pyreason_fact: Fact) -> None: if pyreason_fact.name is None: pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}' - if pyreason_fact.name in __node_facts_name_set: + if pyreason_fact.name in __facts_name_set: warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.") f = fact_node.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static) - __node_facts_name_set.add(pyreason_fact.name) + __facts_name_set.add(pyreason_fact.name) __node_facts.append(f) else: if pyreason_fact.name is None: pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}' - if pyreason_fact.name in __node_facts_name_set: + if pyreason_fact.name in __facts_name_set: warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.") f = fact_edge.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static) - __node_facts_name_set.add(pyreason_fact.name) + __facts_name_set.add(pyreason_fact.name) __edge_facts.append(f) From 6359b4210660817097507f0ef4e3a146becb6729 Mon Sep 17 00:00:00 2001 From: Colton Date: Sat, 7 Feb 2026 13:25:12 -0500 Subject: [PATCH 2/3] Fact loader must support thresholds as a list or a dict --- pyreason/pyreason.py | 76 +++++- tests/api_tests/test_pyreason_file_loading.py | 245 +++++++++--------- 2 files changed, 184 insertions(+), 137 deletions(-) diff --git a/pyreason/pyreason.py b/pyreason/pyreason.py index 3e36e1ac..21e4e887 100755 --- a/pyreason/pyreason.py +++ b/pyreason/pyreason.py @@ -840,8 +840,9 @@ def add_rule_from_json(json_path: str, raise_errors: bool = True) -> None: - **name** (optional): The name of the rule. This will appear in the rule trace. - **infer_edges** (optional): Whether to infer new edges after edge rule fires (default: false). - **set_static** (optional): Whether to set the atom in the head as static if the rule fires (default: false). - - **custom_thresholds** (optional): A list of custom thresholds for the rule, or a mapping of - clause index to threshold. If not specified, the default thresholds for ANY are used. + - **custom_thresholds** (optional): A list of threshold objects (one per clause), or a dict + mapping clause index to threshold object (unspecified clauses get defaults). Each threshold + object has ``quantifier``, ``quantifier_type``, and ``thresh`` fields. - **weights** (optional): A list of weights for the rule clauses. This is passed to an annotation function. Example JSON format:: @@ -855,9 +856,19 @@ def add_rule_from_json(json_path: str, raise_errors: bool = True) -> None: }, { "rule_text": "ally(A, B) <- friend(A, B), common_interest(A, B)", - "name": "ally-rule", - "custom_thresholds": [0.5, 0.8], + "name": "ally-rule-list", + "custom_thresholds": [ + {"quantifier": "greater_equal", "quantifier_type": ["number", "total"], "thresh": 1}, + {"quantifier": "greater_equal", "quantifier_type": ["percent", "total"], "thresh": 100} + ], "weights": [1.0, 2.0] + }, + { + "rule_text": "ally(A, B) <- friend(A, B), common_interest(A, B)", + "name": "ally-rule-dict", + "custom_thresholds": { + "0": {"quantifier": "greater_equal", "quantifier_type": ["percent", "total"], "thresh": 50} + } } ] @@ -925,12 +936,9 @@ def add_rule_from_json(json_path: str, raise_errors: bool = True) -> None: # Extract advanced params (JSON-only) custom_thresholds_raw = rule_obj.get('custom_thresholds') custom_thresholds = None + found_threshold_error = False if custom_thresholds_raw is not None: - if not isinstance(custom_thresholds_raw, list): - if raise_errors: - raise ValueError(f"Item {idx}: 'custom_thresholds' must be a list of threshold objects") - warnings.warn(f"Item {idx}: 'custom_thresholds' must be a list of threshold objects, ignoring") - else: + if isinstance(custom_thresholds_raw, list): custom_thresholds = [] for t_idx, t_obj in enumerate(custom_thresholds_raw): if isinstance(t_obj, dict): @@ -943,15 +951,53 @@ def add_rule_from_json(json_path: str, raise_errors: bool = True) -> None: except (KeyError, ValueError, TypeError) as te: if raise_errors: raise ValueError(f"Item {idx}, threshold {t_idx}: Invalid threshold - {te}") - warnings.warn(f"Item {idx}, threshold {t_idx}: Invalid threshold - {te}, ignoring") - custom_thresholds = None + warnings.warn(f"Item {idx}, threshold {t_idx}: Invalid threshold - {te}, skipping rule") + found_threshold_error = True break else: if raise_errors: raise ValueError(f"Item {idx}, threshold {t_idx}: Expected object, got {type(t_obj).__name__}") - warnings.warn(f"Item {idx}, threshold {t_idx}: Expected object, got {type(t_obj).__name__}, ignoring all thresholds") - custom_thresholds = None + warnings.warn(f"Item {idx}, threshold {t_idx}: Expected object, got {type(t_obj).__name__}, skipping rule") + found_threshold_error = True + break + elif isinstance(custom_thresholds_raw, dict): + custom_thresholds = {} + for key_str, t_obj in custom_thresholds_raw.items(): + try: + clause_idx = int(key_str) + except (ValueError, TypeError): + if raise_errors: + raise ValueError(f"Item {idx}: custom_thresholds dict key '{key_str}' must be an integer clause index") + warnings.warn(f"Item {idx}: custom_thresholds dict key '{key_str}' must be an integer clause index, skipping rule") + found_threshold_error = True + break + if isinstance(t_obj, dict): + try: + custom_thresholds[clause_idx] = Threshold( + t_obj['quantifier'], + tuple(t_obj['quantifier_type']), + t_obj['thresh'] + ) + except (KeyError, ValueError, TypeError) as te: + if raise_errors: + raise ValueError(f"Item {idx}, threshold key '{key_str}': Invalid threshold - {te}") + warnings.warn(f"Item {idx}, threshold key '{key_str}': Invalid threshold - {te}, skipping rule") + found_threshold_error = True + break + else: + if raise_errors: + raise ValueError(f"Item {idx}, threshold key '{key_str}': Expected object, got {type(t_obj).__name__}") + warnings.warn(f"Item {idx}, threshold key '{key_str}': Expected object, got {type(t_obj).__name__}, skipping rule") + found_threshold_error = True break + else: + if raise_errors: + raise ValueError(f"Item {idx}: 'custom_thresholds' must be a list or dict of threshold objects") + warnings.warn(f"Item {idx}: 'custom_thresholds' must be a list or dict of threshold objects, skipping rule") + found_threshold_error = True + if found_threshold_error: + error_count += 1 + continue weights_raw = rule_obj.get('weights') weights = None @@ -959,7 +1005,9 @@ def add_rule_from_json(json_path: str, raise_errors: bool = True) -> None: if not isinstance(weights_raw, list): if raise_errors: raise ValueError(f"Item {idx}: 'weights' must be a list of numeric values") - warnings.warn(f"Item {idx}: 'weights' must be a list of numeric values, ignoring") + warnings.warn(f"Item {idx}: 'weights' must be a list of numeric values, skipping rule") + error_count += 1 + continue else: weights = weights_raw diff --git a/tests/api_tests/test_pyreason_file_loading.py b/tests/api_tests/test_pyreason_file_loading.py index ccc14281..392cf97d 100644 --- a/tests/api_tests/test_pyreason_file_loading.py +++ b/tests/api_tests/test_pyreason_file_loading.py @@ -677,7 +677,129 @@ def test_partial_failure_recovery(self): graph2.add_edge('C', 'D') pr.load_graph(graph2) +class TestRuleTrace: + """Test save_rule_trace() and get_rule_trace() functions.""" + + def setup_method(self): + """Clean state before each test.""" + pr.reset() + pr.reset_settings() + + def test_save_rule_trace_with_store_interpretation_changes_disabled(self): + """Test save_rule_trace() with store_interpretation_changes disabled.""" + pr.settings.store_interpretation_changes = False + + # Create a simple interpretation (empty for this test) + interpretation = {} + + with pytest.raises(AssertionError, match='store interpretation changes setting is off'): + pr.save_rule_trace(interpretation) + + def test_get_rule_trace_with_store_interpretation_changes_disabled(self): + """Test get_rule_trace() with store_interpretation_changes disabled.""" + pr.settings.store_interpretation_changes = False + + # Create a simple interpretation (empty for this test) + interpretation = {} + + with pytest.raises(AssertionError, match='store interpretation changes setting is off'): + pr.get_rule_trace(interpretation) + + def test_save_rule_trace_with_store_interpretation_changes_enabled(self): + """Test save_rule_trace() with store_interpretation_changes enabled.""" + pr.settings.store_interpretation_changes = True + + # Create a simple graph and run reasoning to get an interpretation + graph = nx.DiGraph() + graph.add_edge('A', 'B') + pr.load_graph(graph) + + # Add a simple fact and rule + pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) + pr.add_rule(Rule('friend(A, B) <- person(A)', 'test_rule', False)) + + # Run reasoning to get interpretation + interpretation = pr.reason(1) + + # Test save_rule_trace with default folder + with tempfile.TemporaryDirectory() as temp_dir: + pr.save_rule_trace(interpretation, temp_dir) + # Check that files were created (exact files depend on implementation) + files_created = os.listdir(temp_dir) + assert len(files_created) > 0, "Expected files to be created in the trace folder" + + def test_save_rule_trace_with_custom_folder(self): + """Test save_rule_trace() with custom folder path.""" + pr.settings.store_interpretation_changes = True + + # Create a simple graph and run reasoning + graph = nx.DiGraph() + graph.add_edge('A', 'B') + pr.load_graph(graph) + + pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) + pr.add_rule(Rule('friend(A, B) <- person(A)', 'test_rule', False)) + interpretation = pr.reason(1) + + # Test with custom folder + with tempfile.TemporaryDirectory() as temp_dir: + custom_folder = os.path.join(temp_dir, 'custom_trace') + os.makedirs(custom_folder, exist_ok=True) + + pr.save_rule_trace(interpretation, custom_folder) + files_created = os.listdir(custom_folder) + assert len(files_created) > 0, "Expected files to be created in the custom trace folder" + + def test_get_rule_trace_with_store_interpretation_changes_enabled(self): + """Test get_rule_trace() with store_interpretation_changes enabled.""" + pr.settings.store_interpretation_changes = True + + # Create a simple graph and run reasoning + graph = nx.DiGraph() + graph.add_edge('A', 'B') + pr.load_graph(graph) + + pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) + pr.add_rule(Rule('friend(A, B) <- person(A)', 'test_rule', False)) + + interpretation = pr.reason(1) + + # Test get_rule_trace + node_trace, edge_trace = pr.get_rule_trace(interpretation) + + # Verify return types are DataFrames + assert isinstance(node_trace, pd.DataFrame), "Expected node_trace to be a pandas DataFrame" + assert isinstance(edge_trace, pd.DataFrame), "Expected edge_trace to be a pandas DataFrame" + + def test_get_rule_trace_returns_dataframes(self): + """Test that get_rule_trace() returns proper DataFrame structures.""" + pr.settings.store_interpretation_changes = True + + # Create a more complex scenario + graph = nx.DiGraph() + graph.add_edges_from([('A', 'B'), ('B', 'C'), ('C', 'A')]) + pr.load_graph(graph) + + # Add multiple facts and rules + pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) + pr.add_fact(pr.Fact('person(B)', 'B', 1, 1)) + pr.add_rule(Rule('friend(A, B) <- person(A)', 'rule1', False)) + pr.add_rule(Rule('likes(A, B) <- friend(A, B)', 'rule2', False)) + + interpretation = pr.reason(2) + + node_trace, edge_trace = pr.get_rule_trace(interpretation) + + # Basic structure verification + assert isinstance(node_trace, pd.DataFrame) + assert isinstance(edge_trace, pd.DataFrame) + + # DataFrames should have some basic expected structure + # (exact columns depend on implementation, but they should be valid DataFrames) + assert hasattr(node_trace, 'columns') + assert hasattr(edge_trace, 'columns') + class TestAddFactFromJSON: """Test add_fact_from_json() function for loading facts from JSON.""" @@ -987,129 +1109,6 @@ def test_add_fact_from_csv_empty_optional_fields(self): os.unlink(tmp_path) -class TestRuleTrace: - """Test save_rule_trace() and get_rule_trace() functions.""" - - def setup_method(self): - """Clean state before each test.""" - pr.reset() - pr.reset_settings() - - def test_save_rule_trace_with_store_interpretation_changes_disabled(self): - """Test save_rule_trace() with store_interpretation_changes disabled.""" - pr.settings.store_interpretation_changes = False - - # Create a simple interpretation (empty for this test) - interpretation = {} - - with pytest.raises(AssertionError, match='store interpretation changes setting is off'): - pr.save_rule_trace(interpretation) - - def test_get_rule_trace_with_store_interpretation_changes_disabled(self): - """Test get_rule_trace() with store_interpretation_changes disabled.""" - pr.settings.store_interpretation_changes = False - - # Create a simple interpretation (empty for this test) - interpretation = {} - - with pytest.raises(AssertionError, match='store interpretation changes setting is off'): - pr.get_rule_trace(interpretation) - - def test_save_rule_trace_with_store_interpretation_changes_enabled(self): - """Test save_rule_trace() with store_interpretation_changes enabled.""" - pr.settings.store_interpretation_changes = True - - # Create a simple graph and run reasoning to get an interpretation - graph = nx.DiGraph() - graph.add_edge('A', 'B') - pr.load_graph(graph) - - # Add a simple fact and rule - pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) - pr.add_rule(Rule('friend(A, B) <- person(A)', 'test_rule', False)) - - # Run reasoning to get interpretation - interpretation = pr.reason(1) - - # Test save_rule_trace with default folder - with tempfile.TemporaryDirectory() as temp_dir: - pr.save_rule_trace(interpretation, temp_dir) - # Check that files were created (exact files depend on implementation) - files_created = os.listdir(temp_dir) - assert len(files_created) > 0, "Expected files to be created in the trace folder" - - def test_save_rule_trace_with_custom_folder(self): - """Test save_rule_trace() with custom folder path.""" - pr.settings.store_interpretation_changes = True - - # Create a simple graph and run reasoning - graph = nx.DiGraph() - graph.add_edge('A', 'B') - pr.load_graph(graph) - - pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) - pr.add_rule(Rule('friend(A, B) <- person(A)', 'test_rule', False)) - - interpretation = pr.reason(1) - - # Test with custom folder - with tempfile.TemporaryDirectory() as temp_dir: - custom_folder = os.path.join(temp_dir, 'custom_trace') - os.makedirs(custom_folder, exist_ok=True) - - pr.save_rule_trace(interpretation, custom_folder) - files_created = os.listdir(custom_folder) - assert len(files_created) > 0, "Expected files to be created in the custom trace folder" - - def test_get_rule_trace_with_store_interpretation_changes_enabled(self): - """Test get_rule_trace() with store_interpretation_changes enabled.""" - pr.settings.store_interpretation_changes = True - - # Create a simple graph and run reasoning - graph = nx.DiGraph() - graph.add_edge('A', 'B') - pr.load_graph(graph) - - pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) - pr.add_rule(Rule('friend(A, B) <- person(A)', 'test_rule', False)) - - interpretation = pr.reason(1) - - # Test get_rule_trace - node_trace, edge_trace = pr.get_rule_trace(interpretation) - - # Verify return types are DataFrames - assert isinstance(node_trace, pd.DataFrame), "Expected node_trace to be a pandas DataFrame" - assert isinstance(edge_trace, pd.DataFrame), "Expected edge_trace to be a pandas DataFrame" - - def test_get_rule_trace_returns_dataframes(self): - """Test that get_rule_trace() returns proper DataFrame structures.""" - pr.settings.store_interpretation_changes = True - - # Create a more complex scenario - graph = nx.DiGraph() - graph.add_edges_from([('A', 'B'), ('B', 'C'), ('C', 'A')]) - pr.load_graph(graph) - - # Add multiple facts and rules - pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) - pr.add_fact(pr.Fact('person(B)', 'B', 1, 1)) - pr.add_rule(Rule('friend(A, B) <- person(A)', 'rule1', False)) - pr.add_rule(Rule('likes(A, B) <- friend(A, B)', 'rule2', False)) - - interpretation = pr.reason(2) - - node_trace, edge_trace = pr.get_rule_trace(interpretation) - - # Basic structure verification - assert isinstance(node_trace, pd.DataFrame) - assert isinstance(edge_trace, pd.DataFrame) - - # DataFrames should have some basic expected structure - # (exact columns depend on implementation, but they should be valid DataFrames) - assert hasattr(node_trace, 'columns') - assert hasattr(edge_trace, 'columns') - class TestAddRulesFromFile: """Test add_rules_from_file() function.""" From ac3bd10f3d04367218613889d17fb058c6d12729 Mon Sep 17 00:00:00 2001 From: Colton Date: Wed, 11 Feb 2026 14:06:55 -0500 Subject: [PATCH 3/3] Update function name from variable to component --- pyreason/scripts/utils/rule_parser.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyreason/scripts/utils/rule_parser.py b/pyreason/scripts/utils/rule_parser.py index f298b5f5..7f7d1874 100644 --- a/pyreason/scripts/utils/rule_parser.py +++ b/pyreason/scripts/utils/rule_parser.py @@ -102,7 +102,7 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: Union[None, list, d # Validate head variable names for var in head_variables: - _validate_variable_name(var, "Head") + _validate_component_name(var, "Head") # Assign type of rule rule_type = 'node' if len(head_variables) == 1 else 'edge' @@ -132,7 +132,7 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: Union[None, list, d variables += clause_str[start_idx+1:end_idx].split(',') # Validate body variable names for var in variables: - _validate_variable_name(var, "Body") + _validate_component_name(var, "Body") body_variables.append(variables) # Change infer edge parameter if it's a node rule @@ -489,12 +489,12 @@ def _validate_predicate_name(pred, context): raise ValueError(f"{context} predicate name '{pred}' contains invalid characters. Must match [a-zA-Z_][a-zA-Z0-9_]*") -def _validate_variable_name(var, context): +def _validate_component_name(var, context): """Validate that a variable name matches ^[a-zA-Z_][a-zA-Z0-9_]*$.""" if not _IDENTIFIER_RE.match(var): if var and var[0].isdigit(): - raise ValueError(f"{context} variable name '{var}' cannot start with a digit") - raise ValueError(f"{context} variable name '{var}' contains invalid characters. Must match [a-zA-Z_][a-zA-Z0-9_]*") + raise ValueError(f"{context} component name '{var}' cannot start with a digit") + raise ValueError(f"{context} component name '{var}' contains invalid characters. Must match [a-zA-Z_][a-zA-Z0-9_]*") def _str_bound_to_bound(str_bound):