diff --git a/mas_arena/evaluators/aime_evaluator.py b/mas_arena/evaluators/aime_evaluator.py index 4d40c1b..076e8f0 100644 --- a/mas_arena/evaluators/aime_evaluator.py +++ b/mas_arena/evaluators/aime_evaluator.py @@ -4,22 +4,14 @@ Standalone evaluator for AIME-style math problems. """ -import re -import time -from typing import Dict, Any, Tuple + +from typing import Dict, Any from pathlib import Path -from math import isclose -from sympy import N, simplify -from sympy.parsing.latex import parse_latex -from sympy.parsing.sympy_parser import parse_expr from langsmith.evaluation import RunEvaluator -from langsmith.schemas import Run -from mas_arena.evaluators.base_evaluator import BaseEvaluator +from mas_arena.evaluators.math_evaluator import MathEvaluator from mas_arena.evaluators.registry import register_benchmark -from mas_arena.evaluators.utils.math_equal import calculate_score -from mas_arena.evaluators.utils import extract_answer_numeric @register_benchmark( @@ -29,7 +21,7 @@ "solution": "answer", } ) -class AIMEEvaluator(BaseEvaluator): +class AIMEEvaluator(MathEvaluator): """ Evaluator for AIME-style math problems. @@ -58,179 +50,6 @@ def __init__(self, name: str = "aime", config: Dict[str, Any] = None): def from_config(cls, name: str, config: Dict[str, Any] = None): return cls(name, config) - def extract_answer(self, text: str) -> str: - if self.evaluate_type == 0: - return self.simple_extract_answer(text) - else: - return self.math_extract_answer(text) - - def math_extract_answer(self, text: str) -> str: - """ - Extract the answer from model output text (last number or string). - """ - return extract_answer_numeric(text) - - def simple_extract_answer(self, text: str) -> str: - """ - Extract the answer from model output text, looking for boxed answers or final statements. - - Args: - text: The model's output text - - Returns: - The extracted answer - """ - # Look for LaTeX boxed answers first - pattern = r"\\boxed{((?:[^{}]|{[^{}]*})*)}" - boxed_matches = re.findall(pattern, text, re.DOTALL) - if boxed_matches: - return boxed_matches[-1].strip() - - # For AIME, also look for 3-digit numbers (000-999) - number_pattern = r"\b\d{3}\b" - number_matches = re.findall(number_pattern, text) - if number_matches: - return number_matches[-1] - - # If no boxed answer, try to extract the final conclusion - sentence_end_pattern = r"(? Tuple[int, str]: - """ - Calculate a score by comparing the expected and predicted answers. - - Args: - expected_output: The expected answer (solution) - prediction: The model's prediction - - Returns: - Tuple of (score, extracted_answer) where score is 1 for correct, 0 for incorrect - """ - extracted_expected = self.extract_answer(expected_output) - extracted_prediction = self.extract_answer(prediction) - - if self.math_equal(extracted_prediction, extracted_expected): - return 1, extracted_prediction - else: - return 0, extracted_prediction - - def calculate_score(self, expected_output: str, prediction: str) -> Tuple[int, str]: - return calculate_score(expected_output, prediction) - - def math_equal(self, prediction: Any, reference: Any) -> bool: - """ - Check if two mathematical expressions are equivalent. - - Args: - prediction: The predicted answer - reference: The reference answer - - Returns: - True if the expressions are equivalent, False otherwise - """ - # Direct string comparison - if str(prediction) == str(reference): - return True - - # For AIME, treat as integers and compare - try: - pred_int = int(prediction) - ref_int = int(reference) - return pred_int == ref_int - except ValueError: - pass - - # Numeric comparison - try: - if self.is_digit(prediction) and self.is_digit(reference): - prediction_val = self.parse_digits(prediction) - reference_val = self.parse_digits(reference) - return isclose(prediction_val, reference_val, abs_tol=1e-3) - except ValueError: - pass - - # Symbolic comparison - try: - return self.symbolic_equal(prediction, reference) - except Exception: - pass - - return False - - def is_digit(self, num): - """Check if a string can be parsed as a number""" - return self.parse_digits(num) is not None - - def parse_digits(self, num): - """Parse a string as a number, handling percentage and commas""" - num = str(num).replace(",", "") - try: - return float(num) - except ValueError: - if num.endswith("%"): - num = num[:-1] - if num.endswith("\\"): - num = num[:-1] - try: - return float(num) / 100 - except ValueError: - pass - return None - - def symbolic_equal(self, a, b): - """Check symbolic equality using SymPy""" - def _parse(s): - for f in [parse_latex, parse_expr]: - try: - return f(s) - except Exception: - pass - return s - - a = _parse(a) - b = _parse(b) - - try: - if simplify(a - b) == 0: - return True - except Exception: - pass - - try: - if isclose(N(a), N(b), abs_tol=1e-3): - return True - except Exception: - pass - - return False - - def extract_final_answer(self, messages: list) -> str: - """ - Extract the final answer from a list of messages. - - Args: - messages: List of messages from the agent conversation - - Returns: - The extracted final answer - """ - final_answer = "" - - if not messages: - return final_answer - - last_msg = messages[-1] - if isinstance(last_msg, tuple) and len(last_msg) > 1: - final_answer = last_msg[1] - elif hasattr(last_msg, "content"): - final_answer = last_msg.content - elif isinstance(last_msg, dict) and "content" in last_msg: - final_answer = last_msg["content"] - - return final_answer def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[str, Any]: """ @@ -245,14 +64,14 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[ """ # Extract the final answer from messages all_messages = run_result.get("messages", []) - final_answer = self.extract_final_answer(all_messages) + final_answer = super().extract_final_answer(all_messages) if self.evaluate_type == 0: # Use the new calculate_score method - score, extracted_answer = self.simple_calculate_score(problem["solution"], final_answer) + score, extracted_answer = super().simple_calculate_score(problem["solution"], final_answer) else: # Use the new calculate_score method - score, extracted_answer = self.calculate_score(problem["solution"], final_answer) + score, extracted_answer = super().calculate_score(problem["solution"], final_answer) # Return evaluation results return { diff --git a/mas_arena/evaluators/math_evaluator.py b/mas_arena/evaluators/math_evaluator.py index 73848cc..2b3f19b 100644 --- a/mas_arena/evaluators/math_evaluator.py +++ b/mas_arena/evaluators/math_evaluator.py @@ -3,15 +3,13 @@ This module provides a standalone evaluator for mathematical problems. """ -import asyncio import re import time -from typing import Dict, Any, Optional, List, Callable, Tuple +from typing import Dict, Any, Optional, List, Tuple from pathlib import Path from math import isclose from sympy import N, simplify -from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr from langsmith.evaluation import RunEvaluator from langsmith.schemas import Run @@ -19,9 +17,6 @@ from mas_arena.evaluators.base_evaluator import BaseEvaluator from mas_arena.evaluators.registry import register_benchmark from mas_arena.evaluators.utils.math_equal import calculate_score -from mas_arena.evaluators.utils.normalization import normalize_problem_keys - -# change @register_benchmark( name="math", @@ -31,7 +26,6 @@ "solution": "solution", } ) - class MathEvaluator(BaseEvaluator): """ Math Evaluator for evaluating math problems. @@ -135,7 +129,9 @@ def math_equal(self, prediction: Any, reference: Any) -> bool: if self.is_digit(prediction) and self.is_digit(reference): prediction_val = self.parse_digits(prediction) reference_val = self.parse_digits(reference) - return isclose(prediction_val, reference_val, abs_tol=1e-3) + # Check if both values are not None before using isclose + if prediction_val is not None and reference_val is not None: + return isclose(prediction_val, reference_val, abs_tol=1e-3) except ValueError: pass @@ -149,11 +145,31 @@ def math_equal(self, prediction: Any, reference: Any) -> bool: def is_digit(self, num): """Check if a string can be parsed as a number""" + num = str(num) + + # Handle numbers with commas (like {14{,}916}) + comma_pattern = r'\{(\d+)\{,\}(\d+)\}' + if re.search(comma_pattern, num): + num = re.sub(comma_pattern, r'\1\2', num) + elif "{,}" in num: + num = num.replace("{,}", "") + return self.parse_digits(num) is not None def parse_digits(self, num): """Parse a string as a number, handling percentage and commas""" - num = str(num).replace(",", "") + num = str(num) + + # Handle numbers with commas (like {14{,}916}) + comma_pattern = r'\{(\d+)\{,\}(\d+)\}' + if re.search(comma_pattern, num): + num = re.sub(comma_pattern, r'\1\2', num) + elif "{,}" in num: + num = num.replace("{,}", "") + + # Handle simple commas + num = num.replace(",", "") + try: return float(num) except ValueError: @@ -167,29 +183,127 @@ def parse_digits(self, num): pass return None + def latex_to_sympy(self, latex_str: str): + """ + Convert a LaTeX string to a SymPy-parsable expression. + This function handles various LaTeX commands and environments to make the string + compatible with SymPy's parser. + """ + latex_str = str(latex_str).strip() + + # Handle matrix/vector environments by extracting their content + pmatrix_match = re.search(r'\\begin{pmatrix}(.*?)\\end{pmatrix}', latex_str, re.DOTALL) + if pmatrix_match: + content = pmatrix_match.group(1).strip() + # Split by \\ and & and filter out empty strings + elements = [elem.strip() for elem in re.split(r'\\\\|&', content) if elem.strip()] + return f"({', '.join(elements)})" + + # Remove other environments + latex_str = re.sub(r'\\begin{asy}.*?\\end{asy}', '', latex_str, flags=re.DOTALL) + latex_str = re.sub(r'\\begin{tabular}.*?\\end{tabular}', '', latex_str, flags=re.DOTALL) + latex_str = re.sub(r'\\begin{align\*}.*?\\end{align\*}', '', latex_str, flags=re.DOTALL) + latex_str = re.sub(r'\\begin{align}.*?\\end{align}', '', latex_str, flags=re.DOTALL) + + # Handle \boxed{} + match = re.search(r'\\boxed{(.*)}', latex_str) + if match: + latex_str = match.group(1) + + # Remove text commands + latex_str = re.sub(r'\\text(normal)?\{.*?\}', '', latex_str) + latex_str = re.sub(r'\\textit\{.*?\}', '', latex_str) + + # Handle fractions + latex_str = re.sub(r'\\(d)?frac\{([^}]+)\}\{([^}]+)\}', r'((\2)/(\3))', latex_str) + + # Replacements for known LaTeX commands + replacements = { + r'\sin': 'sin', r'\cos': 'cos', r'\tan': 'tan', + r'\log': 'log', r'\ln': 'ln', + r'\sqrt': 'sqrt', r'\pi': 'pi', + r'\left': '', r'\right': '', + r'\cdot': '*', r'\times': '*', + r'\%': '/100', + r'^{\circ}': '', r'^\circ': '', + r'\$': '', r'\\,': '', r'\\!': '', r'\\#': '', + r'\allowbreak': '' + } + for old, new in replacements.items(): + latex_str = latex_str.replace(old, new) + + # Remove any other LaTeX commands that are just names + latex_str = re.sub(r'\\[a-zA-Z]+', '', latex_str) + + # remove subscripts like _{...} or _b + latex_str = re.sub(r'(_\{.*?\}|_b)', '', latex_str) + + # Handle numbers with commas + latex_str = re.sub(r'(\d),(\d)', r'\1\2', latex_str) + + # Handle repeating decimals + overline_match = re.search(r'(\d+)\.\\overline\{(\d+)\}', latex_str) + if overline_match: + integer_part = overline_match.group(1) + repeating_part = overline_match.group(2) + num = int(integer_part + repeating_part) - int(integer_part) + den = 10**len(repeating_part) - 1 + latex_str = f'({num}/{den})' + + overline_match = re.search(r'0\.\\overline\{(\d+)\}', latex_str) + if overline_match: + repeating_part = overline_match.group(1) + latex_str = f'({repeating_part}/(10**{len(repeating_part)}-1))' + + # Final cleanup of braces and backslashes + latex_str = latex_str.replace('{', '(').replace('}', ')').replace('\\', '').replace(' ', '') + + # Cleanup mismatched parentheses + while '()' in latex_str: + latex_str = latex_str.replace('()', '') + + return latex_str.strip() + def symbolic_equal(self, a, b): """Check symbolic equality using SymPy""" def _parse(s): - for f in [parse_latex, parse_expr]: - try: - return f(s) - except Exception: - pass - return s + s_str = str(s).strip() + if not s_str: + return None + + try: + latex_converted = self.latex_to_sympy(s_str) + if not latex_converted: + return None + return parse_expr(latex_converted) + except Exception: + # Fallback to direct parsing if latex conversion fails + pass + + try: + return parse_expr(s_str) + except Exception: + return None + + a_parsed = _parse(a) + b_parsed = _parse(b) - a = _parse(a) - b = _parse(b) + if a_parsed is None or b_parsed is None: + return False try: - if simplify(a - b) == 0: + # Simplify the difference + if simplify(a_parsed - b_parsed) == 0: return True except Exception: pass try: - if isclose(N(a), N(b), abs_tol=1e-3): + # Numerical evaluation + if isclose(N(a_parsed), N(b_parsed), abs_tol=1e-3): return True - except Exception: + except (TypeError, ValueError, Exception): + # This can fail if expressions are not numeric pass return False @@ -266,7 +380,7 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[ # Extract the final answer from messages all_messages = run_result.get("messages", []) final_answer = self.extract_final_answer(all_messages) - + if self.evaluate_type == 0: # Use the new calculate_score method score, extracted_answer = self.simple_calculate_score(problem["solution"], final_answer) @@ -284,3 +398,4 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[ "score": score, "extracted_answer": extracted_answer } +