Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 7 additions & 188 deletions mas_arena/evaluators/aime_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,14 @@
Standalone evaluator for AIME-style math problems.
"""

import re
import time
from typing import Dict, Any, Tuple

from typing import Dict, Any
from pathlib import Path
from math import isclose

from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
from langsmith.evaluation import RunEvaluator
from langsmith.schemas import Run

from mas_arena.evaluators.base_evaluator import BaseEvaluator
from mas_arena.evaluators.math_evaluator import MathEvaluator
from mas_arena.evaluators.registry import register_benchmark
from mas_arena.evaluators.utils.math_equal import calculate_score
from mas_arena.evaluators.utils import extract_answer_numeric


@register_benchmark(
Expand All @@ -29,7 +21,7 @@
"solution": "answer",
}
)
class AIMEEvaluator(BaseEvaluator):
class AIMEEvaluator(MathEvaluator):
"""
Evaluator for AIME-style math problems.

Expand Down Expand Up @@ -58,179 +50,6 @@ def __init__(self, name: str = "aime", config: Dict[str, Any] = None):
def from_config(cls, name: str, config: Dict[str, Any] = None):
return cls(name, config)

def extract_answer(self, text: str) -> str:
if self.evaluate_type == 0:
return self.simple_extract_answer(text)
else:
return self.math_extract_answer(text)

def math_extract_answer(self, text: str) -> str:
"""
Extract the answer from model output text (last number or string).
"""
return extract_answer_numeric(text)

def simple_extract_answer(self, text: str) -> str:
"""
Extract the answer from model output text, looking for boxed answers or final statements.

Args:
text: The model's output text

Returns:
The extracted answer
"""
# Look for LaTeX boxed answers first
pattern = r"\\boxed{((?:[^{}]|{[^{}]*})*)}"
boxed_matches = re.findall(pattern, text, re.DOTALL)
if boxed_matches:
return boxed_matches[-1].strip()

# For AIME, also look for 3-digit numbers (000-999)
number_pattern = r"\b\d{3}\b"
number_matches = re.findall(number_pattern, text)
if number_matches:
return number_matches[-1]

# If no boxed answer, try to extract the final conclusion
sentence_end_pattern = r"(?<!\d)[.!?]\s+"
sentences = re.split(sentence_end_pattern, text)
sentences = [s.strip() for s in sentences if s.strip()]
return sentences[-1] if sentences else ""

def simple_calculate_score(self, expected_output: str, prediction: str) -> Tuple[int, str]:
"""
Calculate a score by comparing the expected and predicted answers.

Args:
expected_output: The expected answer (solution)
prediction: The model's prediction

Returns:
Tuple of (score, extracted_answer) where score is 1 for correct, 0 for incorrect
"""
extracted_expected = self.extract_answer(expected_output)
extracted_prediction = self.extract_answer(prediction)

if self.math_equal(extracted_prediction, extracted_expected):
return 1, extracted_prediction
else:
return 0, extracted_prediction

def calculate_score(self, expected_output: str, prediction: str) -> Tuple[int, str]:
return calculate_score(expected_output, prediction)

def math_equal(self, prediction: Any, reference: Any) -> bool:
"""
Check if two mathematical expressions are equivalent.

Args:
prediction: The predicted answer
reference: The reference answer

Returns:
True if the expressions are equivalent, False otherwise
"""
# Direct string comparison
if str(prediction) == str(reference):
return True

# For AIME, treat as integers and compare
try:
pred_int = int(prediction)
ref_int = int(reference)
return pred_int == ref_int
except ValueError:
pass

# Numeric comparison
try:
if self.is_digit(prediction) and self.is_digit(reference):
prediction_val = self.parse_digits(prediction)
reference_val = self.parse_digits(reference)
return isclose(prediction_val, reference_val, abs_tol=1e-3)
except ValueError:
pass

# Symbolic comparison
try:
return self.symbolic_equal(prediction, reference)
except Exception:
pass

return False

def is_digit(self, num):
"""Check if a string can be parsed as a number"""
return self.parse_digits(num) is not None

def parse_digits(self, num):
"""Parse a string as a number, handling percentage and commas"""
num = str(num).replace(",", "")
try:
return float(num)
except ValueError:
if num.endswith("%"):
num = num[:-1]
if num.endswith("\\"):
num = num[:-1]
try:
return float(num) / 100
except ValueError:
pass
return None

def symbolic_equal(self, a, b):
"""Check symbolic equality using SymPy"""
def _parse(s):
for f in [parse_latex, parse_expr]:
try:
return f(s)
except Exception:
pass
return s

a = _parse(a)
b = _parse(b)

try:
if simplify(a - b) == 0:
return True
except Exception:
pass

try:
if isclose(N(a), N(b), abs_tol=1e-3):
return True
except Exception:
pass

return False

def extract_final_answer(self, messages: list) -> str:
"""
Extract the final answer from a list of messages.

Args:
messages: List of messages from the agent conversation

Returns:
The extracted final answer
"""
final_answer = ""

if not messages:
return final_answer

last_msg = messages[-1]
if isinstance(last_msg, tuple) and len(last_msg) > 1:
final_answer = last_msg[1]
elif hasattr(last_msg, "content"):
final_answer = last_msg.content
elif isinstance(last_msg, dict) and "content" in last_msg:
final_answer = last_msg["content"]

return final_answer

def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[str, Any]:
"""
Expand All @@ -245,14 +64,14 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[
"""
# Extract the final answer from messages
all_messages = run_result.get("messages", [])
final_answer = self.extract_final_answer(all_messages)
final_answer = super().extract_final_answer(all_messages)

if self.evaluate_type == 0:
# Use the new calculate_score method
score, extracted_answer = self.simple_calculate_score(problem["solution"], final_answer)
score, extracted_answer = super().simple_calculate_score(problem["solution"], final_answer)
else:
# Use the new calculate_score method
score, extracted_answer = self.calculate_score(problem["solution"], final_answer)
score, extracted_answer = super().calculate_score(problem["solution"], final_answer)

# Return evaluation results
return {
Expand Down
Loading