From 5a7be6d58c40be3875b3eadaf05a212d887abe92 Mon Sep 17 00:00:00 2001 From: morgendave Date: Tue, 18 Nov 2025 14:13:50 -0800 Subject: [PATCH 1/3] adding response quality validation for retry --- eval_protocol/exceptions.py | 7 + eval_protocol/models.py | 8 + eval_protocol/pytest/__init__.py | 3 + eval_protocol/pytest/evaluation_test_utils.py | 28 +- eval_protocol/pytest/exception_config.py | 133 +++++++-- eval_protocol/pytest/types.py | 2 + tests/test_exceptions.py | 16 +- tests/test_retry_mechanism.py | 255 +++++++++++++++++- 8 files changed, 427 insertions(+), 25 deletions(-) diff --git a/eval_protocol/exceptions.py b/eval_protocol/exceptions.py index 3b92c865..5459f1e9 100644 --- a/eval_protocol/exceptions.py +++ b/eval_protocol/exceptions.py @@ -134,6 +134,12 @@ class ScoreInvalidError(EvalProtocolError): status_code = 102 +class ResponseQualityError(EvalProtocolError): + """Response quality check failed (Status.Code.RESPONSE_QUALITY_ERROR = 103)""" + + status_code = 103 + + # Convenience mapping from status codes to exception classes # Only actual error conditions should raise exceptions STATUS_CODE_TO_EXCEPTION = { @@ -157,6 +163,7 @@ class ScoreInvalidError(EvalProtocolError): 100: None, # FINISHED - success, no exception 101: None, # RUNNING - in progress, no exception 102: None, # SCORE_INVALID - success, no exception + 103: ResponseQualityError, # RESPONSE_QUALITY_ERROR - quality check failed } diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 67d287ba..edfdbec8 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -117,6 +117,7 @@ class Code(int, Enum): FINISHED = 100 RUNNING = 101 SCORE_INVALID = 102 + RESPONSE_QUALITY_ERROR = 103 @classmethod def rollout_running(cls) -> "Status": @@ -367,6 +368,13 @@ def score_invalid( """Create a status indicating the score is invalid.""" return cls(code=cls.Code.SCORE_INVALID, message=message, details=details or []) + @classmethod + def response_quality_error( + cls, message: str = "Response quality check failed", details: Optional[List[Dict[str, Any]]] = None + ) -> "Status": + """Create a status indicating the response quality check failed.""" + return cls(code=cls.Code.RESPONSE_QUALITY_ERROR, message=message, details=details or []) + def is_running(self) -> bool: """Check if the status indicates the rollout is running.""" return self.code == self.Code.RUNNING diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index 31c81167..26485f43 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -8,6 +8,7 @@ from .evaluation_test import evaluation_test from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config from .rollout_processor import RolloutProcessor +from .rollout_result_post_processor import RolloutResultPostProcessor, NoOpRolloutResultPostProcessor from .types import RolloutProcessorConfig # Conditional import for optional dependencies @@ -42,6 +43,8 @@ "ExceptionHandlerConfig", "BackoffConfig", "get_default_exception_handler_config", + "RolloutResultPostProcessor", + "NoOpRolloutResultPostProcessor", ] # Only add to __all__ if available diff --git a/eval_protocol/pytest/evaluation_test_utils.py b/eval_protocol/pytest/evaluation_test_utils.py index b4d1c218..e570c815 100644 --- a/eval_protocol/pytest/evaluation_test_utils.py +++ b/eval_protocol/pytest/evaluation_test_utils.py @@ -28,6 +28,7 @@ ServerMode, ) from eval_protocol.pytest.exception_config import get_default_exception_handler_config +from eval_protocol.exceptions import ResponseQualityError import logging import json @@ -363,7 +364,21 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow: """Execute rollout for a single row with backoff retry.""" retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False}) retry_tasks = rollout_processor([row], retry_config) - return await retry_tasks[0] + result = await retry_tasks[0] + + # Apply post-processing quality checks if configured + # This must be inside the retry function so ResponseQualityError can trigger retries + if config.post_processor is not None: + try: + config.post_processor.process(result) + except ResponseQualityError as quality_error: + # Re-raise ResponseQualityError to trigger retry logic + raise quality_error + except Exception as post_process_error: + # Wrap unexpected post-processor errors in ResponseQualityError + raise ResponseQualityError(f"Post-processor failed: {post_process_error}") from post_process_error + + return result async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: EvaluationRow) -> EvaluationRow: """Execute a single row task with backoff retry.""" @@ -372,6 +387,15 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu # Try original task first result = await task # pyright: ignore[reportUnknownVariableType] + # Apply post-processing quality checks if configured + if config.post_processor is not None: + try: + config.post_processor.process(result) + except ResponseQualityError as quality_error: + raise quality_error + except Exception as post_process_error: + raise ResponseQualityError(f"Post-processor failed: {post_process_error}") from post_process_error + _set_rollout_status_to_finished(result) return result # pyright: ignore[reportUnknownVariableType] @@ -384,9 +408,9 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu if is_retryable and not should_giveup: # Use shared backoff function for retryable exceptions + # Note: post-processing is handled inside execute_row_with_backoff_retry try: result = await execute_row_with_backoff_retry(row) - _set_rollout_status_to_finished(result) return result diff --git a/eval_protocol/pytest/exception_config.py b/eval_protocol/pytest/exception_config.py index e4bb1b7c..83121643 100644 --- a/eval_protocol/pytest/exception_config.py +++ b/eval_protocol/pytest/exception_config.py @@ -4,7 +4,7 @@ import os from dataclasses import dataclass, field -from typing import Callable, Set, Type, Union +from typing import Callable, Dict, Set, Type, Union import backoff @@ -47,6 +47,7 @@ eval_protocol.exceptions.UnavailableError, eval_protocol.exceptions.UnauthenticatedError, eval_protocol.exceptions.ResourceExhaustedError, + eval_protocol.exceptions.ResponseQualityError, } @@ -78,8 +79,14 @@ class BackoffConfig: # Optional custom giveup function - if provided, overrides the default exception handling logic giveup_func: Callable[[Exception], bool] = lambda e: False - def get_backoff_decorator(self, exceptions: Set[Type[Exception]]): - """Get the appropriate backoff decorator based on configuration.""" + def get_backoff_decorator(self, exceptions: Set[Type[Exception]], exception_backoff_overrides: Dict[Type[Exception], "BackoffConfig"] | None = None): + """Get the appropriate backoff decorator based on configuration. + + Args: + exceptions: Set of exception types to retry + exception_backoff_overrides: Optional mapping of exception types to custom backoff configs. + If an exception type has an override, that config will be used instead of this one. + """ if not exceptions: # If no exceptions specified, return a no-op decorator def no_op_decorator(func): @@ -87,30 +94,97 @@ def no_op_decorator(func): return no_op_decorator - if self.strategy == "expo": + # If no overrides, use simple decorator for all exceptions + if not exception_backoff_overrides: + return self._create_single_decorator(exceptions, self) + + # Group exceptions by their backoff config to avoid double backoff + # Each exception type gets exactly one decorator based on its config + # Use a tuple of config attributes as the key since BackoffConfig is not hashable + config_to_exceptions: Dict[tuple, tuple[Set[Type[Exception]], "BackoffConfig"]] = {} + + for exc_type in exceptions: + if exc_type in exception_backoff_overrides: + override_config = exception_backoff_overrides[exc_type] + else: + override_config = self + + # Create a hashable key from config attributes + # Note: jitter and giveup_func are callable, which are hashable in Python + config_key = ( + override_config.strategy, + override_config.base_delay, + override_config.max_delay, + override_config.max_tries, + override_config.factor, + id(override_config.jitter) if override_config.jitter is not None else None, + id(override_config.giveup_func) if override_config.giveup_func is not None else None, + override_config.raise_on_giveup, + ) + + if config_key not in config_to_exceptions: + config_to_exceptions[config_key] = (set(), override_config) + exc_set, _ = config_to_exceptions[config_key] + exc_set.add(exc_type) + + # If all exceptions use the same config, use a single decorator + if len(config_to_exceptions) == 1: + exc_set, config = next(iter(config_to_exceptions.values())) + return self._create_single_decorator(exc_set, config) + + # Create separate decorators for each config group + # Each exception type gets exactly one decorator, preventing double backoff + decorators_by_config: list[tuple[Set[Type[Exception]], Callable]] = [] + + for exc_set, config in config_to_exceptions.values(): + decorator = self._create_single_decorator(exc_set, config) + if decorator: + decorators_by_config.append((exc_set, decorator)) + + # Create a combined decorator that applies all decorators + # Each decorator only catches exceptions in its exception set, so no double backoff + def combined_decorator(func): + decorated_func = func + + # Apply each decorator in order (inner to outer) + # Each decorator only catches exceptions in its specific exception set + # Since exception sets are disjoint (grouped by config), no double backoff + for exc_set, decorator in decorators_by_config: + decorated_func = decorator(decorated_func) + + return decorated_func + + return combined_decorator + + def _create_single_decorator(self, exc_set: Set[Type[Exception]], config: "BackoffConfig"): + """Create a single backoff decorator for a set of exceptions.""" + if not exc_set: + return None + + if config.strategy == "expo": return backoff.on_exception( backoff.expo, - tuple(exceptions), - max_tries=self.max_tries, - base=self.base_delay, - max_value=self.max_delay, - factor=self.factor, - jitter=self.jitter, - giveup=self.giveup_func, - raise_on_giveup=self.raise_on_giveup, + tuple(exc_set), + max_tries=config.max_tries, + base=config.base_delay, + max_value=config.max_delay, + factor=config.factor, + jitter=config.jitter, + giveup=config.giveup_func, + raise_on_giveup=config.raise_on_giveup, ) - elif self.strategy == "constant": + elif config.strategy == "constant": return backoff.on_exception( backoff.constant, - tuple(exceptions), - max_tries=self.max_tries, - interval=self.base_delay, - jitter=self.jitter, - giveup=self.giveup_func, - raise_on_giveup=self.raise_on_giveup, + tuple(exc_set), + max_tries=config.max_tries, + interval=config.base_delay, + jitter=config.jitter, + giveup=config.giveup_func, + raise_on_giveup=config.raise_on_giveup, ) else: - raise ValueError(f"Unknown backoff strategy: {self.strategy}") + raise ValueError(f"Unknown backoff strategy: {config.strategy}") @dataclass @@ -123,6 +197,10 @@ class ExceptionHandlerConfig: # Backoff configuration backoff_config: BackoffConfig = field(default_factory=BackoffConfig) + # Per-exception backoff overrides - allows custom backoff config for specific exception types + # For example, ResponseQualityError can use no backoff (base_delay=0, max_delay=0) + exception_backoff_overrides: Dict[Type[Exception], BackoffConfig] = field(default_factory=dict) + def __post_init__(self): """Automatically apply environment variable overrides after initialization.""" # Override backoff settings from environment variables @@ -133,10 +211,23 @@ def __post_init__(self): if "EP_FAIL_ON_MAX_RETRY" in os.environ: fail_on_max_retry = os.environ["EP_FAIL_ON_MAX_RETRY"].lower() self.backoff_config.raise_on_giveup = fail_on_max_retry != "false" + + # Set default no-backoff config for ResponseQualityError if not already set + if eval_protocol.exceptions.ResponseQualityError not in self.exception_backoff_overrides: + # Default: no backoff for ResponseQualityError (immediate retry) + self.exception_backoff_overrides[eval_protocol.exceptions.ResponseQualityError] = BackoffConfig( + strategy="constant", + base_delay=0.0, + max_delay=0.0, + max_tries=self.backoff_config.max_tries, + ) def get_backoff_decorator(self): """Get the backoff decorator configured for this exception handler.""" - return self.backoff_config.get_backoff_decorator(self.retryable_exceptions) + return self.backoff_config.get_backoff_decorator( + self.retryable_exceptions, + self.exception_backoff_overrides if self.exception_backoff_overrides else None + ) def get_default_exception_handler_config() -> ExceptionHandlerConfig: diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index 9603c7b9..9cf82ca9 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -11,6 +11,7 @@ from ..models import CompletionParams, EvaluationRow, Message from .exception_config import ExceptionHandlerConfig +from .rollout_result_post_processor import RolloutResultPostProcessor ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct DatasetPathParam = str @@ -75,3 +76,4 @@ class RolloutProcessorConfig: default_factory=dict ) # any additional kwargs to pass to the rollout processor exception_handler_config: ExceptionHandlerConfig | None = None # configuration for exception handling with backoff + post_processor: RolloutResultPostProcessor | None = None # optional post-processor for quality checks diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index a1fcdeb6..a58cba29 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -35,6 +35,7 @@ RolloutFinishedError, RolloutRunningError, ScoreInvalidError, + ResponseQualityError, ) @@ -71,6 +72,7 @@ def test_error_status_codes_raise_exceptions(): (14, UnavailableError, "UNAVAILABLE"), (15, DataLossError, "DATA_LOSS"), (16, UnauthenticatedError, "UNAUTHENTICATED"), + (103, ResponseQualityError, "RESPONSE_QUALITY_ERROR"), ] for code, expected_exception_class, name in error_test_cases: @@ -105,6 +107,7 @@ def test_status_code_mapping_completeness(): 100, 101, 102, # Custom EP codes + 103, # ResponseQualityError ] for code in expected_codes: @@ -113,7 +116,7 @@ def test_status_code_mapping_completeness(): def test_invalid_status_codes(): """Test behavior with invalid/unknown status codes.""" - invalid_codes = [-1, 17, 99, 103, 999] + invalid_codes = [-1, 17, 99, 104, 999] for code in invalid_codes: exception = exception_for_status_code(code) @@ -181,6 +184,7 @@ def test_status_code_enum_consistency(): Status.Code.FINISHED: None, Status.Code.RUNNING: None, Status.Code.SCORE_INVALID: None, + Status.Code.RESPONSE_QUALITY_ERROR: ResponseQualityError, } for status_code_enum, expected_exception_class in status_code_mapping.items(): @@ -219,6 +223,7 @@ def test_exception_inheritance(): RolloutFinishedError, RolloutRunningError, ScoreInvalidError, + ResponseQualityError, ] for exception_class in exception_classes: @@ -273,6 +278,12 @@ def test_real_world_usage_scenarios(): "should_raise": True, "expected_exception": UnavailableError, }, + { + "status_code": 103, + "description": "Response quality check failed", + "should_raise": True, + "expected_exception": ResponseQualityError, + }, ] for scenario in scenarios: @@ -320,6 +331,7 @@ def test_exception_status_code_attributes(): (RolloutFinishedError, 100), (RolloutRunningError, 101), (ScoreInvalidError, 102), + (ResponseQualityError, 103), ] for exception_class, expected_code in expected_mappings: @@ -342,6 +354,7 @@ def test_integration_with_retry_logic(): UnavailableError, UnauthenticatedError, ResourceExhaustedError, + ResponseQualityError, ] for exception_class in our_error_exceptions: @@ -356,6 +369,7 @@ def test_exception_message_preservation(): (13, "test error", InternalError), (5, "Model xyz not found", NotFoundError), (7, "Invalid API key", PermissionDeniedError), + (103, "Quality check failed: response too repetitive", ResponseQualityError), ] for status_code, message, expected_exception_class in test_cases: diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py index 95b70faf..9f773811 100644 --- a/tests/test_retry_mechanism.py +++ b/tests/test_retry_mechanism.py @@ -6,15 +6,18 @@ # pyright: reportPrivateImportUsage=false import asyncio +import backoff from collections import Counter +from typing import Type from typing_extensions import override -from unittest.mock import Mock +from unittest.mock import Mock, patch from eval_protocol.models import EvaluateResult, EvaluationRow, Message from eval_protocol.pytest.evaluation_test import evaluation_test from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, BackoffConfig +from eval_protocol.exceptions import ResponseQualityError import litellm @@ -397,3 +400,253 @@ def test_simple_giveup_verification(): ) print("āœ… Simple giveup test passed! 4xx error was not retried due to giveup function.") + + +# Test 5: ResponseQualityError with no backoff (immediate retry) +class MockRolloutProcessorResponseQuality(RolloutProcessor): + """Mock processor that raises ResponseQualityError""" + + def __init__(self): + self.mock_tracker: Mock = Mock() + + @override + def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]: + self.mock_tracker.batch_call(len(rows)) + + async def process_single_row(row: EvaluationRow) -> EvaluationRow: + self.mock_tracker.process_row_call(row.execution_metadata.rollout_id) + + # Determine attempt number by counting previous calls for this rollout_id + previous_calls = [ + call for call in self.mock_tracker.process_row_call.call_args_list if call[0][0] == row.execution_metadata.rollout_id + ] + attempt_number = len(previous_calls) + + # Fail on first attempt, succeed on retry + if attempt_number == 1: + raise ResponseQualityError("Response quality check failed: too repetitive") + + return row + + tasks = [asyncio.create_task(process_single_row(row)) for row in rows] + return tasks + + +shared_processor_response_quality = MockRolloutProcessorResponseQuality() + + +@evaluation_test( + completion_params=[{"model": "gpt-4o-mini", "temperature": 0}], + input_messages=[[[Message(role="user", content="Test quality")]]], + rollout_processor=shared_processor_response_quality, + num_runs=1, + mode="pointwise", + exception_handler_config=ExceptionHandlerConfig( + backoff_config=BackoffConfig(max_tries=3), + # ResponseQualityError should have no backoff by default (set in __post_init__) + ), +) +def test_response_quality_error_retry(row: EvaluationRow) -> EvaluationRow: + """Test that ResponseQualityError is retried with no backoff (immediate retry).""" + print( + f"šŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})" + ) + score = 1.0 if row.rollout_status.is_finished() else 0.0 + row.evaluation_result = EvaluateResult(score=score) + return row + + +def test_response_quality_error_verification(): + """Verify that ResponseQualityError is retried with no backoff.""" + mock_tracker = shared_processor_response_quality.mock_tracker + + print("\nšŸ”„ RESPONSE QUALITY ERROR TEST ANALYSIS:") + print(f" Batch calls made: {mock_tracker.batch_call.call_count}") + print(f" Total row processing calls: {mock_tracker.process_row_call.call_count}") + + call_args = mock_tracker.process_row_call.call_args_list + rollout_ids = [call[0][0] for call in call_args] + call_counts = Counter(rollout_ids) + + print(f" Call counts per rollout_id: {dict(call_counts)}") + + # Should have 2 calls: 1 original + 1 retry (no backoff, immediate retry) + # Note: With max_tries=3, it should retry up to 3 times, but our mock succeeds on attempt 2 + assert mock_tracker.process_row_call.call_count == 2, ( + f"Expected 2 calls (1 original + 1 retry), got {mock_tracker.process_row_call.call_count}" + ) + + # Should have exactly 1 rollout_id called twice + call_count_values = list(call_counts.values()) + assert call_count_values.count(2) == 1, ( + f"Expected 1 rollout with 2 calls, got {call_count_values}" + ) + + print("āœ… ResponseQualityError test passed! Error was retried with no backoff (immediate retry).") + + +# Test 6: Per-exception backoff overrides +class MockRolloutProcessorBackoffOverride(RolloutProcessor): + """Mock processor that raises different exceptions to test backoff overrides""" + + def __init__(self): + self.mock_tracker: Mock = Mock() + + @override + def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]: + self.mock_tracker.batch_call(len(rows)) + + async def process_single_row(row: EvaluationRow) -> EvaluationRow: + rollout_id = row.execution_metadata.rollout_id + self.mock_tracker.process_row_call(rollout_id) + + # Determine attempt number + previous_calls = [ + call for call in self.mock_tracker.process_row_call.call_args_list if call[0][0] == rollout_id + ] + attempt_number = len(previous_calls) + + task_content = row.messages[0].content if row.messages else "" + + # Different exceptions based on content + if task_content and "quality" in task_content: + # ResponseQualityError - should use no backoff override (immediate retry) + if attempt_number == 1: + raise ResponseQualityError("Quality check failed") + elif task_content and "connection" in task_content: + # ConnectionError - should use default backoff + if attempt_number <= 2: # Fail twice, succeed on third + raise ConnectionError("Connection failed") + + return row + + tasks = [asyncio.create_task(process_single_row(row)) for row in rows] + return tasks + + +shared_processor_backoff_override = MockRolloutProcessorBackoffOverride() + + +@evaluation_test( + completion_params=[{"model": "gpt-4o-mini", "temperature": 0}], + input_messages=[ + [ + [Message(role="user", content="Test quality")], # ResponseQualityError - no backoff + [Message(role="user", content="Test connection")], # ConnectionError - default backoff + ] + ], + rollout_processor=shared_processor_backoff_override, + num_runs=1, + mode="pointwise", + exception_handler_config=ExceptionHandlerConfig( + backoff_config=BackoffConfig(max_tries=3, base_delay=0.1, strategy="constant"), + exception_backoff_overrides={ + ResponseQualityError: BackoffConfig( + strategy="constant", + base_delay=0.0, # No backoff + max_delay=0.0, + max_tries=3, + ), + }, + ), +) +def test_backoff_override(row: EvaluationRow) -> EvaluationRow: + """Test that per-exception backoff overrides work correctly.""" + task_content = row.messages[0].content if row.messages else "" + print( + f"šŸ“Š EVALUATED: {task_content} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})" + ) + score = 1.0 if row.rollout_status.is_finished() else 0.0 + row.evaluation_result = EvaluateResult(score=score) + return row + + +def test_backoff_override_verification(): + """Verify that different backoff decorators are triggered for different exceptions.""" + # Track which exceptions are passed to backoff.on_exception + backoff_calls = [] + original_on_exception = backoff.on_exception + + def track_backoff_on_exception(*args, **kwargs): + """Track calls to backoff.on_exception to verify different decorators.""" + # Extract exception types from args (second positional arg is the exception tuple) + if len(args) >= 2: + exception_types = args[1] + # Extract interval/base_delay from kwargs + interval = kwargs.get('interval', kwargs.get('base', None)) + backoff_calls.append({ + 'exceptions': exception_types, + 'interval': interval, + }) + return original_on_exception(*args, **kwargs) + + # Re-run the test with backoff tracking + with patch('eval_protocol.pytest.exception_config.backoff.on_exception', side_effect=track_backoff_on_exception): + # Recreate the config to get a fresh decorator + config = ExceptionHandlerConfig( + backoff_config=BackoffConfig(max_tries=3, base_delay=0.1, strategy="constant"), + exception_backoff_overrides={ + ResponseQualityError: BackoffConfig( + strategy="constant", + base_delay=0.0, + max_delay=0.0, + max_tries=3, + ), + }, + ) + decorator = config.get_backoff_decorator() + + # Verify decorator was created + assert decorator is not None, "Decorator should be created" + + # Verify different decorators were created for different exceptions + assert len(backoff_calls) >= 2, ( + f"Expected at least 2 backoff.on_exception calls (one per exception type), got {len(backoff_calls)}" + ) + + # Find decorators for ResponseQualityError and ConnectionError + quality_decorator = None + connection_decorator = None + + for call in backoff_calls: + exceptions = call['exceptions'] + if ResponseQualityError in exceptions: + quality_decorator = call + if ConnectionError in exceptions: + connection_decorator = call + + assert quality_decorator is not None, "Should have a decorator for ResponseQualityError" + assert connection_decorator is not None, "Should have a decorator for ConnectionError" + + # Verify different intervals (base_delay) were used + assert quality_decorator['interval'] == 0.0, ( + f"ResponseQualityError decorator should have interval=0.0, got {quality_decorator['interval']}" + ) + assert connection_decorator['interval'] == 0.1, ( + f"ConnectionError decorator should have interval=0.1, got {connection_decorator['interval']}" + ) + + # Verify exception sets are disjoint (no overlap) + quality_exceptions = set(quality_decorator['exceptions']) + connection_exceptions = set(connection_decorator['exceptions']) + assert quality_exceptions.isdisjoint(connection_exceptions), ( + "Exception sets should be disjoint to prevent double backoff" + ) + + # Verify call counts from actual test run + mock_tracker = shared_processor_backoff_override.mock_tracker + call_args = mock_tracker.process_row_call.call_args_list + rollout_ids = [call[0][0] for call in call_args] + call_counts = Counter(rollout_ids) + + assert mock_tracker.process_row_call.call_count == 5, ( + f"Expected 5 calls total (2 for quality + 3 for connection), got {mock_tracker.process_row_call.call_count}" + ) + + call_count_values = list(call_counts.values()) + assert call_count_values.count(2) == 1, ( + f"Expected 1 rollout with 2 calls (ResponseQualityError with no backoff), got {call_count_values}" + ) + assert call_count_values.count(3) == 1, ( + f"Expected 1 rollout with 3 calls (ConnectionError with backoff), got {call_count_values}" + ) From a3222247761b0ed3aa24e0bf2b8f8d7b3ade9cfe Mon Sep 17 00:00:00 2001 From: morgendave Date: Wed, 26 Nov 2025 15:44:33 -0800 Subject: [PATCH 2/3] Add rollout result post processor --- eval_protocol/pytest/evaluation_test_utils.py | 5 - eval_protocol/pytest/exception_config.py | 83 +------- .../pytest/rollout_result_post_processor.py | 57 ++++++ tests/test_exception_config.py | 114 +++++++++++ tests/test_retry_mechanism.py | 179 +----------------- 5 files changed, 180 insertions(+), 258 deletions(-) create mode 100644 eval_protocol/pytest/rollout_result_post_processor.py create mode 100644 tests/test_exception_config.py diff --git a/eval_protocol/pytest/evaluation_test_utils.py b/eval_protocol/pytest/evaluation_test_utils.py index e570c815..94d6f7fe 100644 --- a/eval_protocol/pytest/evaluation_test_utils.py +++ b/eval_protocol/pytest/evaluation_test_utils.py @@ -374,9 +374,6 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow: except ResponseQualityError as quality_error: # Re-raise ResponseQualityError to trigger retry logic raise quality_error - except Exception as post_process_error: - # Wrap unexpected post-processor errors in ResponseQualityError - raise ResponseQualityError(f"Post-processor failed: {post_process_error}") from post_process_error return result @@ -393,8 +390,6 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu config.post_processor.process(result) except ResponseQualityError as quality_error: raise quality_error - except Exception as post_process_error: - raise ResponseQualityError(f"Post-processor failed: {post_process_error}") from post_process_error _set_rollout_status_to_finished(result) diff --git a/eval_protocol/pytest/exception_config.py b/eval_protocol/pytest/exception_config.py index 83121643..bf7d84f8 100644 --- a/eval_protocol/pytest/exception_config.py +++ b/eval_protocol/pytest/exception_config.py @@ -79,13 +79,11 @@ class BackoffConfig: # Optional custom giveup function - if provided, overrides the default exception handling logic giveup_func: Callable[[Exception], bool] = lambda e: False - def get_backoff_decorator(self, exceptions: Set[Type[Exception]], exception_backoff_overrides: Dict[Type[Exception], "BackoffConfig"] | None = None): + def get_backoff_decorator(self, exceptions: Set[Type[Exception]]): """Get the appropriate backoff decorator based on configuration. Args: exceptions: Set of exception types to retry - exception_backoff_overrides: Optional mapping of exception types to custom backoff configs. - If an exception type has an override, that config will be used instead of this one. """ if not exceptions: # If no exceptions specified, return a no-op decorator @@ -94,67 +92,7 @@ def no_op_decorator(func): return no_op_decorator - # If no overrides, use simple decorator for all exceptions - if not exception_backoff_overrides: - return self._create_single_decorator(exceptions, self) - - # Group exceptions by their backoff config to avoid double backoff - # Each exception type gets exactly one decorator based on its config - # Use a tuple of config attributes as the key since BackoffConfig is not hashable - config_to_exceptions: Dict[tuple, tuple[Set[Type[Exception]], "BackoffConfig"]] = {} - - for exc_type in exceptions: - if exc_type in exception_backoff_overrides: - override_config = exception_backoff_overrides[exc_type] - else: - override_config = self - - # Create a hashable key from config attributes - # Note: jitter and giveup_func are callable, which are hashable in Python - config_key = ( - override_config.strategy, - override_config.base_delay, - override_config.max_delay, - override_config.max_tries, - override_config.factor, - id(override_config.jitter) if override_config.jitter is not None else None, - id(override_config.giveup_func) if override_config.giveup_func is not None else None, - override_config.raise_on_giveup, - ) - - if config_key not in config_to_exceptions: - config_to_exceptions[config_key] = (set(), override_config) - exc_set, _ = config_to_exceptions[config_key] - exc_set.add(exc_type) - - # If all exceptions use the same config, use a single decorator - if len(config_to_exceptions) == 1: - exc_set, config = next(iter(config_to_exceptions.values())) - return self._create_single_decorator(exc_set, config) - - # Create separate decorators for each config group - # Each exception type gets exactly one decorator, preventing double backoff - decorators_by_config: list[tuple[Set[Type[Exception]], Callable]] = [] - - for exc_set, config in config_to_exceptions.values(): - decorator = self._create_single_decorator(exc_set, config) - if decorator: - decorators_by_config.append((exc_set, decorator)) - - # Create a combined decorator that applies all decorators - # Each decorator only catches exceptions in its exception set, so no double backoff - def combined_decorator(func): - decorated_func = func - - # Apply each decorator in order (inner to outer) - # Each decorator only catches exceptions in its specific exception set - # Since exception sets are disjoint (grouped by config), no double backoff - for exc_set, decorator in decorators_by_config: - decorated_func = decorator(decorated_func) - - return decorated_func - - return combined_decorator + return self._create_single_decorator(exceptions, self) def _create_single_decorator(self, exc_set: Set[Type[Exception]], config: "BackoffConfig"): """Create a single backoff decorator for a set of exceptions.""" @@ -197,10 +135,6 @@ class ExceptionHandlerConfig: # Backoff configuration backoff_config: BackoffConfig = field(default_factory=BackoffConfig) - # Per-exception backoff overrides - allows custom backoff config for specific exception types - # For example, ResponseQualityError can use no backoff (base_delay=0, max_delay=0) - exception_backoff_overrides: Dict[Type[Exception], BackoffConfig] = field(default_factory=dict) - def __post_init__(self): """Automatically apply environment variable overrides after initialization.""" # Override backoff settings from environment variables @@ -211,22 +145,11 @@ def __post_init__(self): if "EP_FAIL_ON_MAX_RETRY" in os.environ: fail_on_max_retry = os.environ["EP_FAIL_ON_MAX_RETRY"].lower() self.backoff_config.raise_on_giveup = fail_on_max_retry != "false" - - # Set default no-backoff config for ResponseQualityError if not already set - if eval_protocol.exceptions.ResponseQualityError not in self.exception_backoff_overrides: - # Default: no backoff for ResponseQualityError (immediate retry) - self.exception_backoff_overrides[eval_protocol.exceptions.ResponseQualityError] = BackoffConfig( - strategy="constant", - base_delay=0.0, - max_delay=0.0, - max_tries=self.backoff_config.max_tries, - ) def get_backoff_decorator(self): """Get the backoff decorator configured for this exception handler.""" return self.backoff_config.get_backoff_decorator( - self.retryable_exceptions, - self.exception_backoff_overrides if self.exception_backoff_overrides else None + self.retryable_exceptions ) diff --git a/eval_protocol/pytest/rollout_result_post_processor.py b/eval_protocol/pytest/rollout_result_post_processor.py new file mode 100644 index 00000000..cdaa98d5 --- /dev/null +++ b/eval_protocol/pytest/rollout_result_post_processor.py @@ -0,0 +1,57 @@ +""" +Rollout result post-processing plugin for quality checks. + +This module provides an abstract base class for post-processing rollout results +to guard response quality. Post-processors can validate results and raise +ResponseQualityError if quality checks fail. +""" + +from abc import ABC, abstractmethod + +from eval_protocol.models import EvaluationRow + + +class RolloutResultPostProcessor(ABC): + """ + Abstract base class for rollout result post-processing plugins. + + Post-processors validate rollout results and can raise ResponseQualityError + if quality checks fail. This allows for customizable quality guards that + can be overridden by users. + """ + + @abstractmethod + def process(self, result: EvaluationRow) -> None: + """ + Process and validate a rollout result. + + This method should perform quality checks on the result. If quality + checks fail, it should raise ResponseQualityError with an appropriate + message. + + Args: + result: The EvaluationRow result from the rollout + + Raises: + ResponseQualityError: If quality checks fail + """ + pass + + +class NoOpRolloutResultPostProcessor(RolloutResultPostProcessor): + """ + Default no-op implementation of RolloutResultPostProcessor. + + This implementation does not perform any quality checks and always passes. + Use this as a default when no post-processing is needed. + """ + + def process(self, result: EvaluationRow) -> None: + """ + No-op implementation that does not perform any quality checks. + + Args: + result: The EvaluationRow result from the rollout + """ + pass + diff --git a/tests/test_exception_config.py b/tests/test_exception_config.py new file mode 100644 index 00000000..90db182a --- /dev/null +++ b/tests/test_exception_config.py @@ -0,0 +1,114 @@ +""" +Unit tests for exception_config module. + +Tests the BackoffConfig and ExceptionHandlerConfig classes, including: +1. Backoff decorator creation +2. Per-exception backoff overrides +3. ResponseQualityError default no-backoff configuration +4. Exception grouping to avoid double backoff +""" + +import pytest +from eval_protocol.pytest.exception_config import BackoffConfig, ExceptionHandlerConfig, DEFAULT_RETRYABLE_EXCEPTIONS +from eval_protocol.exceptions import ResponseQualityError + + +def test_backoff_config_no_exceptions(): + """Test that BackoffConfig returns no-op decorator when no exceptions specified.""" + config = BackoffConfig() + decorator = config.get_backoff_decorator(set()) + + # Should be a no-op decorator + def test_func(): + return "test" + + decorated = decorator(test_func) + assert decorated() == "test" + assert decorated is test_func # Should be the same function + + +def test_backoff_config_no_overrides(): + """Test that BackoffConfig creates a single decorator.""" + config = BackoffConfig(strategy="constant", base_delay=0.1, max_tries=2) + exceptions = {ConnectionError, TimeoutError} + + decorator = config.get_backoff_decorator(exceptions) + assert decorator is not None + + # Decorator should be callable + def test_func(): + raise ConnectionError("test") + + decorated = decorator(test_func) + assert callable(decorated) + + +def test_exception_handler_config_default_response_quality_error(): + """Test that ExceptionHandlerConfig includes ResponseQualityError by default.""" + config = ExceptionHandlerConfig() + + # ResponseQualityError should be in retryable_exceptions + assert ResponseQualityError in config.retryable_exceptions + + +def test_exception_handler_config_get_backoff_decorator(): + """Test that ExceptionHandlerConfig.get_backoff_decorator() works correctly.""" + config = ExceptionHandlerConfig() + decorator = config.get_backoff_decorator() + + assert decorator is not None + assert callable(decorator) + + # Should be able to decorate a function + def test_func(): + raise ConnectionError("test") + + decorated = decorator(test_func) + assert callable(decorated) + + +def test_backoff_config_expo_strategy(): + + """Test that BackoffConfig creates expo decorator correctly.""" + config = BackoffConfig(strategy="expo", base_delay=1.0, max_tries=2) + exceptions = {ConnectionError} + + decorator = config.get_backoff_decorator(exceptions) + assert decorator is not None + + def test_func(): + raise ConnectionError("test") + + decorated = decorator(test_func) + assert callable(decorated) + + +def test_backoff_config_constant_strategy(): + """Test that BackoffConfig creates constant decorator correctly.""" + config = BackoffConfig(strategy="constant", base_delay=0.1, max_tries=2) + exceptions = {ConnectionError} + + decorator = config.get_backoff_decorator(exceptions) + assert decorator is not None + + def test_func(): + raise ConnectionError("test") + + decorated = decorator(test_func) + assert callable(decorated) + + +def test_backoff_config_invalid_strategy(): + """Test that BackoffConfig raises ValueError for invalid strategy.""" + config = BackoffConfig(strategy="invalid", base_delay=1.0, max_tries=2) + exceptions = {ConnectionError} + + with pytest.raises(ValueError, match="Unknown backoff strategy"): + config.get_backoff_decorator(exceptions) + + +def test_exception_handler_config_response_quality_error_in_defaults(): + """Test that ResponseQualityError is in DEFAULT_RETRYABLE_EXCEPTIONS.""" + assert ResponseQualityError in DEFAULT_RETRYABLE_EXCEPTIONS + + diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py index 9f773811..861793c1 100644 --- a/tests/test_retry_mechanism.py +++ b/tests/test_retry_mechanism.py @@ -266,7 +266,7 @@ def custom_http_giveup(e: Exception) -> bool: return True # Give up immediately on bad requests elif isinstance(e, litellm.RateLimitError): return False # Retry rate limits with backoff - + return False # Retry everything else @@ -388,7 +388,7 @@ def test_simple_giveup_function(row: EvaluationRow) -> EvaluationRow: def test_simple_giveup_verification(): """Verify that giveup function prevents retries.""" mock_tracker = shared_processor_simple_giveup.mock_tracker - + print("\nšŸ”„ SIMPLE GIVEUP TEST ANALYSIS:") print(f" Batch calls made: {mock_tracker.batch_call.call_count}") print(f" Total row processing calls: {mock_tracker.process_row_call.call_count}") @@ -443,11 +443,10 @@ async def process_single_row(row: EvaluationRow) -> EvaluationRow: mode="pointwise", exception_handler_config=ExceptionHandlerConfig( backoff_config=BackoffConfig(max_tries=3), - # ResponseQualityError should have no backoff by default (set in __post_init__) ), ) def test_response_quality_error_retry(row: EvaluationRow) -> EvaluationRow: - """Test that ResponseQualityError is retried with no backoff (immediate retry).""" + """Test that ResponseQualityError is retried (using default backoff).""" print( f"šŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})" ) @@ -457,7 +456,7 @@ def test_response_quality_error_retry(row: EvaluationRow) -> EvaluationRow: def test_response_quality_error_verification(): - """Verify that ResponseQualityError is retried with no backoff.""" + """Verify that ResponseQualityError is retried.""" mock_tracker = shared_processor_response_quality.mock_tracker print("\nšŸ”„ RESPONSE QUALITY ERROR TEST ANALYSIS:") @@ -470,7 +469,7 @@ def test_response_quality_error_verification(): print(f" Call counts per rollout_id: {dict(call_counts)}") - # Should have 2 calls: 1 original + 1 retry (no backoff, immediate retry) + # Should have 2 calls: 1 original + 1 retry # Note: With max_tries=3, it should retry up to 3 times, but our mock succeeds on attempt 2 assert mock_tracker.process_row_call.call_count == 2, ( f"Expected 2 calls (1 original + 1 retry), got {mock_tracker.process_row_call.call_count}" @@ -482,171 +481,5 @@ def test_response_quality_error_verification(): f"Expected 1 rollout with 2 calls, got {call_count_values}" ) - print("āœ… ResponseQualityError test passed! Error was retried with no backoff (immediate retry).") - - -# Test 6: Per-exception backoff overrides -class MockRolloutProcessorBackoffOverride(RolloutProcessor): - """Mock processor that raises different exceptions to test backoff overrides""" - - def __init__(self): - self.mock_tracker: Mock = Mock() - - @override - def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]: - self.mock_tracker.batch_call(len(rows)) - - async def process_single_row(row: EvaluationRow) -> EvaluationRow: - rollout_id = row.execution_metadata.rollout_id - self.mock_tracker.process_row_call(rollout_id) + print("āœ… ResponseQualityError test passed! Error was retried.") - # Determine attempt number - previous_calls = [ - call for call in self.mock_tracker.process_row_call.call_args_list if call[0][0] == rollout_id - ] - attempt_number = len(previous_calls) - - task_content = row.messages[0].content if row.messages else "" - - # Different exceptions based on content - if task_content and "quality" in task_content: - # ResponseQualityError - should use no backoff override (immediate retry) - if attempt_number == 1: - raise ResponseQualityError("Quality check failed") - elif task_content and "connection" in task_content: - # ConnectionError - should use default backoff - if attempt_number <= 2: # Fail twice, succeed on third - raise ConnectionError("Connection failed") - - return row - - tasks = [asyncio.create_task(process_single_row(row)) for row in rows] - return tasks - - -shared_processor_backoff_override = MockRolloutProcessorBackoffOverride() - - -@evaluation_test( - completion_params=[{"model": "gpt-4o-mini", "temperature": 0}], - input_messages=[ - [ - [Message(role="user", content="Test quality")], # ResponseQualityError - no backoff - [Message(role="user", content="Test connection")], # ConnectionError - default backoff - ] - ], - rollout_processor=shared_processor_backoff_override, - num_runs=1, - mode="pointwise", - exception_handler_config=ExceptionHandlerConfig( - backoff_config=BackoffConfig(max_tries=3, base_delay=0.1, strategy="constant"), - exception_backoff_overrides={ - ResponseQualityError: BackoffConfig( - strategy="constant", - base_delay=0.0, # No backoff - max_delay=0.0, - max_tries=3, - ), - }, - ), -) -def test_backoff_override(row: EvaluationRow) -> EvaluationRow: - """Test that per-exception backoff overrides work correctly.""" - task_content = row.messages[0].content if row.messages else "" - print( - f"šŸ“Š EVALUATED: {task_content} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})" - ) - score = 1.0 if row.rollout_status.is_finished() else 0.0 - row.evaluation_result = EvaluateResult(score=score) - return row - - -def test_backoff_override_verification(): - """Verify that different backoff decorators are triggered for different exceptions.""" - # Track which exceptions are passed to backoff.on_exception - backoff_calls = [] - original_on_exception = backoff.on_exception - - def track_backoff_on_exception(*args, **kwargs): - """Track calls to backoff.on_exception to verify different decorators.""" - # Extract exception types from args (second positional arg is the exception tuple) - if len(args) >= 2: - exception_types = args[1] - # Extract interval/base_delay from kwargs - interval = kwargs.get('interval', kwargs.get('base', None)) - backoff_calls.append({ - 'exceptions': exception_types, - 'interval': interval, - }) - return original_on_exception(*args, **kwargs) - - # Re-run the test with backoff tracking - with patch('eval_protocol.pytest.exception_config.backoff.on_exception', side_effect=track_backoff_on_exception): - # Recreate the config to get a fresh decorator - config = ExceptionHandlerConfig( - backoff_config=BackoffConfig(max_tries=3, base_delay=0.1, strategy="constant"), - exception_backoff_overrides={ - ResponseQualityError: BackoffConfig( - strategy="constant", - base_delay=0.0, - max_delay=0.0, - max_tries=3, - ), - }, - ) - decorator = config.get_backoff_decorator() - - # Verify decorator was created - assert decorator is not None, "Decorator should be created" - - # Verify different decorators were created for different exceptions - assert len(backoff_calls) >= 2, ( - f"Expected at least 2 backoff.on_exception calls (one per exception type), got {len(backoff_calls)}" - ) - - # Find decorators for ResponseQualityError and ConnectionError - quality_decorator = None - connection_decorator = None - - for call in backoff_calls: - exceptions = call['exceptions'] - if ResponseQualityError in exceptions: - quality_decorator = call - if ConnectionError in exceptions: - connection_decorator = call - - assert quality_decorator is not None, "Should have a decorator for ResponseQualityError" - assert connection_decorator is not None, "Should have a decorator for ConnectionError" - - # Verify different intervals (base_delay) were used - assert quality_decorator['interval'] == 0.0, ( - f"ResponseQualityError decorator should have interval=0.0, got {quality_decorator['interval']}" - ) - assert connection_decorator['interval'] == 0.1, ( - f"ConnectionError decorator should have interval=0.1, got {connection_decorator['interval']}" - ) - - # Verify exception sets are disjoint (no overlap) - quality_exceptions = set(quality_decorator['exceptions']) - connection_exceptions = set(connection_decorator['exceptions']) - assert quality_exceptions.isdisjoint(connection_exceptions), ( - "Exception sets should be disjoint to prevent double backoff" - ) - - # Verify call counts from actual test run - mock_tracker = shared_processor_backoff_override.mock_tracker - call_args = mock_tracker.process_row_call.call_args_list - rollout_ids = [call[0][0] for call in call_args] - call_counts = Counter(rollout_ids) - - assert mock_tracker.process_row_call.call_count == 5, ( - f"Expected 5 calls total (2 for quality + 3 for connection), got {mock_tracker.process_row_call.call_count}" - ) - - call_count_values = list(call_counts.values()) - assert call_count_values.count(2) == 1, ( - f"Expected 1 rollout with 2 calls (ResponseQualityError with no backoff), got {call_count_values}" - ) - assert call_count_values.count(3) == 1, ( - f"Expected 1 rollout with 3 calls (ConnectionError with backoff), got {call_count_values}" - ) From 2eeb478c637d2d4897bbc6c03df2906f4216f494 Mon Sep 17 00:00:00 2001 From: morgendave Date: Wed, 26 Nov 2025 15:50:45 -0800 Subject: [PATCH 3/3] simplify the config --- eval_protocol/pytest/exception_config.py | 41 ++++++++++-------------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/eval_protocol/pytest/exception_config.py b/eval_protocol/pytest/exception_config.py index bf7d84f8..a2244b2a 100644 --- a/eval_protocol/pytest/exception_config.py +++ b/eval_protocol/pytest/exception_config.py @@ -92,37 +92,30 @@ def no_op_decorator(func): return no_op_decorator - return self._create_single_decorator(exceptions, self) - - def _create_single_decorator(self, exc_set: Set[Type[Exception]], config: "BackoffConfig"): - """Create a single backoff decorator for a set of exceptions.""" - if not exc_set: - return None - - if config.strategy == "expo": + if self.strategy == "expo": return backoff.on_exception( backoff.expo, - tuple(exc_set), - max_tries=config.max_tries, - base=config.base_delay, - max_value=config.max_delay, - factor=config.factor, - jitter=config.jitter, - giveup=config.giveup_func, - raise_on_giveup=config.raise_on_giveup, + tuple(exceptions), + max_tries=self.max_tries, + base=self.base_delay, + max_value=self.max_delay, + factor=self.factor, + jitter=self.jitter, + giveup=self.giveup_func, + raise_on_giveup=self.raise_on_giveup, ) - elif config.strategy == "constant": + elif self.strategy == "constant": return backoff.on_exception( backoff.constant, - tuple(exc_set), - max_tries=config.max_tries, - interval=config.base_delay, - jitter=config.jitter, - giveup=config.giveup_func, - raise_on_giveup=config.raise_on_giveup, + tuple(exceptions), + max_tries=self.max_tries, + interval=self.base_delay, + jitter=self.jitter, + giveup=self.giveup_func, + raise_on_giveup=self.raise_on_giveup, ) else: - raise ValueError(f"Unknown backoff strategy: {config.strategy}") + raise ValueError(f"Unknown backoff strategy: {self.strategy}") @dataclass