diff --git a/eval_protocol/exceptions.py b/eval_protocol/exceptions.py index 3b92c865..5459f1e9 100644 --- a/eval_protocol/exceptions.py +++ b/eval_protocol/exceptions.py @@ -134,6 +134,12 @@ class ScoreInvalidError(EvalProtocolError): status_code = 102 +class ResponseQualityError(EvalProtocolError): + """Response quality check failed (Status.Code.RESPONSE_QUALITY_ERROR = 103)""" + + status_code = 103 + + # Convenience mapping from status codes to exception classes # Only actual error conditions should raise exceptions STATUS_CODE_TO_EXCEPTION = { @@ -157,6 +163,7 @@ class ScoreInvalidError(EvalProtocolError): 100: None, # FINISHED - success, no exception 101: None, # RUNNING - in progress, no exception 102: None, # SCORE_INVALID - success, no exception + 103: ResponseQualityError, # RESPONSE_QUALITY_ERROR - quality check failed } diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 67d287ba..edfdbec8 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -117,6 +117,7 @@ class Code(int, Enum): FINISHED = 100 RUNNING = 101 SCORE_INVALID = 102 + RESPONSE_QUALITY_ERROR = 103 @classmethod def rollout_running(cls) -> "Status": @@ -367,6 +368,13 @@ def score_invalid( """Create a status indicating the score is invalid.""" return cls(code=cls.Code.SCORE_INVALID, message=message, details=details or []) + @classmethod + def response_quality_error( + cls, message: str = "Response quality check failed", details: Optional[List[Dict[str, Any]]] = None + ) -> "Status": + """Create a status indicating the response quality check failed.""" + return cls(code=cls.Code.RESPONSE_QUALITY_ERROR, message=message, details=details or []) + def is_running(self) -> bool: """Check if the status indicates the rollout is running.""" return self.code == self.Code.RUNNING diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index 31c81167..26485f43 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -8,6 +8,7 @@ from .evaluation_test import evaluation_test from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config from .rollout_processor import RolloutProcessor +from .rollout_result_post_processor import RolloutResultPostProcessor, NoOpRolloutResultPostProcessor from .types import RolloutProcessorConfig # Conditional import for optional dependencies @@ -42,6 +43,8 @@ "ExceptionHandlerConfig", "BackoffConfig", "get_default_exception_handler_config", + "RolloutResultPostProcessor", + "NoOpRolloutResultPostProcessor", ] # Only add to __all__ if available diff --git a/eval_protocol/pytest/evaluation_test_utils.py b/eval_protocol/pytest/evaluation_test_utils.py index b4d1c218..94d6f7fe 100644 --- a/eval_protocol/pytest/evaluation_test_utils.py +++ b/eval_protocol/pytest/evaluation_test_utils.py @@ -28,6 +28,7 @@ ServerMode, ) from eval_protocol.pytest.exception_config import get_default_exception_handler_config +from eval_protocol.exceptions import ResponseQualityError import logging import json @@ -363,7 +364,18 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow: """Execute rollout for a single row with backoff retry.""" retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False}) retry_tasks = rollout_processor([row], retry_config) - return await retry_tasks[0] + result = await retry_tasks[0] + + # Apply post-processing quality checks if configured + # This must be inside the retry function so ResponseQualityError can trigger retries + if config.post_processor is not None: + try: + config.post_processor.process(result) + except ResponseQualityError as quality_error: + # Re-raise ResponseQualityError to trigger retry logic + raise quality_error + + return result async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: EvaluationRow) -> EvaluationRow: """Execute a single row task with backoff retry.""" @@ -372,6 +384,13 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu # Try original task first result = await task # pyright: ignore[reportUnknownVariableType] + # Apply post-processing quality checks if configured + if config.post_processor is not None: + try: + config.post_processor.process(result) + except ResponseQualityError as quality_error: + raise quality_error + _set_rollout_status_to_finished(result) return result # pyright: ignore[reportUnknownVariableType] @@ -384,9 +403,9 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu if is_retryable and not should_giveup: # Use shared backoff function for retryable exceptions + # Note: post-processing is handled inside execute_row_with_backoff_retry try: result = await execute_row_with_backoff_retry(row) - _set_rollout_status_to_finished(result) return result diff --git a/eval_protocol/pytest/exception_config.py b/eval_protocol/pytest/exception_config.py index e4bb1b7c..a2244b2a 100644 --- a/eval_protocol/pytest/exception_config.py +++ b/eval_protocol/pytest/exception_config.py @@ -4,7 +4,7 @@ import os from dataclasses import dataclass, field -from typing import Callable, Set, Type, Union +from typing import Callable, Dict, Set, Type, Union import backoff @@ -47,6 +47,7 @@ eval_protocol.exceptions.UnavailableError, eval_protocol.exceptions.UnauthenticatedError, eval_protocol.exceptions.ResourceExhaustedError, + eval_protocol.exceptions.ResponseQualityError, } @@ -79,7 +80,11 @@ class BackoffConfig: giveup_func: Callable[[Exception], bool] = lambda e: False def get_backoff_decorator(self, exceptions: Set[Type[Exception]]): - """Get the appropriate backoff decorator based on configuration.""" + """Get the appropriate backoff decorator based on configuration. + + Args: + exceptions: Set of exception types to retry + """ if not exceptions: # If no exceptions specified, return a no-op decorator def no_op_decorator(func): @@ -136,7 +141,9 @@ def __post_init__(self): def get_backoff_decorator(self): """Get the backoff decorator configured for this exception handler.""" - return self.backoff_config.get_backoff_decorator(self.retryable_exceptions) + return self.backoff_config.get_backoff_decorator( + self.retryable_exceptions + ) def get_default_exception_handler_config() -> ExceptionHandlerConfig: diff --git a/eval_protocol/pytest/rollout_result_post_processor.py b/eval_protocol/pytest/rollout_result_post_processor.py new file mode 100644 index 00000000..cdaa98d5 --- /dev/null +++ b/eval_protocol/pytest/rollout_result_post_processor.py @@ -0,0 +1,57 @@ +""" +Rollout result post-processing plugin for quality checks. + +This module provides an abstract base class for post-processing rollout results +to guard response quality. Post-processors can validate results and raise +ResponseQualityError if quality checks fail. +""" + +from abc import ABC, abstractmethod + +from eval_protocol.models import EvaluationRow + + +class RolloutResultPostProcessor(ABC): + """ + Abstract base class for rollout result post-processing plugins. + + Post-processors validate rollout results and can raise ResponseQualityError + if quality checks fail. This allows for customizable quality guards that + can be overridden by users. + """ + + @abstractmethod + def process(self, result: EvaluationRow) -> None: + """ + Process and validate a rollout result. + + This method should perform quality checks on the result. If quality + checks fail, it should raise ResponseQualityError with an appropriate + message. + + Args: + result: The EvaluationRow result from the rollout + + Raises: + ResponseQualityError: If quality checks fail + """ + pass + + +class NoOpRolloutResultPostProcessor(RolloutResultPostProcessor): + """ + Default no-op implementation of RolloutResultPostProcessor. + + This implementation does not perform any quality checks and always passes. + Use this as a default when no post-processing is needed. + """ + + def process(self, result: EvaluationRow) -> None: + """ + No-op implementation that does not perform any quality checks. + + Args: + result: The EvaluationRow result from the rollout + """ + pass + diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index 9603c7b9..9cf82ca9 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -11,6 +11,7 @@ from ..models import CompletionParams, EvaluationRow, Message from .exception_config import ExceptionHandlerConfig +from .rollout_result_post_processor import RolloutResultPostProcessor ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct DatasetPathParam = str @@ -75,3 +76,4 @@ class RolloutProcessorConfig: default_factory=dict ) # any additional kwargs to pass to the rollout processor exception_handler_config: ExceptionHandlerConfig | None = None # configuration for exception handling with backoff + post_processor: RolloutResultPostProcessor | None = None # optional post-processor for quality checks diff --git a/tests/test_exception_config.py b/tests/test_exception_config.py new file mode 100644 index 00000000..90db182a --- /dev/null +++ b/tests/test_exception_config.py @@ -0,0 +1,114 @@ +""" +Unit tests for exception_config module. + +Tests the BackoffConfig and ExceptionHandlerConfig classes, including: +1. Backoff decorator creation +2. Per-exception backoff overrides +3. ResponseQualityError default no-backoff configuration +4. Exception grouping to avoid double backoff +""" + +import pytest +from eval_protocol.pytest.exception_config import BackoffConfig, ExceptionHandlerConfig, DEFAULT_RETRYABLE_EXCEPTIONS +from eval_protocol.exceptions import ResponseQualityError + + +def test_backoff_config_no_exceptions(): + """Test that BackoffConfig returns no-op decorator when no exceptions specified.""" + config = BackoffConfig() + decorator = config.get_backoff_decorator(set()) + + # Should be a no-op decorator + def test_func(): + return "test" + + decorated = decorator(test_func) + assert decorated() == "test" + assert decorated is test_func # Should be the same function + + +def test_backoff_config_no_overrides(): + """Test that BackoffConfig creates a single decorator.""" + config = BackoffConfig(strategy="constant", base_delay=0.1, max_tries=2) + exceptions = {ConnectionError, TimeoutError} + + decorator = config.get_backoff_decorator(exceptions) + assert decorator is not None + + # Decorator should be callable + def test_func(): + raise ConnectionError("test") + + decorated = decorator(test_func) + assert callable(decorated) + + +def test_exception_handler_config_default_response_quality_error(): + """Test that ExceptionHandlerConfig includes ResponseQualityError by default.""" + config = ExceptionHandlerConfig() + + # ResponseQualityError should be in retryable_exceptions + assert ResponseQualityError in config.retryable_exceptions + + +def test_exception_handler_config_get_backoff_decorator(): + """Test that ExceptionHandlerConfig.get_backoff_decorator() works correctly.""" + config = ExceptionHandlerConfig() + decorator = config.get_backoff_decorator() + + assert decorator is not None + assert callable(decorator) + + # Should be able to decorate a function + def test_func(): + raise ConnectionError("test") + + decorated = decorator(test_func) + assert callable(decorated) + + +def test_backoff_config_expo_strategy(): + + """Test that BackoffConfig creates expo decorator correctly.""" + config = BackoffConfig(strategy="expo", base_delay=1.0, max_tries=2) + exceptions = {ConnectionError} + + decorator = config.get_backoff_decorator(exceptions) + assert decorator is not None + + def test_func(): + raise ConnectionError("test") + + decorated = decorator(test_func) + assert callable(decorated) + + +def test_backoff_config_constant_strategy(): + """Test that BackoffConfig creates constant decorator correctly.""" + config = BackoffConfig(strategy="constant", base_delay=0.1, max_tries=2) + exceptions = {ConnectionError} + + decorator = config.get_backoff_decorator(exceptions) + assert decorator is not None + + def test_func(): + raise ConnectionError("test") + + decorated = decorator(test_func) + assert callable(decorated) + + +def test_backoff_config_invalid_strategy(): + """Test that BackoffConfig raises ValueError for invalid strategy.""" + config = BackoffConfig(strategy="invalid", base_delay=1.0, max_tries=2) + exceptions = {ConnectionError} + + with pytest.raises(ValueError, match="Unknown backoff strategy"): + config.get_backoff_decorator(exceptions) + + +def test_exception_handler_config_response_quality_error_in_defaults(): + """Test that ResponseQualityError is in DEFAULT_RETRYABLE_EXCEPTIONS.""" + assert ResponseQualityError in DEFAULT_RETRYABLE_EXCEPTIONS + + diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index a1fcdeb6..a58cba29 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -35,6 +35,7 @@ RolloutFinishedError, RolloutRunningError, ScoreInvalidError, + ResponseQualityError, ) @@ -71,6 +72,7 @@ def test_error_status_codes_raise_exceptions(): (14, UnavailableError, "UNAVAILABLE"), (15, DataLossError, "DATA_LOSS"), (16, UnauthenticatedError, "UNAUTHENTICATED"), + (103, ResponseQualityError, "RESPONSE_QUALITY_ERROR"), ] for code, expected_exception_class, name in error_test_cases: @@ -105,6 +107,7 @@ def test_status_code_mapping_completeness(): 100, 101, 102, # Custom EP codes + 103, # ResponseQualityError ] for code in expected_codes: @@ -113,7 +116,7 @@ def test_status_code_mapping_completeness(): def test_invalid_status_codes(): """Test behavior with invalid/unknown status codes.""" - invalid_codes = [-1, 17, 99, 103, 999] + invalid_codes = [-1, 17, 99, 104, 999] for code in invalid_codes: exception = exception_for_status_code(code) @@ -181,6 +184,7 @@ def test_status_code_enum_consistency(): Status.Code.FINISHED: None, Status.Code.RUNNING: None, Status.Code.SCORE_INVALID: None, + Status.Code.RESPONSE_QUALITY_ERROR: ResponseQualityError, } for status_code_enum, expected_exception_class in status_code_mapping.items(): @@ -219,6 +223,7 @@ def test_exception_inheritance(): RolloutFinishedError, RolloutRunningError, ScoreInvalidError, + ResponseQualityError, ] for exception_class in exception_classes: @@ -273,6 +278,12 @@ def test_real_world_usage_scenarios(): "should_raise": True, "expected_exception": UnavailableError, }, + { + "status_code": 103, + "description": "Response quality check failed", + "should_raise": True, + "expected_exception": ResponseQualityError, + }, ] for scenario in scenarios: @@ -320,6 +331,7 @@ def test_exception_status_code_attributes(): (RolloutFinishedError, 100), (RolloutRunningError, 101), (ScoreInvalidError, 102), + (ResponseQualityError, 103), ] for exception_class, expected_code in expected_mappings: @@ -342,6 +354,7 @@ def test_integration_with_retry_logic(): UnavailableError, UnauthenticatedError, ResourceExhaustedError, + ResponseQualityError, ] for exception_class in our_error_exceptions: @@ -356,6 +369,7 @@ def test_exception_message_preservation(): (13, "test error", InternalError), (5, "Model xyz not found", NotFoundError), (7, "Invalid API key", PermissionDeniedError), + (103, "Quality check failed: response too repetitive", ResponseQualityError), ] for status_code, message, expected_exception_class in test_cases: diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py index 95b70faf..861793c1 100644 --- a/tests/test_retry_mechanism.py +++ b/tests/test_retry_mechanism.py @@ -6,15 +6,18 @@ # pyright: reportPrivateImportUsage=false import asyncio +import backoff from collections import Counter +from typing import Type from typing_extensions import override -from unittest.mock import Mock +from unittest.mock import Mock, patch from eval_protocol.models import EvaluateResult, EvaluationRow, Message from eval_protocol.pytest.evaluation_test import evaluation_test from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, BackoffConfig +from eval_protocol.exceptions import ResponseQualityError import litellm @@ -263,7 +266,7 @@ def custom_http_giveup(e: Exception) -> bool: return True # Give up immediately on bad requests elif isinstance(e, litellm.RateLimitError): return False # Retry rate limits with backoff - + return False # Retry everything else @@ -385,7 +388,7 @@ def test_simple_giveup_function(row: EvaluationRow) -> EvaluationRow: def test_simple_giveup_verification(): """Verify that giveup function prevents retries.""" mock_tracker = shared_processor_simple_giveup.mock_tracker - + print("\nšŸ”„ SIMPLE GIVEUP TEST ANALYSIS:") print(f" Batch calls made: {mock_tracker.batch_call.call_count}") print(f" Total row processing calls: {mock_tracker.process_row_call.call_count}") @@ -397,3 +400,86 @@ def test_simple_giveup_verification(): ) print("āœ… Simple giveup test passed! 4xx error was not retried due to giveup function.") + + +# Test 5: ResponseQualityError with no backoff (immediate retry) +class MockRolloutProcessorResponseQuality(RolloutProcessor): + """Mock processor that raises ResponseQualityError""" + + def __init__(self): + self.mock_tracker: Mock = Mock() + + @override + def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]: + self.mock_tracker.batch_call(len(rows)) + + async def process_single_row(row: EvaluationRow) -> EvaluationRow: + self.mock_tracker.process_row_call(row.execution_metadata.rollout_id) + + # Determine attempt number by counting previous calls for this rollout_id + previous_calls = [ + call for call in self.mock_tracker.process_row_call.call_args_list if call[0][0] == row.execution_metadata.rollout_id + ] + attempt_number = len(previous_calls) + + # Fail on first attempt, succeed on retry + if attempt_number == 1: + raise ResponseQualityError("Response quality check failed: too repetitive") + + return row + + tasks = [asyncio.create_task(process_single_row(row)) for row in rows] + return tasks + + +shared_processor_response_quality = MockRolloutProcessorResponseQuality() + + +@evaluation_test( + completion_params=[{"model": "gpt-4o-mini", "temperature": 0}], + input_messages=[[[Message(role="user", content="Test quality")]]], + rollout_processor=shared_processor_response_quality, + num_runs=1, + mode="pointwise", + exception_handler_config=ExceptionHandlerConfig( + backoff_config=BackoffConfig(max_tries=3), + ), +) +def test_response_quality_error_retry(row: EvaluationRow) -> EvaluationRow: + """Test that ResponseQualityError is retried (using default backoff).""" + print( + f"šŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})" + ) + score = 1.0 if row.rollout_status.is_finished() else 0.0 + row.evaluation_result = EvaluateResult(score=score) + return row + + +def test_response_quality_error_verification(): + """Verify that ResponseQualityError is retried.""" + mock_tracker = shared_processor_response_quality.mock_tracker + + print("\nšŸ”„ RESPONSE QUALITY ERROR TEST ANALYSIS:") + print(f" Batch calls made: {mock_tracker.batch_call.call_count}") + print(f" Total row processing calls: {mock_tracker.process_row_call.call_count}") + + call_args = mock_tracker.process_row_call.call_args_list + rollout_ids = [call[0][0] for call in call_args] + call_counts = Counter(rollout_ids) + + print(f" Call counts per rollout_id: {dict(call_counts)}") + + # Should have 2 calls: 1 original + 1 retry + # Note: With max_tries=3, it should retry up to 3 times, but our mock succeeds on attempt 2 + assert mock_tracker.process_row_call.call_count == 2, ( + f"Expected 2 calls (1 original + 1 retry), got {mock_tracker.process_row_call.call_count}" + ) + + # Should have exactly 1 rollout_id called twice + call_count_values = list(call_counts.values()) + assert call_count_values.count(2) == 1, ( + f"Expected 1 rollout with 2 calls, got {call_count_values}" + ) + + print("āœ… ResponseQualityError test passed! Error was retried.") +