From fd28a43cf7e057af8f1f3ffa76b2b591ac42268c Mon Sep 17 00:00:00 2001 From: fuzi233 Date: Thu, 31 Jul 2025 21:37:27 +0800 Subject: [PATCH 1/4] modified: mas_arena/evaluators/math_evaluator.py --- mas_arena/evaluators/math_evaluator.py | 95 +++++++++++++++++++++++--- 1 file changed, 87 insertions(+), 8 deletions(-) diff --git a/mas_arena/evaluators/math_evaluator.py b/mas_arena/evaluators/math_evaluator.py index 73848cc..5e6df97 100644 --- a/mas_arena/evaluators/math_evaluator.py +++ b/mas_arena/evaluators/math_evaluator.py @@ -135,7 +135,9 @@ def math_equal(self, prediction: Any, reference: Any) -> bool: if self.is_digit(prediction) and self.is_digit(reference): prediction_val = self.parse_digits(prediction) reference_val = self.parse_digits(reference) - return isclose(prediction_val, reference_val, abs_tol=1e-3) + # Check if both values are not None before using isclose + if prediction_val is not None and reference_val is not None: + return isclose(prediction_val, reference_val, abs_tol=1e-3) except ValueError: pass @@ -149,11 +151,32 @@ def math_equal(self, prediction: Any, reference: Any) -> bool: def is_digit(self, num): """Check if a string can be parsed as a number""" + num = str(num) + + # Handle numbers with commas (like {14{,}916}) + comma_pattern = r'\{(\d+)\{,\}(\d+)\}' + if re.search(comma_pattern, num): + num = re.sub(comma_pattern, r'\1\2', num) + elif "{,}" in num: + num = num.replace("{,}", "") + + print("parse_digits:", num) return self.parse_digits(num) is not None def parse_digits(self, num): """Parse a string as a number, handling percentage and commas""" - num = str(num).replace(",", "") + num = str(num) + + # Handle numbers with commas (like {14{,}916}) + comma_pattern = r'\{(\d+)\{,\}(\d+)\}' + if re.search(comma_pattern, num): + num = re.sub(comma_pattern, r'\1\2', num) + elif "{,}" in num: + num = num.replace("{,}", "") + + # Handle simple commas + num = num.replace(",", "") + try: return float(num) except ValueError: @@ -167,18 +190,62 @@ def parse_digits(self, num): pass return None + def latex_to_sympy(self, latex_str): + """Convert LaTeX to SymPy expression""" + latex_str = str(latex_str).strip() + + # Handle common LaTeX patterns + # Remove outer braces + if latex_str.startswith('{') and latex_str.endswith('}'): + latex_str = latex_str[1:-1] + + # Handle fractions + frac_pattern = r'\\frac\{([^}]+)\}\{([^}]+)\}' + if re.search(frac_pattern, latex_str): + latex_str = re.sub(frac_pattern, r'(\1)/(\2)', latex_str) + + # Handle dfrac (same as frac) + dfrac_pattern = r'\\dfrac\{([^}]+)\}\{([^}]+)\}' + if re.search(dfrac_pattern, latex_str): + latex_str = re.sub(dfrac_pattern, r'(\1)/(\2)', latex_str) + + # Handle numbers with commas (like {14{,}916}) + comma_pattern = r'\{(\d+)\{,\}(\d+)\}' + if re.search(comma_pattern, latex_str): + latex_str = re.sub(comma_pattern, r'\1\2', latex_str) + + # Handle simple numbers in braces + num_brace_pattern = r'\{(\d+)\}' + if re.search(num_brace_pattern, latex_str): + latex_str = re.sub(num_brace_pattern, r'\1', latex_str) + + return latex_str + def symbolic_equal(self, a, b): """Check symbolic equality using SymPy""" def _parse(s): - for f in [parse_latex, parse_expr]: - try: - return f(s) - except Exception: - pass + s_str = str(s) + + # Try to parse as LaTeX first + try: + latex_converted = self.latex_to_sympy(s_str) + return parse_expr(latex_converted) + except Exception: + pass + + # Try direct parsing + try: + return parse_expr(s_str) + except Exception as e: + print("error:", str(e)) + pass + return s a = _parse(a) b = _parse(b) + print("a:", a) + print("b:", b) try: if simplify(a - b) == 0: @@ -266,7 +333,7 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[ # Extract the final answer from messages all_messages = run_result.get("messages", []) final_answer = self.extract_final_answer(all_messages) - + if self.evaluate_type == 0: # Use the new calculate_score method score, extracted_answer = self.simple_calculate_score(problem["solution"], final_answer) @@ -284,3 +351,15 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[ "score": score, "extracted_answer": extracted_answer } + +if __name__ == "__main__": + evaluator = MathEvaluator("math") + data_1 = "{\frac{1}{2}}" + data_2 = "1/2" + data_3 = "{\dfrac{1}{2}}" + print(data_1) + print(data_2) + print(data_3) + print(evaluator.math_equal(data_1, data_2)) + print(evaluator.math_equal(data_1, data_3)) + print(evaluator.math_equal(data_2, data_3)) From e91c86b3ff5f0327d74225759dc286111b8b28ed Mon Sep 17 00:00:00 2001 From: fuzi233 Date: Fri, 1 Aug 2025 16:28:37 +0800 Subject: [PATCH 2/4] modified: mas_arena/evaluators/math_evaluator.py --- mas_arena/evaluators/math_evaluator.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/mas_arena/evaluators/math_evaluator.py b/mas_arena/evaluators/math_evaluator.py index 5e6df97..2fc972b 100644 --- a/mas_arena/evaluators/math_evaluator.py +++ b/mas_arena/evaluators/math_evaluator.py @@ -160,7 +160,6 @@ def is_digit(self, num): elif "{,}" in num: num = num.replace("{,}", "") - print("parse_digits:", num) return self.parse_digits(num) is not None def parse_digits(self, num): @@ -244,8 +243,6 @@ def _parse(s): a = _parse(a) b = _parse(b) - print("a:", a) - print("b:", b) try: if simplify(a - b) == 0: @@ -352,14 +349,3 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[ "extracted_answer": extracted_answer } -if __name__ == "__main__": - evaluator = MathEvaluator("math") - data_1 = "{\frac{1}{2}}" - data_2 = "1/2" - data_3 = "{\dfrac{1}{2}}" - print(data_1) - print(data_2) - print(data_3) - print(evaluator.math_equal(data_1, data_2)) - print(evaluator.math_equal(data_1, data_3)) - print(evaluator.math_equal(data_2, data_3)) From 9b2057bf157e884bf1b091876a47342c37a82381 Mon Sep 17 00:00:00 2001 From: fuzi233 Date: Tue, 5 Aug 2025 10:40:51 +0800 Subject: [PATCH 3/4] modified: mas_arena/evaluators/aime_evaluator.py modified: mas_arena/evaluators/math_evaluator.py --- mas_arena/evaluators/aime_evaluator.py | 199 +------------------------ mas_arena/evaluators/math_evaluator.py | 132 +++++++++++----- 2 files changed, 101 insertions(+), 230 deletions(-) diff --git a/mas_arena/evaluators/aime_evaluator.py b/mas_arena/evaluators/aime_evaluator.py index 83c5f29..2140389 100644 --- a/mas_arena/evaluators/aime_evaluator.py +++ b/mas_arena/evaluators/aime_evaluator.py @@ -4,21 +4,15 @@ Standalone evaluator for AIME-style math problems. """ -import re + import time -from typing import Dict, Any, Tuple +from typing import Dict, Any from pathlib import Path -from math import isclose -from sympy import N, simplify -from sympy.parsing.latex import parse_latex -from sympy.parsing.sympy_parser import parse_expr from langsmith.evaluation import RunEvaluator -from langsmith.schemas import Run -from mas_arena.evaluators.base_evaluator import BaseEvaluator +from mas_arena.evaluators.math_evaluator import MathEvaluator from mas_arena.evaluators.registry import register_benchmark -from mas_arena.evaluators.utils.math_equal import calculate_score @register_benchmark( @@ -28,7 +22,7 @@ "solution": "answer", } ) -class AIMEEvaluator(BaseEvaluator): +class AIMEEvaluator(MathEvaluator): """ Evaluator for AIME-style math problems. @@ -57,185 +51,6 @@ def __init__(self, name: str = "aime", config: Dict[str, Any] = None): def from_config(cls, name: str, config: Dict[str, Any] = None): return cls(name, config) - def extract_answer(self, text: str) -> str: - if self.evaluate_type == 0: - return self.simple_extract_answer(text) - else: - return self.math_extract_answer(text) - - def math_extract_answer(self, text: str) -> str: - """ - Extract the answer from model output text (last number or string). - """ - # Try to extract the last number (int/float) - matches = re.findall(r"[-+]?\d+(?:,\d{3})*(?:\.\d+)?|\d+\.\d+", str(text)) - if matches: - return matches[-1].replace(",", "").strip() - # Fallback: last non-empty line - lines = [line.strip() for line in str(text).splitlines() if line.strip()] - return lines[-1] if lines else str(text).strip() - - def simple_extract_answer(self, text: str) -> str: - """ - Extract the answer from model output text, looking for boxed answers or final statements. - - Args: - text: The model's output text - - Returns: - The extracted answer - """ - # Look for LaTeX boxed answers first - pattern = r"\\boxed{((?:[^{}]|{[^{}]*})*)}" - boxed_matches = re.findall(pattern, text, re.DOTALL) - if boxed_matches: - return boxed_matches[-1].strip() - - # For AIME, also look for 3-digit numbers (000-999) - number_pattern = r"\b\d{3}\b" - number_matches = re.findall(number_pattern, text) - if number_matches: - return number_matches[-1] - - # If no boxed answer, try to extract the final conclusion - sentence_end_pattern = r"(? Tuple[int, str]: - """ - Calculate a score by comparing the expected and predicted answers. - - Args: - expected_output: The expected answer (solution) - prediction: The model's prediction - - Returns: - Tuple of (score, extracted_answer) where score is 1 for correct, 0 for incorrect - """ - extracted_expected = self.extract_answer(expected_output) - extracted_prediction = self.extract_answer(prediction) - - if self.math_equal(extracted_prediction, extracted_expected): - return 1, extracted_prediction - else: - return 0, extracted_prediction - - def calculate_score(self, expected_output: str, prediction: str) -> Tuple[int, str]: - return calculate_score(expected_output, prediction) - - def math_equal(self, prediction: Any, reference: Any) -> bool: - """ - Check if two mathematical expressions are equivalent. - - Args: - prediction: The predicted answer - reference: The reference answer - - Returns: - True if the expressions are equivalent, False otherwise - """ - # Direct string comparison - if str(prediction) == str(reference): - return True - - # For AIME, treat as integers and compare - try: - pred_int = int(prediction) - ref_int = int(reference) - return pred_int == ref_int - except ValueError: - pass - - # Numeric comparison - try: - if self.is_digit(prediction) and self.is_digit(reference): - prediction_val = self.parse_digits(prediction) - reference_val = self.parse_digits(reference) - return isclose(prediction_val, reference_val, abs_tol=1e-3) - except ValueError: - pass - - # Symbolic comparison - try: - return self.symbolic_equal(prediction, reference) - except Exception: - pass - - return False - - def is_digit(self, num): - """Check if a string can be parsed as a number""" - return self.parse_digits(num) is not None - - def parse_digits(self, num): - """Parse a string as a number, handling percentage and commas""" - num = str(num).replace(",", "") - try: - return float(num) - except ValueError: - if num.endswith("%"): - num = num[:-1] - if num.endswith("\\"): - num = num[:-1] - try: - return float(num) / 100 - except ValueError: - pass - return None - - def symbolic_equal(self, a, b): - """Check symbolic equality using SymPy""" - def _parse(s): - for f in [parse_latex, parse_expr]: - try: - return f(s) - except Exception: - pass - return s - - a = _parse(a) - b = _parse(b) - - try: - if simplify(a - b) == 0: - return True - except Exception: - pass - - try: - if isclose(N(a), N(b), abs_tol=1e-3): - return True - except Exception: - pass - - return False - - def extract_final_answer(self, messages: list) -> str: - """ - Extract the final answer from a list of messages. - - Args: - messages: List of messages from the agent conversation - - Returns: - The extracted final answer - """ - final_answer = "" - - if not messages: - return final_answer - - last_msg = messages[-1] - if isinstance(last_msg, tuple) and len(last_msg) > 1: - final_answer = last_msg[1] - elif hasattr(last_msg, "content"): - final_answer = last_msg.content - elif isinstance(last_msg, dict) and "content" in last_msg: - final_answer = last_msg["content"] - - return final_answer def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[str, Any]: """ @@ -250,14 +65,14 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[ """ # Extract the final answer from messages all_messages = run_result.get("messages", []) - final_answer = self.extract_final_answer(all_messages) + final_answer = super().extract_final_answer(all_messages) if self.evaluate_type == 0: # Use the new calculate_score method - score, extracted_answer = self.simple_calculate_score(problem["solution"], final_answer) + score, extracted_answer = super().simple_calculate_score(problem["solution"], final_answer) else: # Use the new calculate_score method - score, extracted_answer = self.calculate_score(problem["solution"], final_answer) + score, extracted_answer = super().calculate_score(problem["solution"], final_answer) # Return evaluation results return { diff --git a/mas_arena/evaluators/math_evaluator.py b/mas_arena/evaluators/math_evaluator.py index 2fc972b..bac309b 100644 --- a/mas_arena/evaluators/math_evaluator.py +++ b/mas_arena/evaluators/math_evaluator.py @@ -189,71 +189,127 @@ def parse_digits(self, num): pass return None - def latex_to_sympy(self, latex_str): - """Convert LaTeX to SymPy expression""" + def latex_to_sympy(self, latex_str: str): + """ + Convert a LaTeX string to a SymPy-parsable expression. + This function handles various LaTeX commands and environments to make the string + compatible with SymPy's parser. + """ latex_str = str(latex_str).strip() - - # Handle common LaTeX patterns - # Remove outer braces - if latex_str.startswith('{') and latex_str.endswith('}'): - latex_str = latex_str[1:-1] - + + # Handle matrix/vector environments by extracting their content + pmatrix_match = re.search(r'\\begin{pmatrix}(.*?)\\end{pmatrix}', latex_str, re.DOTALL) + if pmatrix_match: + content = pmatrix_match.group(1).strip() + # Split by \\ and & and filter out empty strings + elements = [elem.strip() for elem in re.split(r'\\\\|&', content) if elem.strip()] + return f"({', '.join(elements)})" + + # Remove other environments + latex_str = re.sub(r'\\begin{asy}.*?\\end{asy}', '', latex_str, flags=re.DOTALL) + latex_str = re.sub(r'\\begin{tabular}.*?\\end{tabular}', '', latex_str, flags=re.DOTALL) + latex_str = re.sub(r'\\begin{align\*}.*?\\end{align\*}', '', latex_str, flags=re.DOTALL) + latex_str = re.sub(r'\\begin{align}.*?\\end{align}', '', latex_str, flags=re.DOTALL) + + # Handle \boxed{} + match = re.search(r'\\boxed{(.*)}', latex_str) + if match: + latex_str = match.group(1) + + # Remove text commands + latex_str = re.sub(r'\\text(normal)?\{.*?\}', '', latex_str) + latex_str = re.sub(r'\\textit\{.*?\}', '', latex_str) + # Handle fractions - frac_pattern = r'\\frac\{([^}]+)\}\{([^}]+)\}' - if re.search(frac_pattern, latex_str): - latex_str = re.sub(frac_pattern, r'(\1)/(\2)', latex_str) + latex_str = re.sub(r'\\(d)?frac\{([^}]+)\}\{([^}]+)\}', r'((\2)/(\3))', latex_str) - # Handle dfrac (same as frac) - dfrac_pattern = r'\\dfrac\{([^}]+)\}\{([^}]+)\}' - if re.search(dfrac_pattern, latex_str): - latex_str = re.sub(dfrac_pattern, r'(\1)/(\2)', latex_str) + # Replacements for known LaTeX commands + replacements = { + r'\sin': 'sin', r'\cos': 'cos', r'\tan': 'tan', + r'\log': 'log', r'\ln': 'ln', + r'\sqrt': 'sqrt', r'\pi': 'pi', + r'\left': '', r'\right': '', + r'\cdot': '*', r'\times': '*', + r'\%': '/100', + r'^{\circ}': '', r'^\circ': '', + r'\$': '', r'\\,': '', r'\\!': '', r'\\#': '', + r'\allowbreak': '' + } + for old, new in replacements.items(): + latex_str = latex_str.replace(old, new) - # Handle numbers with commas (like {14{,}916}) - comma_pattern = r'\{(\d+)\{,\}(\d+)\}' - if re.search(comma_pattern, latex_str): - latex_str = re.sub(comma_pattern, r'\1\2', latex_str) + # Remove any other LaTeX commands that are just names + latex_str = re.sub(r'\\[a-zA-Z]+', '', latex_str) - # Handle simple numbers in braces - num_brace_pattern = r'\{(\d+)\}' - if re.search(num_brace_pattern, latex_str): - latex_str = re.sub(num_brace_pattern, r'\1', latex_str) + # remove subscripts like _{...} or _b + latex_str = re.sub(r'(_\{.*?\}|_b)', '', latex_str) + + # Handle numbers with commas + latex_str = re.sub(r'(\d),(\d)', r'\1\2', latex_str) + + # Handle repeating decimals + overline_match = re.search(r'(\d+)\.\\overline\{(\d+)\}', latex_str) + if overline_match: + integer_part = overline_match.group(1) + repeating_part = overline_match.group(2) + num = int(integer_part + repeating_part) - int(integer_part) + den = 10**len(repeating_part) - 1 + latex_str = f'({num}/{den})' + + overline_match = re.search(r'0\.\\overline\{(\d+)\}', latex_str) + if overline_match: + repeating_part = overline_match.group(1) + latex_str = f'({repeating_part}/(10**{len(repeating_part)}-1))' + + # Final cleanup of braces and backslashes + latex_str = latex_str.replace('{', '(').replace('}', ')').replace('\\', '').replace(' ', '') - return latex_str + # Cleanup mismatched parentheses + while '()' in latex_str: + latex_str = latex_str.replace('()', '') + + return latex_str.strip() def symbolic_equal(self, a, b): """Check symbolic equality using SymPy""" def _parse(s): - s_str = str(s) - - # Try to parse as LaTeX first + s_str = str(s).strip() + if not s_str: + return None + try: latex_converted = self.latex_to_sympy(s_str) + if not latex_converted: + return None return parse_expr(latex_converted) except Exception: + # Fallback to direct parsing if latex conversion fails pass - # Try direct parsing try: return parse_expr(s_str) - except Exception as e: - print("error:", str(e)) - pass - - return s + except Exception: + return None + + a_parsed = _parse(a) + b_parsed = _parse(b) - a = _parse(a) - b = _parse(b) + if a_parsed is None or b_parsed is None: + return False try: - if simplify(a - b) == 0: + # Simplify the difference + if simplify(a_parsed - b_parsed) == 0: return True except Exception: pass try: - if isclose(N(a), N(b), abs_tol=1e-3): + # Numerical evaluation + if isclose(N(a_parsed), N(b_parsed), abs_tol=1e-3): return True - except Exception: + except (TypeError, ValueError, Exception): + # This can fail if expressions are not numeric pass return False From 08a1a1717342272ab125983b3aa6951fa1f1f323 Mon Sep 17 00:00:00 2001 From: jiaqi Date: Wed, 6 Aug 2025 20:33:51 +0800 Subject: [PATCH 4/4] remove useless package --- mas_arena/evaluators/aime_evaluator.py | 1 - mas_arena/evaluators/math_evaluator.py | 8 +------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/mas_arena/evaluators/aime_evaluator.py b/mas_arena/evaluators/aime_evaluator.py index 2140389..076e8f0 100644 --- a/mas_arena/evaluators/aime_evaluator.py +++ b/mas_arena/evaluators/aime_evaluator.py @@ -5,7 +5,6 @@ """ -import time from typing import Dict, Any from pathlib import Path diff --git a/mas_arena/evaluators/math_evaluator.py b/mas_arena/evaluators/math_evaluator.py index bac309b..2b3f19b 100644 --- a/mas_arena/evaluators/math_evaluator.py +++ b/mas_arena/evaluators/math_evaluator.py @@ -3,15 +3,13 @@ This module provides a standalone evaluator for mathematical problems. """ -import asyncio import re import time -from typing import Dict, Any, Optional, List, Callable, Tuple +from typing import Dict, Any, Optional, List, Tuple from pathlib import Path from math import isclose from sympy import N, simplify -from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr from langsmith.evaluation import RunEvaluator from langsmith.schemas import Run @@ -19,9 +17,6 @@ from mas_arena.evaluators.base_evaluator import BaseEvaluator from mas_arena.evaluators.registry import register_benchmark from mas_arena.evaluators.utils.math_equal import calculate_score -from mas_arena.evaluators.utils.normalization import normalize_problem_keys - -# change @register_benchmark( name="math", @@ -31,7 +26,6 @@ "solution": "solution", } ) - class MathEvaluator(BaseEvaluator): """ Math Evaluator for evaluating math problems.