diff --git a/pyreason/pyreason.py b/pyreason/pyreason.py index ee78df92..4945b96a 100755 --- a/pyreason/pyreason.py +++ b/pyreason/pyreason.py @@ -460,8 +460,9 @@ def fp_version(self, value: bool) -> None: __rules: Optional[numba.typed.List] = None __clause_maps: Optional[dict] = None __node_facts: Optional[numba.typed.List] = None -__node_facts_name_set = set() # We want to warn the user if they add multiple facts with the same name __edge_facts: Optional[numba.typed.List] = None +__facts_name_set = set() # We want to warn the user if they add multiple facts with the same name +__rules_name_set = set() # We want to warn the user if they add multiple rules with the same name __ipl: Optional[numba.typed.List] = None __specific_node_labels: Optional[numba.typed.List] = None __specific_edge_labels: Optional[numba.typed.List] = None @@ -485,12 +486,12 @@ def reset(): """Resets certain variables to None to be able to do pr.reason() multiple times in a program without memory blowing up """ - global __node_facts, __edge_facts, __graph, __node_facts_name_set + global __node_facts, __edge_facts, __graph, __facts_name_set # Facts __node_facts = None __edge_facts = None - __node_facts_name_set.clear() + __facts_name_set.clear() if __program is not None: __program.reset_facts() @@ -516,6 +517,7 @@ def reset_rules(): """ global __rules, __annotation_functions, __head_functions __rules = None + __rules_name_set.clear() __annotation_functions = [] __head_functions = [] if __program is not None: @@ -606,24 +608,439 @@ def add_rule(pr_rule: Rule) -> None: if pr_rule.rule.get_rule_name() is None: pr_rule.rule.set_rule_name(f'rule_{len(__rules)}') + if pr_rule.rule.get_rule_name() in __rules_name_set: + warnings.warn(f"Rule {pr_rule.rule.get_rule_name()} has already been added. Duplicate rule names will lead to an ambiguous rule trace.") + + __rules_name_set.add(pr_rule.rule.get_rule_name()) __rules.append(pr_rule.rule) -def add_rules_from_file(file_path: str, infer_edges: bool = False) -> None: - """ Add a set of rules from a text file +def add_rules_from_file(file_path: str, infer_edges: bool = False, raise_errors: bool = False) -> None: + """Add a set of rules from a text file. + + Each non-empty, non-comment line is treated as a rule in text format. + Lines starting with ``#`` are treated as comments and skipped. + The ``infer_edges`` parameter is applied uniformly to all rules loaded from the file. :param file_path: Path to the text file containing rules :type file_path: str :param infer_edges: Whether to infer edges on these rules if an edge doesn't exist between head variables and the body of the rule is satisfied :type infer_edges: bool + :param raise_errors: If True, raise on invalid rules. If False, warn and skip them. + :type raise_errors: bool :return: None + :raises FileNotFoundError: If the text file doesn't exist + :raises ValueError: If rule parsing fails and raise_errors is True """ with open(file_path, 'r') as file: rules = [line.rstrip() for line in file if line.rstrip() != '' and line.rstrip()[0] != '#'] + loaded_count = 0 + error_count = 0 + rule_offset = 0 if __rules is None else len(__rules) for i, r in enumerate(rules): - add_rule(Rule(r, f'rule_{i+rule_offset}', infer_edges)) + try: + add_rule(Rule(r, f'rule_{i+rule_offset}', infer_edges)) + loaded_count += 1 + except Exception as e: + if raise_errors: + raise ValueError(f"Line {i + 1}: Failed to parse rule '{r}' - {e}") from e + error_count += 1 + warnings.warn(f"Line {i + 1}: Failed to parse rule '{r}' - {e}") + + if settings.verbose: + print(f"Loaded {loaded_count} rules from {file_path}") + if error_count > 0: + print(f"Failed to load {error_count} rules due to errors") + + +def _parse_bool_param(raw_value, param_name, idx, raise_errors, item_label="Item", default=False): + """Private helper to parse a raw value as a boolean. + + :param raw_value: Raw value to parse (can be None, str, bool, int, float) + :param param_name: Name of the parameter (for error messages) + :param idx: Index of the item being parsed (for error messages) + :param raise_errors: Whether to raise errors or just warn + :param item_label: Label for error messages (e.g., "Item", "Row") + :param default: Default value if raw_value is None or empty + :return: Parsed boolean value + :raises ValueError: If validation fails and raise_errors is True + """ + if raw_value is None: + return default + if isinstance(raw_value, bool): + return raw_value + if isinstance(raw_value, str): + val_str = raw_value.strip().lower() + if val_str in ('true', '1', 'yes', 't', 'y'): + return True + elif val_str in ('false', '0', 'no', 'f', 'n', ''): + return default if val_str == '' else False + else: + if raise_errors: + raise ValueError(f"{item_label} {idx}: Invalid {param_name} value '{raw_value}'") + warnings.warn(f"{item_label} {idx}: Invalid {param_name} value '{raw_value}', using default value") + return default + if isinstance(raw_value, (int, float)): + return bool(raw_value) + if raise_errors: + raise ValueError(f"{item_label} {idx}: Invalid {param_name} value type '{type(raw_value).__name__}'") + warnings.warn(f"{item_label} {idx}: Invalid {param_name} value type '{type(raw_value).__name__}', using default value") + return default + + +def _parse_and_validate_rule_params(idx, name_raw, infer_edges_raw, set_static_raw, raise_errors, item_label="Item"): + """Private helper to parse and validate rule parameters. + + :param idx: Index of the item being parsed (for error messages) + :param name_raw: Raw name value (can be None, str, or other types) + :param infer_edges_raw: Raw infer_edges value + :param set_static_raw: Raw set_static value + :param raise_errors: Whether to raise errors or just warn + :param item_label: Label for error messages (e.g., "Item", "Row") + :return: Tuple of (name, infer_edges, set_static) + :raises ValueError: If validation fails and raise_errors is True + """ + # Parse name + name = None + if name_raw is not None: + name = str(name_raw).strip() if str(name_raw).strip() else None + + # Parse infer_edges + infer_edges = _parse_bool_param(infer_edges_raw, 'infer_edges', idx, raise_errors, item_label, default=False) + + # Parse set_static + set_static = _parse_bool_param(set_static_raw, 'set_static', idx, raise_errors, item_label, default=False) + + return name, infer_edges, set_static + + +def add_rule_from_csv(csv_path: str, raise_errors: bool = True) -> None: + """Load multiple rules from a CSV file. + + Each row should have up to 4 comma-separated values in this order: + ``rule_text, name, infer_edges, set_static`` + + - **rule_text** (required): The rule in text format, e.g., ``friend(A, B) <- knows(A, B)`` + or ``"ally(A, B) <- friend(A, B), common_interest(A, B)"`` for rules with commas. + - **name** (optional): A unique name for the rule (can be empty). + - **infer_edges** (optional): Whether to infer new edges after edge rule fires (default: False). + Accepts: True/False, 1/0, yes/no (case-insensitive). + - **set_static** (optional): Whether to set the atom in the head as static if the rule fires (default: False). + Accepts: True/False, 1/0, yes/no (case-insensitive). + + A header row is optional. If included, it must be exactly:: + + rule_text,name,infer_edges,set_static + + Any other header format will be treated as a data row and will likely raise a parsing error. + + Example CSV:: + + rule_text,name,infer_edges,set_static + friend(A, B) <- knows(A, B),friendship-rule,False,False + "ally(A, B) <- friend(A, B), common_interest(A, B)",ally-rule,False,False + connected(A, B) <- link(A, B),connected-rule,True,False + + :param csv_path: Path to the CSV file containing rules + :type csv_path: str + :param raise_errors: If True, raise on invalid rows. If False, warn and skip them. + :type raise_errors: bool + :return: None + :raises FileNotFoundError: If the CSV file doesn't exist + :raises ValueError: If rule parsing fails or CSV format is invalid + """ + try: + df = pd.read_csv(csv_path, header=None, dtype=str, keep_default_na=False) + except FileNotFoundError: + raise FileNotFoundError(f"CSV file not found: {csv_path}") + except pd.errors.EmptyDataError: + warnings.warn(f"CSV file {csv_path} is empty, no rules loaded") + return + except Exception as e: + raise ValueError(f"Error reading CSV file {csv_path}: {e}") + + if df.empty: + warnings.warn(f"CSV file {csv_path} is empty, no rules loaded") + return + + # Skip first row if it exactly matches the expected header + expected_header = ['rule_text', 'name', 'infer_edges', 'set_static'] + first_row = [str(v).strip() for v in df.iloc[0]] if len(df) > 0 else [] + has_header = first_row == expected_header + start_idx = 1 if has_header else 0 + + # Track loaded rules for reporting + loaded_count = 0 + error_count = 0 + loaded_name_set = set() + + # Process each row + for idx, row in df.iloc[start_idx:].iterrows(): + try: + # Extract rule_text (required, column 0) + if len(row) < 1 or not str(row[0]).strip(): + if raise_errors: + raise ValueError(f"Row {idx + 1}: Missing required 'rule_text'") + warnings.warn(f"Row {idx + 1}: Missing required 'rule_text', skipping row") + error_count += 1 + continue + + rule_text = str(row[0]).strip() + + # Parse and validate parameters using shared helper + name, infer_edges, set_static = _parse_and_validate_rule_params( + idx + 1, + row[1] if len(row) > 1 else None, + row[2] if len(row) > 2 else None, + row[3] if len(row) > 3 else None, + raise_errors, + "Row" + ) + + # Check for duplicate names + if name and name in loaded_name_set: + if raise_errors: + raise ValueError(f"Row {idx + 1}: Loaded name '{name}' is a duplicate - all rule names must be unique.") + warnings.warn(f"Row {idx + 1}: Loaded name '{name}' is a duplicate - all rule names must be unique.") + error_count += 1 + continue + if name: + loaded_name_set.add(name) + + # Create and add the rule + r = Rule(rule_text=rule_text, name=name, infer_edges=infer_edges, set_static=set_static) + add_rule(r) + loaded_count += 1 + + except ValueError as e: + if raise_errors: + raise ValueError(f"Row {idx + 1}: Failed to parse rule - {e}") from e + error_count += 1 + warnings.warn(f"Row {idx + 1}: Failed to parse rule - {e}") + except Exception as e: + if raise_errors: + raise Exception(f"Row {idx + 1}: Unexpected error - {e}") from e + error_count += 1 + warnings.warn(f"Row {idx + 1}: Unexpected error - {e}") + + if settings.verbose: + print(f"Loaded {loaded_count} rules from {csv_path}") + if error_count > 0: + print(f"Failed to load {error_count} rules due to errors") + + +def add_rule_from_json(json_path: str, raise_errors: bool = True) -> None: + """Load multiple rules from a JSON file. + + The JSON should be an array of objects, where each object represents a Rule with these fields: + + - **rule_text** (required): The rule in text format, e.g., ``"friend(A, B) <- knows(A, B)"`` + - **name** (optional): The name of the rule. This will appear in the rule trace. + - **infer_edges** (optional): Whether to infer new edges after edge rule fires (default: false). + - **set_static** (optional): Whether to set the atom in the head as static if the rule fires (default: false). + - **custom_thresholds** (optional): A list of threshold objects (one per clause), or a dict + mapping clause index to threshold object (unspecified clauses get defaults). Each threshold + object has ``quantifier``, ``quantifier_type``, and ``thresh`` fields. + - **weights** (optional): A list of weights for the rule clauses. This is passed to an annotation function. + + Example JSON format:: + + [ + { + "rule_text": "friend(A, B) <- knows(A, B)", + "name": "friendship-rule", + "infer_edges": false, + "set_static": false + }, + { + "rule_text": "ally(A, B) <- friend(A, B), common_interest(A, B)", + "name": "ally-rule-list", + "custom_thresholds": [ + {"quantifier": "greater_equal", "quantifier_type": ["number", "total"], "thresh": 1}, + {"quantifier": "greater_equal", "quantifier_type": ["percent", "total"], "thresh": 100} + ], + "weights": [1.0, 2.0] + }, + { + "rule_text": "ally(A, B) <- friend(A, B), common_interest(A, B)", + "name": "ally-rule-dict", + "custom_thresholds": { + "0": {"quantifier": "greater_equal", "quantifier_type": ["percent", "total"], "thresh": 50} + } + } + ] + + :param json_path: Path to the JSON file containing rules + :type json_path: str + :param raise_errors: If True, raise on invalid items. If False, warn and skip them. + :type raise_errors: bool + :return: None + :raises FileNotFoundError: If the JSON file doesn't exist + :raises ValueError: If rule parsing fails or JSON format is invalid + """ + try: + with open(json_path, 'r') as f: + data = json.load(f) + except FileNotFoundError: + raise FileNotFoundError(f"JSON file not found: {json_path}") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON format in file {json_path}: {e}") + except Exception as e: + raise ValueError(f"Error reading JSON file {json_path}: {e}") + + if not isinstance(data, list): + raise ValueError(f"JSON file must contain an array of rule objects, got {type(data).__name__}") + + if len(data) == 0: + warnings.warn(f"JSON file {json_path} contains an empty array, no rules loaded") + return + + # Track loaded rules for reporting + loaded_count = 0 + error_count = 0 + loaded_name_set = set() + + # Process each rule object + for idx, rule_obj in enumerate(data): + try: + if not isinstance(rule_obj, dict): + if raise_errors: + raise ValueError(f"Item {idx}: Expected object, got {type(rule_obj).__name__}") + warnings.warn(f"Item {idx}: Expected object, got {type(rule_obj).__name__}, skipping item") + error_count += 1 + continue + + # Extract rule_text (required) + rule_text = rule_obj.get('rule_text') + if not rule_text or not str(rule_text).strip(): + if raise_errors: + raise ValueError(f"Item {idx}: Missing required 'rule_text'") + warnings.warn(f"Item {idx}: Missing required 'rule_text', skipping item") + error_count += 1 + continue + + rule_text = str(rule_text).strip() + + # Parse and validate parameters using shared helper + name, infer_edges, set_static = _parse_and_validate_rule_params( + idx, + rule_obj.get('name'), + rule_obj.get('infer_edges', False), + rule_obj.get('set_static', False), + raise_errors, + "Item" + ) + + # Extract advanced params (JSON-only) + custom_thresholds_raw = rule_obj.get('custom_thresholds') + custom_thresholds = None + found_threshold_error = False + if custom_thresholds_raw is not None: + if isinstance(custom_thresholds_raw, list): + custom_thresholds = [] + for t_idx, t_obj in enumerate(custom_thresholds_raw): + if isinstance(t_obj, dict): + try: + custom_thresholds.append(Threshold( + t_obj['quantifier'], + tuple(t_obj['quantifier_type']), + t_obj['thresh'] + )) + except (KeyError, ValueError, TypeError) as te: + if raise_errors: + raise ValueError(f"Item {idx}, threshold {t_idx}: Invalid threshold - {te}") + warnings.warn(f"Item {idx}, threshold {t_idx}: Invalid threshold - {te}, skipping rule") + found_threshold_error = True + break + else: + if raise_errors: + raise ValueError(f"Item {idx}, threshold {t_idx}: Expected object, got {type(t_obj).__name__}") + warnings.warn(f"Item {idx}, threshold {t_idx}: Expected object, got {type(t_obj).__name__}, skipping rule") + found_threshold_error = True + break + elif isinstance(custom_thresholds_raw, dict): + custom_thresholds = {} + for key_str, t_obj in custom_thresholds_raw.items(): + try: + clause_idx = int(key_str) + except (ValueError, TypeError): + if raise_errors: + raise ValueError(f"Item {idx}: custom_thresholds dict key '{key_str}' must be an integer clause index") + warnings.warn(f"Item {idx}: custom_thresholds dict key '{key_str}' must be an integer clause index, skipping rule") + found_threshold_error = True + break + if isinstance(t_obj, dict): + try: + custom_thresholds[clause_idx] = Threshold( + t_obj['quantifier'], + tuple(t_obj['quantifier_type']), + t_obj['thresh'] + ) + except (KeyError, ValueError, TypeError) as te: + if raise_errors: + raise ValueError(f"Item {idx}, threshold key '{key_str}': Invalid threshold - {te}") + warnings.warn(f"Item {idx}, threshold key '{key_str}': Invalid threshold - {te}, skipping rule") + found_threshold_error = True + break + else: + if raise_errors: + raise ValueError(f"Item {idx}, threshold key '{key_str}': Expected object, got {type(t_obj).__name__}") + warnings.warn(f"Item {idx}, threshold key '{key_str}': Expected object, got {type(t_obj).__name__}, skipping rule") + found_threshold_error = True + break + else: + if raise_errors: + raise ValueError(f"Item {idx}: 'custom_thresholds' must be a list or dict of threshold objects") + warnings.warn(f"Item {idx}: 'custom_thresholds' must be a list or dict of threshold objects, skipping rule") + found_threshold_error = True + if found_threshold_error: + error_count += 1 + continue + + weights_raw = rule_obj.get('weights') + weights = None + if weights_raw is not None: + if not isinstance(weights_raw, list): + if raise_errors: + raise ValueError(f"Item {idx}: 'weights' must be a list of numeric values") + warnings.warn(f"Item {idx}: 'weights' must be a list of numeric values, skipping rule") + error_count += 1 + continue + else: + weights = weights_raw + + # Check for duplicate names + if name and name in loaded_name_set: + if raise_errors: + raise ValueError(f"Item {idx}: Loaded name '{name}' is a duplicate - all rule names must be unique.") + warnings.warn(f"Item {idx}: Loaded name '{name}' is a duplicate - all rule names must be unique.") + error_count += 1 + continue + if name: + loaded_name_set.add(name) + + # Create and add the rule + r = Rule(rule_text=rule_text, name=name, infer_edges=infer_edges, set_static=set_static, custom_thresholds=custom_thresholds, weights=weights) + add_rule(r) + loaded_count += 1 + + except ValueError as e: + if raise_errors: + raise ValueError(f"Item {idx}: Failed to parse rule - {e}") from e + error_count += 1 + warnings.warn(f"Item {idx}: Failed to parse rule - {e}") + except Exception as e: + if raise_errors: + raise Exception(f"Item {idx}: Unexpected error - {e}") from e + error_count += 1 + warnings.warn(f"Item {idx}: Unexpected error - {e}") + + if settings.verbose: + print(f"Loaded {loaded_count} rules from {json_path}") + if error_count > 0: + print(f"Failed to load {error_count} rules due to errors") def _parse_and_validate_fact_params(idx, name_raw, start_time_raw, end_time_raw, static_raw, raise_errors, item_label="Item"): @@ -663,28 +1080,7 @@ def _parse_and_validate_fact_params(idx, name_raw, start_time_raw, end_time_raw, end_time = start_time # Parse static as boolean - static = False - if static_raw is not None: - if isinstance(static_raw, bool): - static = static_raw - elif isinstance(static_raw, str): - static_str = static_raw.strip().lower() - if static_str in ('true', 'yes', 't', 'y'): - static = True - elif static_str in ('false', 'no', 'f', 'n', ''): - static = False - else: - if raise_errors: - raise ValueError(f"{item_label} {idx}: Invalid static value '{static_raw}'") - warnings.warn(f"{item_label} {idx}: Invalid static value '{static_raw}', using default value") - static = False - elif isinstance(static_raw, (int, float)): - static = bool(static_raw) - else: - if raise_errors: - raise ValueError(f"{item_label} {idx}: Invalid static value type '{type(static_raw).__name__}'") - warnings.warn(f"{item_label} {idx}: Invalid static value type '{type(static_raw).__name__}', using default value") - static = False + static = _parse_bool_param(static_raw, 'static', idx, raise_errors, item_label, default=False) return name, start_time, end_time, static @@ -706,21 +1102,21 @@ def add_fact(pyreason_fact: Fact) -> None: if pyreason_fact.name is None: pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}' - if pyreason_fact.name in __node_facts_name_set: + if pyreason_fact.name in __facts_name_set: warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.") f = fact_node.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static) - __node_facts_name_set.add(pyreason_fact.name) + __facts_name_set.add(pyreason_fact.name) __node_facts.append(f) else: if pyreason_fact.name is None: pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}' - if pyreason_fact.name in __node_facts_name_set: + if pyreason_fact.name in __facts_name_set: warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.") f = fact_edge.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static) - __node_facts_name_set.add(pyreason_fact.name) + __facts_name_set.add(pyreason_fact.name) __edge_facts.append(f) diff --git a/pyreason/scripts/utils/rule_parser.py b/pyreason/scripts/utils/rule_parser.py index f298b5f5..7f7d1874 100644 --- a/pyreason/scripts/utils/rule_parser.py +++ b/pyreason/scripts/utils/rule_parser.py @@ -102,7 +102,7 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: Union[None, list, d # Validate head variable names for var in head_variables: - _validate_variable_name(var, "Head") + _validate_component_name(var, "Head") # Assign type of rule rule_type = 'node' if len(head_variables) == 1 else 'edge' @@ -132,7 +132,7 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: Union[None, list, d variables += clause_str[start_idx+1:end_idx].split(',') # Validate body variable names for var in variables: - _validate_variable_name(var, "Body") + _validate_component_name(var, "Body") body_variables.append(variables) # Change infer edge parameter if it's a node rule @@ -489,12 +489,12 @@ def _validate_predicate_name(pred, context): raise ValueError(f"{context} predicate name '{pred}' contains invalid characters. Must match [a-zA-Z_][a-zA-Z0-9_]*") -def _validate_variable_name(var, context): +def _validate_component_name(var, context): """Validate that a variable name matches ^[a-zA-Z_][a-zA-Z0-9_]*$.""" if not _IDENTIFIER_RE.match(var): if var and var[0].isdigit(): - raise ValueError(f"{context} variable name '{var}' cannot start with a digit") - raise ValueError(f"{context} variable name '{var}' contains invalid characters. Must match [a-zA-Z_][a-zA-Z0-9_]*") + raise ValueError(f"{context} component name '{var}' cannot start with a digit") + raise ValueError(f"{context} component name '{var}' contains invalid characters. Must match [a-zA-Z_][a-zA-Z0-9_]*") def _str_bound_to_bound(str_bound): diff --git a/tests/api_tests/test_files/example_rules.csv b/tests/api_tests/test_files/example_rules.csv new file mode 100644 index 00000000..dcb99f95 --- /dev/null +++ b/tests/api_tests/test_files/example_rules.csv @@ -0,0 +1,14 @@ +rule_text,name,infer_edges,set_static +"friend(A, B) <- knows(A, B)",friendship-rule,False,False +"enemy(A, B) <- ~friend(A, B)",enemy-rule,false,false +"ally(A, B) <- friend(A, B), common_interest(A, B)",ally-rule,0,0 +"likes(A, B) <- knows(A, B), friend(A, B)",likes-rule,True,False +"connected(A, B) <- link(A, B)",connected-rule,true,false +"trusted(A, B) <- friend(A, B), ally(A, B)",trusted-rule,1,0 +"close(A, B) <- friend(A, B), likes(A, B)",close-rule,yes,no +"popular(x) <- friend(x, y)",popular-rule,False,True +,empty-rule-text,False,False +InvalidRuleSyntax,bad-syntax,False,False +"friend(A, B) <- knows(A, B)",bad-infer,invalid,False +"friend(A, B) <- knows(A, B)",bad-static,False,invalid +"friend(A, B) <- knows(A, B)",,, diff --git a/tests/api_tests/test_files/example_rules.json b/tests/api_tests/test_files/example_rules.json new file mode 100644 index 00000000..c2ac0274 --- /dev/null +++ b/tests/api_tests/test_files/example_rules.json @@ -0,0 +1,100 @@ +[ + { + "rule_text": "friend(A, B) <- knows(A, B)", + "name": "friendship-rule", + "infer_edges": false, + "set_static": false + }, + { + "rule_text": "enemy(A, B) <- ~friend(A, B)", + "name": "enemy-rule", + "infer_edges": false, + "set_static": false + }, + { + "rule_text": "ally(A, B) <- friend(A, B), common_interest(A, B)", + "name": "ally-rule", + "infer_edges": false, + "set_static": false + }, + { + "rule_text": "likes(A, B) <- knows(A, B), friend(A, B)", + "name": "likes-rule", + "infer_edges": true, + "set_static": false + }, + { + "rule_text": "connected(A, B) <- link(A, B)", + "name": "connected-rule", + "infer_edges": true, + "set_static": false + }, + { + "rule_text": "trusted(A, B) <- friend(A, B), ally(A, B)", + "name": "trusted-rule", + "infer_edges": false, + "set_static": true + }, + { + "rule_text": "popular(x) <- friend(x, y)", + "name": "popular-rule", + "infer_edges": false, + "set_static": true + }, + { + "rule_text": "close(A, B) <- friend(A, B), likes(A, B)", + "name": "close-rule", + "custom_thresholds": [ + {"quantifier": "greater_equal", "quantifier_type": ["number", "total"], "thresh": 1}, + {"quantifier": "greater_equal", "quantifier_type": ["percent", "total"], "thresh": 100} + ], + "weights": [1.0, 2.0] + }, + { + "rule_text": "ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)", + "name": "threshold-rule", + "custom_thresholds": [ + {"quantifier": "greater_equal", "quantifier_type": ["number", "total"], "thresh": 1}, + {"quantifier": "greater_equal", "quantifier_type": ["percent", "total"], "thresh": 100} + ] + }, + { + "rule_text": "", + "name": "empty-rule-text" + }, + { + "rule_text": "InvalidRuleSyntax", + "name": "bad-syntax" + }, + { + "rule_text": "friend(A, B) <- knows(A, B)", + "name": "bad-infer", + "infer_edges": "invalid" + }, + { + "rule_text": "friend(A, B) <- knows(A, B)", + "name": "bad-static", + "set_static": "invalid" + }, + { + "rule_text": "friend(A, B) <- knows(A, B)", + "name": "bad-thresholds", + "custom_thresholds": "not-a-list" + }, + { + "rule_text": "friend(A, B) <- knows(A, B)", + "name": "bad-threshold-item", + "custom_thresholds": [ + {"quantifier": "invalid_quantifier", "quantifier_type": ["number", "total"], "thresh": 1} + ] + }, + { + "rule_text": "friend(A, B) <- knows(A, B)", + "name": "bad-weights", + "weights": "not-a-list" + }, + { + "rule_text": "friend(A, B) <- knows(A, B)" + }, + "not-an-object" +] diff --git a/tests/api_tests/test_files/example_rules_no_headers.csv b/tests/api_tests/test_files/example_rules_no_headers.csv new file mode 100644 index 00000000..d986180d --- /dev/null +++ b/tests/api_tests/test_files/example_rules_no_headers.csv @@ -0,0 +1,6 @@ +"friend(A, B) <- knows(A, B)",friendship-rule,False,False +"enemy(A, B) <- ~friend(A, B)",enemy-rule,False,False +"ally(A, B) <- friend(A, B), common_interest(A, B)",ally-rule,False,False +,empty-rule,False,False +InvalidRuleSyntax,bad-syntax,False,False +"friend(A, B) <- knows(A, B)",good-rule,invalid,False diff --git a/tests/api_tests/test_pyreason_file_loading.py b/tests/api_tests/test_pyreason_file_loading.py index 5d7b54bb..392cf97d 100644 --- a/tests/api_tests/test_pyreason_file_loading.py +++ b/tests/api_tests/test_pyreason_file_loading.py @@ -677,7 +677,129 @@ def test_partial_failure_recovery(self): graph2.add_edge('C', 'D') pr.load_graph(graph2) +class TestRuleTrace: + """Test save_rule_trace() and get_rule_trace() functions.""" + + def setup_method(self): + """Clean state before each test.""" + pr.reset() + pr.reset_settings() + + def test_save_rule_trace_with_store_interpretation_changes_disabled(self): + """Test save_rule_trace() with store_interpretation_changes disabled.""" + pr.settings.store_interpretation_changes = False + + # Create a simple interpretation (empty for this test) + interpretation = {} + + with pytest.raises(AssertionError, match='store interpretation changes setting is off'): + pr.save_rule_trace(interpretation) + + def test_get_rule_trace_with_store_interpretation_changes_disabled(self): + """Test get_rule_trace() with store_interpretation_changes disabled.""" + pr.settings.store_interpretation_changes = False + + # Create a simple interpretation (empty for this test) + interpretation = {} + + with pytest.raises(AssertionError, match='store interpretation changes setting is off'): + pr.get_rule_trace(interpretation) + + def test_save_rule_trace_with_store_interpretation_changes_enabled(self): + """Test save_rule_trace() with store_interpretation_changes enabled.""" + pr.settings.store_interpretation_changes = True + + # Create a simple graph and run reasoning to get an interpretation + graph = nx.DiGraph() + graph.add_edge('A', 'B') + pr.load_graph(graph) + + # Add a simple fact and rule + pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) + pr.add_rule(Rule('friend(A, B) <- person(A)', 'test_rule', False)) + + # Run reasoning to get interpretation + interpretation = pr.reason(1) + + # Test save_rule_trace with default folder + with tempfile.TemporaryDirectory() as temp_dir: + pr.save_rule_trace(interpretation, temp_dir) + # Check that files were created (exact files depend on implementation) + files_created = os.listdir(temp_dir) + assert len(files_created) > 0, "Expected files to be created in the trace folder" + + def test_save_rule_trace_with_custom_folder(self): + """Test save_rule_trace() with custom folder path.""" + pr.settings.store_interpretation_changes = True + + # Create a simple graph and run reasoning + graph = nx.DiGraph() + graph.add_edge('A', 'B') + pr.load_graph(graph) + + pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) + pr.add_rule(Rule('friend(A, B) <- person(A)', 'test_rule', False)) + + interpretation = pr.reason(1) + # Test with custom folder + with tempfile.TemporaryDirectory() as temp_dir: + custom_folder = os.path.join(temp_dir, 'custom_trace') + os.makedirs(custom_folder, exist_ok=True) + + pr.save_rule_trace(interpretation, custom_folder) + files_created = os.listdir(custom_folder) + assert len(files_created) > 0, "Expected files to be created in the custom trace folder" + + def test_get_rule_trace_with_store_interpretation_changes_enabled(self): + """Test get_rule_trace() with store_interpretation_changes enabled.""" + pr.settings.store_interpretation_changes = True + + # Create a simple graph and run reasoning + graph = nx.DiGraph() + graph.add_edge('A', 'B') + pr.load_graph(graph) + + pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) + pr.add_rule(Rule('friend(A, B) <- person(A)', 'test_rule', False)) + + interpretation = pr.reason(1) + + # Test get_rule_trace + node_trace, edge_trace = pr.get_rule_trace(interpretation) + + # Verify return types are DataFrames + assert isinstance(node_trace, pd.DataFrame), "Expected node_trace to be a pandas DataFrame" + assert isinstance(edge_trace, pd.DataFrame), "Expected edge_trace to be a pandas DataFrame" + + def test_get_rule_trace_returns_dataframes(self): + """Test that get_rule_trace() returns proper DataFrame structures.""" + pr.settings.store_interpretation_changes = True + + # Create a more complex scenario + graph = nx.DiGraph() + graph.add_edges_from([('A', 'B'), ('B', 'C'), ('C', 'A')]) + pr.load_graph(graph) + + # Add multiple facts and rules + pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) + pr.add_fact(pr.Fact('person(B)', 'B', 1, 1)) + pr.add_rule(Rule('friend(A, B) <- person(A)', 'rule1', False)) + pr.add_rule(Rule('likes(A, B) <- friend(A, B)', 'rule2', False)) + + interpretation = pr.reason(2) + + node_trace, edge_trace = pr.get_rule_trace(interpretation) + + # Basic structure verification + assert isinstance(node_trace, pd.DataFrame) + assert isinstance(edge_trace, pd.DataFrame) + + # DataFrames should have some basic expected structure + # (exact columns depend on implementation, but they should be valid DataFrames) + assert hasattr(node_trace, 'columns') + assert hasattr(edge_trace, 'columns') + class TestAddFactFromJSON: """Test add_fact_from_json() function for loading facts from JSON.""" @@ -986,166 +1108,6 @@ def test_add_fact_from_csv_empty_optional_fields(self): finally: os.unlink(tmp_path) - def test_add_fact_from_csv_nonexistent_file(self): - """Test add_fact_from_csv() with nonexistent file.""" - with pytest.raises(FileNotFoundError): - pr.add_fact_from_csv('nonexistent_facts.csv') - - def test_add_fact_from_csv_empty_file(self): - """Test loading facts from empty CSV file.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp: - tmp.write('') - tmp_path = tmp.name - - try: - # Empty file should trigger a warning - with pytest.warns(UserWarning, match="empty"): - pr.add_fact_from_csv(tmp_path) - finally: - os.unlink(tmp_path) - - def test_add_fact_from_csv_multiple_calls(self): - """Test multiple calls to add_fact_from_csv accumulate facts.""" - csv1_content = """Viewed(User1),fact1,0,3,False""" - csv2_content = """Viewed(User2),fact2,0,3,False""" - - with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp1: - tmp1.write(csv1_content) - tmp1_path = tmp1.name - - with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp2: - tmp2.write(csv2_content) - tmp2_path = tmp2.name - - try: - pr.add_fact_from_csv(tmp1_path) - pr.add_fact_from_csv(tmp2_path) - finally: - os.unlink(tmp1_path) - os.unlink(tmp2_path) - -class TestRuleTrace: - """Test save_rule_trace() and get_rule_trace() functions.""" - - def setup_method(self): - """Clean state before each test.""" - pr.reset() - pr.reset_settings() - - def test_save_rule_trace_with_store_interpretation_changes_disabled(self): - """Test save_rule_trace() with store_interpretation_changes disabled.""" - pr.settings.store_interpretation_changes = False - - # Create a simple interpretation (empty for this test) - interpretation = {} - - with pytest.raises(AssertionError, match='store interpretation changes setting is off'): - pr.save_rule_trace(interpretation) - - def test_get_rule_trace_with_store_interpretation_changes_disabled(self): - """Test get_rule_trace() with store_interpretation_changes disabled.""" - pr.settings.store_interpretation_changes = False - - # Create a simple interpretation (empty for this test) - interpretation = {} - - with pytest.raises(AssertionError, match='store interpretation changes setting is off'): - pr.get_rule_trace(interpretation) - - def test_save_rule_trace_with_store_interpretation_changes_enabled(self): - """Test save_rule_trace() with store_interpretation_changes enabled.""" - pr.settings.store_interpretation_changes = True - - # Create a simple graph and run reasoning to get an interpretation - graph = nx.DiGraph() - graph.add_edge('A', 'B') - pr.load_graph(graph) - - # Add a simple fact and rule - pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) - pr.add_rule(Rule('friend(A, B) <- person(A)', 'test_rule', False)) - - # Run reasoning to get interpretation - interpretation = pr.reason(1) - - # Test save_rule_trace with default folder - with tempfile.TemporaryDirectory() as temp_dir: - pr.save_rule_trace(interpretation, temp_dir) - # Check that files were created (exact files depend on implementation) - files_created = os.listdir(temp_dir) - assert len(files_created) > 0, "Expected files to be created in the trace folder" - - def test_save_rule_trace_with_custom_folder(self): - """Test save_rule_trace() with custom folder path.""" - pr.settings.store_interpretation_changes = True - - # Create a simple graph and run reasoning - graph = nx.DiGraph() - graph.add_edge('A', 'B') - pr.load_graph(graph) - - pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) - pr.add_rule(Rule('friend(A, B) <- person(A)', 'test_rule', False)) - - interpretation = pr.reason(1) - - # Test with custom folder - with tempfile.TemporaryDirectory() as temp_dir: - custom_folder = os.path.join(temp_dir, 'custom_trace') - os.makedirs(custom_folder, exist_ok=True) - - pr.save_rule_trace(interpretation, custom_folder) - files_created = os.listdir(custom_folder) - assert len(files_created) > 0, "Expected files to be created in the custom trace folder" - - def test_get_rule_trace_with_store_interpretation_changes_enabled(self): - """Test get_rule_trace() with store_interpretation_changes enabled.""" - pr.settings.store_interpretation_changes = True - - # Create a simple graph and run reasoning - graph = nx.DiGraph() - graph.add_edge('A', 'B') - pr.load_graph(graph) - - pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) - pr.add_rule(Rule('friend(A, B) <- person(A)', 'test_rule', False)) - - interpretation = pr.reason(1) - - # Test get_rule_trace - node_trace, edge_trace = pr.get_rule_trace(interpretation) - - # Verify return types are DataFrames - assert isinstance(node_trace, pd.DataFrame), "Expected node_trace to be a pandas DataFrame" - assert isinstance(edge_trace, pd.DataFrame), "Expected edge_trace to be a pandas DataFrame" - - def test_get_rule_trace_returns_dataframes(self): - """Test that get_rule_trace() returns proper DataFrame structures.""" - pr.settings.store_interpretation_changes = True - - # Create a more complex scenario - graph = nx.DiGraph() - graph.add_edges_from([('A', 'B'), ('B', 'C'), ('C', 'A')]) - pr.load_graph(graph) - - # Add multiple facts and rules - pr.add_fact(pr.Fact('person(A)', 'A', 1, 1)) - pr.add_fact(pr.Fact('person(B)', 'B', 1, 1)) - pr.add_rule(Rule('friend(A, B) <- person(A)', 'rule1', False)) - pr.add_rule(Rule('likes(A, B) <- friend(A, B)', 'rule2', False)) - - interpretation = pr.reason(2) - - node_trace, edge_trace = pr.get_rule_trace(interpretation) - - # Basic structure verification - assert isinstance(node_trace, pd.DataFrame) - assert isinstance(edge_trace, pd.DataFrame) - - # DataFrames should have some basic expected structure - # (exact columns depend on implementation, but they should be valid DataFrames) - assert hasattr(node_trace, 'columns') - assert hasattr(edge_trace, 'columns') class TestAddRulesFromFile: """Test add_rules_from_file() function.""" @@ -1311,9 +1273,446 @@ def test_add_rules_from_file_after_existing_rules(self): finally: os.unlink(tmp_path) + def test_add_rules_from_file_raise_errors_true(self): + """Test that raise_errors=True raises on invalid rules.""" + rules_content = """friend(A, B) <- knows(A, B) +InvalidRuleSyntax""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: + tmp.write(rules_content) + tmp_path = tmp.name + + try: + with pytest.raises(ValueError, match="Failed to parse rule"): + pr.add_rules_from_file(tmp_path, raise_errors=True) + finally: + os.unlink(tmp_path) + + def test_add_rules_from_file_raise_errors_false_warns(self): + """Test that raise_errors=False warns on invalid rules and continues.""" + rules_content = """friend(A, B) <- knows(A, B) +InvalidRuleSyntax +enemy(A, B) <- ~friend(A, B)""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: + tmp.write(rules_content) + tmp_path = tmp.name + + try: + with pytest.warns(UserWarning, match="Failed to parse rule"): + pr.add_rules_from_file(tmp_path, raise_errors=False) + # Should have loaded 2 valid rules despite the invalid one + rules = pr.get_rules() + assert len(rules) == 2 + finally: + os.unlink(tmp_path) + + +class TestAddRuleFromCSV: + """Test add_rule_from_csv() function for loading rules from CSV.""" + + def setup_method(self): + """Clean state before each test.""" + pr.reset() + pr.reset_settings() + + def test_add_rule_from_csv_comprehensive(self): + """Test loading rules from CSV with various valid and invalid scenarios. + + This test uses example_rules.csv which contains: + - Valid rules with various boolean formats for infer_edges/set_static + - Multi-clause rules with quoted rule_text + - Empty rule_text (should warn) + - Invalid syntax (should warn) + - Invalid infer_edges value (should warn) + - Invalid set_static value (should warn) + - Empty optional fields + """ + csv_path = os.path.join(os.path.dirname(__file__), 'test_files', 'example_rules.csv') + + with pytest.warns(UserWarning) as warning_list: + pr.add_rule_from_csv(csv_path, raise_errors=False) + + # Verify that we got warnings from the invalid rows: + # - Row 10: empty rule_text -> "Missing required 'rule_text'" + # - Row 11: invalid syntax -> "Failed to parse rule" + # - Row 12: invalid infer_edges -> "Invalid infer_edges value" + # - Row 13: invalid set_static -> "Invalid set_static value" + assert len(warning_list) >= 4, f"Expected at least 4 warnings, got {len(warning_list)}: {[str(w.message) for w in warning_list]}" + + warning_messages = [str(w.message) for w in warning_list] + + assert any("Missing required 'rule_text'" in msg for msg in warning_messages), \ + "Expected warning about missing rule_text" + + assert any("Failed to parse rule" in msg for msg in warning_messages), \ + "Expected warning about invalid syntax" + + assert any("Invalid infer_edges value" in msg for msg in warning_messages), \ + "Expected warning about invalid infer_edges" + + assert any("Invalid set_static value" in msg for msg in warning_messages), \ + "Expected warning about invalid set_static" + + def test_add_rule_from_csv_duplicate_names_raises_error(self): + """Test that duplicate rule names in CSV raise error when raise_errors=True.""" + csv_content = """"friend(A, B) <- knows(A, B)",duplicate-name,False,False +"enemy(A, B) <- ~friend(A, B)",duplicate-name,False,False""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp: + tmp.write(csv_content) + tmp_path = tmp.name + + try: + with pytest.raises(ValueError, match="duplicate"): + pr.add_rule_from_csv(tmp_path, raise_errors=True) + finally: + os.unlink(tmp_path) + + def test_add_rule_from_csv_duplicate_names_warns(self): + """Test that duplicate rule names in CSV warn when raise_errors=False.""" + csv_content = """"friend(A, B) <- knows(A, B)",duplicate-name,False,False +"enemy(A, B) <- ~friend(A, B)",duplicate-name,False,False""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp: + tmp.write(csv_content) + tmp_path = tmp.name + + try: + with pytest.warns(UserWarning, match="duplicate"): + pr.add_rule_from_csv(tmp_path, raise_errors=False) + finally: + os.unlink(tmp_path) + + def test_add_rule_from_csv_no_header_file(self): + """Test loading rules from CSV file without header using example_rules_no_headers.csv. + + The file contains: + - Row 1-3: Valid rules + - Row 4: empty rule_text - should warn + - Row 5: invalid syntax - should warn + - Row 6: invalid infer_edges - should warn + """ + csv_path = os.path.join(os.path.dirname(__file__), 'test_files', 'example_rules_no_headers.csv') + + with pytest.warns(UserWarning) as warning_list: + pr.add_rule_from_csv(csv_path, raise_errors=False) + + assert len(warning_list) >= 3, f"Expected at least 3 warnings, got {len(warning_list)}: {[str(w.message) for w in warning_list]}" + + warning_messages = [str(w.message) for w in warning_list] + + assert any("Missing required 'rule_text'" in msg for msg in warning_messages), \ + "Expected warning about missing rule_text" + + assert any("Failed to parse rule" in msg for msg in warning_messages), \ + "Expected warning about invalid syntax" + + assert any("Invalid infer_edges value" in msg for msg in warning_messages), \ + "Expected warning about invalid infer_edges" + + def test_add_rule_from_csv_empty_optional_fields(self): + """Test loading rules with empty optional fields.""" + csv_content = """rule_text,name,infer_edges,set_static +"friend(A, B) <- knows(A, B)",,, +"enemy(A, B) <- ~friend(A, B)",enemy-rule,,""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp: + tmp.write(csv_content) + tmp_path = tmp.name + + try: + pr.add_rule_from_csv(tmp_path) + finally: + os.unlink(tmp_path) + + def test_add_rule_from_csv_nonexistent_file(self): + """Test add_rule_from_csv() with nonexistent file.""" + with pytest.raises(FileNotFoundError): + pr.add_rule_from_csv('nonexistent_rules.csv') + + def test_add_rule_from_csv_empty_file(self): + """Test loading rules from empty CSV file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp: + tmp.write('') + tmp_path = tmp.name + + try: + with pytest.warns(UserWarning, match="is empty"): + pr.add_rule_from_csv(tmp_path) + finally: + os.unlink(tmp_path) + + def test_add_rule_from_csv_multiple_calls(self): + """Test multiple calls to add_rule_from_csv accumulate rules.""" + csv1_content = """"friend(A, B) <- knows(A, B)",rule1,False,False""" + csv2_content = """"enemy(A, B) <- ~friend(A, B)",rule2,False,False""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp1: + tmp1.write(csv1_content) + tmp1_path = tmp1.name + + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp2: + tmp2.write(csv2_content) + tmp2_path = tmp2.name + + try: + pr.add_rule_from_csv(tmp1_path) + pr.add_rule_from_csv(tmp2_path) + rules = pr.get_rules() + assert len(rules) == 2 + finally: + os.unlink(tmp1_path) + os.unlink(tmp2_path) + + def test_add_rule_from_csv_boolean_formats(self): + """Test various boolean formats for infer_edges and set_static.""" + csv_content = """rule_text,name,infer_edges,set_static +"friend(A, B) <- knows(A, B)",r1,True,False +"enemy(A, B) <- ~friend(A, B)",r2,true,false +"ally(A, B) <- friend(A, B)",r3,1,0 +"likes(A, B) <- knows(A, B)",r4,yes,no +"connected(A, B) <- link(A, B)",r5,t,f +"popular(x) <- friend(x, y)",r6,y,n""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp: + tmp.write(csv_content) + tmp_path = tmp.name - def test_add_inconsistent_predicates(self): - """Test adding inconsistent predicate pairs""" - pr.add_inconsistent_predicate("pred1", "pred2") - pr.add_inconsistent_predicate("pred3", "pred4") - # Should not raise exceptions \ No newline at end of file + try: + pr.add_rule_from_csv(tmp_path) + rules = pr.get_rules() + assert len(rules) == 6 + finally: + os.unlink(tmp_path) + + +class TestAddRuleFromJSON: + """Test add_rule_from_json() function for loading rules from JSON.""" + + def setup_method(self): + """Clean state before each test.""" + pr.reset() + pr.reset_settings() + + def test_add_rule_from_json_comprehensive(self): + """Test loading rules from JSON with various valid and invalid scenarios. + + This test uses example_rules.json which contains: + - Valid rules with various settings + - Rules with custom_thresholds and weights + - Empty rule_text (should warn) + - Invalid syntax (should warn) + - Invalid infer_edges value (should warn) + - Invalid set_static value (should warn) + - Invalid custom_thresholds format (should warn) + - Invalid threshold item (should warn) + - Invalid weights format (should warn) + - Empty optional fields + - Non-object item (should warn) + """ + json_path = os.path.join(os.path.dirname(__file__), 'test_files', 'example_rules.json') + + with pytest.warns(UserWarning) as warning_list: + pr.add_rule_from_json(json_path, raise_errors=False) + + # Verify warnings from invalid items: + # - Item 9: empty rule_text + # - Item 10: invalid syntax + # - Item 11: invalid infer_edges + # - Item 12: invalid set_static + # - Item 13: invalid custom_thresholds (not a list) + # - Item 14: invalid threshold item (bad quantifier) + # - Item 15: invalid weights (not a list) + # - Item 17: not an object + assert len(warning_list) >= 8, f"Expected at least 8 warnings, got {len(warning_list)}: {[str(w.message) for w in warning_list]}" + + warning_messages = [str(w.message) for w in warning_list] + + assert any("Missing required 'rule_text'" in msg for msg in warning_messages), \ + "Expected warning about missing rule_text" + + assert any("Failed to parse rule" in msg for msg in warning_messages), \ + "Expected warning about invalid syntax" + + assert any("Invalid infer_edges value" in msg for msg in warning_messages), \ + "Expected warning about invalid infer_edges" + + assert any("Invalid set_static value" in msg for msg in warning_messages), \ + "Expected warning about invalid set_static" + + assert any("custom_thresholds" in msg for msg in warning_messages), \ + "Expected warning about invalid custom_thresholds" + + assert any("weights" in msg for msg in warning_messages), \ + "Expected warning about invalid weights" + + assert any("Expected object" in msg for msg in warning_messages), \ + "Expected warning about non-object item" + + def test_add_rule_from_json_duplicate_names_raises_error(self): + """Test that duplicate rule names in JSON raise error when raise_errors=True.""" + json_content = """[ + {"rule_text": "friend(A, B) <- knows(A, B)", "name": "duplicate-name"}, + {"rule_text": "enemy(A, B) <- ~friend(A, B)", "name": "duplicate-name"} + ]""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: + tmp.write(json_content) + tmp_path = tmp.name + + try: + with pytest.raises(ValueError, match="duplicate"): + pr.add_rule_from_json(tmp_path, raise_errors=True) + finally: + os.unlink(tmp_path) + + def test_add_rule_from_json_duplicate_names_warns(self): + """Test that duplicate rule names in JSON warn when raise_errors=False.""" + json_content = """[ + {"rule_text": "friend(A, B) <- knows(A, B)", "name": "duplicate-name"}, + {"rule_text": "enemy(A, B) <- ~friend(A, B)", "name": "duplicate-name"} + ]""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: + tmp.write(json_content) + tmp_path = tmp.name + + try: + with pytest.warns(UserWarning, match="duplicate"): + pr.add_rule_from_json(tmp_path, raise_errors=False) + finally: + os.unlink(tmp_path) + + def test_add_rule_from_json_nonexistent_file(self): + """Test add_rule_from_json() with nonexistent file.""" + with pytest.raises(FileNotFoundError): + pr.add_rule_from_json('nonexistent_rules.json') + + def test_add_rule_from_json_empty_array(self): + """Test loading rules from JSON file with empty array.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: + tmp.write('[]') + tmp_path = tmp.name + + try: + with pytest.warns(UserWarning, match="contains an empty array"): + pr.add_rule_from_json(tmp_path) + finally: + os.unlink(tmp_path) + + def test_add_rule_from_json_invalid_json(self): + """Test loading rules from invalid JSON file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: + tmp.write('{ invalid json }') + tmp_path = tmp.name + + try: + with pytest.raises(ValueError, match="Invalid JSON format"): + pr.add_rule_from_json(tmp_path) + finally: + os.unlink(tmp_path) + + def test_add_rule_from_json_not_array(self): + """Test loading rules from JSON file that's not an array.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: + tmp.write('{"rule_text": "friend(A, B) <- knows(A, B)"}') + tmp_path = tmp.name + + try: + with pytest.raises(ValueError, match="must contain an array"): + pr.add_rule_from_json(tmp_path) + finally: + os.unlink(tmp_path) + + def test_add_rule_from_json_multiple_calls(self): + """Test multiple calls to add_rule_from_json accumulate rules.""" + json1 = """[{"rule_text": "friend(A, B) <- knows(A, B)", "name": "rule1"}]""" + json2 = """[{"rule_text": "enemy(A, B) <- ~friend(A, B)", "name": "rule2"}]""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp1: + tmp1.write(json1) + tmp1_path = tmp1.name + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp2: + tmp2.write(json2) + tmp2_path = tmp2.name + + try: + pr.add_rule_from_json(tmp1_path) + pr.add_rule_from_json(tmp2_path) + rules = pr.get_rules() + assert len(rules) == 2 + finally: + os.unlink(tmp1_path) + os.unlink(tmp2_path) + + def test_add_rule_from_json_with_custom_thresholds(self): + """Test loading rules with custom thresholds from JSON.""" + json_content = """[ + { + "rule_text": "ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)", + "name": "threshold-rule", + "custom_thresholds": [ + {"quantifier": "greater_equal", "quantifier_type": ["number", "total"], "thresh": 1}, + {"quantifier": "greater_equal", "quantifier_type": ["percent", "total"], "thresh": 100} + ] + } + ]""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: + tmp.write(json_content) + tmp_path = tmp.name + + try: + pr.add_rule_from_json(tmp_path) + rules = pr.get_rules() + assert len(rules) == 1 + finally: + os.unlink(tmp_path) + + def test_add_rule_from_json_with_weights(self): + """Test loading rules with weights from JSON.""" + json_content = """[ + { + "rule_text": "close(A, B) <- friend(A, B), likes(A, B)", + "name": "weighted-rule", + "weights": [1.0, 2.0] + } + ]""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: + tmp.write(json_content) + tmp_path = tmp.name + + try: + pr.add_rule_from_json(tmp_path) + rules = pr.get_rules() + assert len(rules) == 1 + finally: + os.unlink(tmp_path) + + def test_add_rule_from_json_with_thresholds_and_weights(self): + """Test loading rules with both custom thresholds and weights.""" + json_content = """[ + { + "rule_text": "close(A, B) <- friend(A, B), likes(A, B)", + "name": "full-rule", + "custom_thresholds": [ + {"quantifier": "greater_equal", "quantifier_type": ["number", "total"], "thresh": 1}, + {"quantifier": "greater_equal", "quantifier_type": ["percent", "total"], "thresh": 100} + ], + "weights": [1.0, 2.0] + } + ]""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: + tmp.write(json_content) + tmp_path = tmp.name + + try: + pr.add_rule_from_json(tmp_path) + rules = pr.get_rules() + assert len(rules) == 1 + finally: + os.unlink(tmp_path) \ No newline at end of file