From bce982dc56341ea3023370756c2e5fdd98dca037 Mon Sep 17 00:00:00 2001 From: Yifan Xiao Date: Sun, 14 Dec 2025 14:49:37 -0800 Subject: [PATCH 1/2] commit for trail proxy setup --- eval_protocol/proxy/proxy_core/app.py | 98 ++++++++++++++++++++++ eval_protocol/proxy/proxy_core/langfuse.py | 67 ++++++++------- eval_protocol/proxy/proxy_core/litellm.py | 31 ++++++- eval_protocol/proxy/proxy_core/models.py | 3 + 4 files changed, 166 insertions(+), 33 deletions(-) diff --git a/eval_protocol/proxy/proxy_core/app.py b/eval_protocol/proxy/proxy_core/app.py index 528d467e..054d4c6c 100644 --- a/eval_protocol/proxy/proxy_core/app.py +++ b/eval_protocol/proxy/proxy_core/app.py @@ -173,6 +173,39 @@ async def require_auth(request: Request) -> None: # ===================== # Chat completion routes # ===================== + + # ============ Trail Management System Routes (New) ============ + @app.post("/trails/{trail_id}/chat/completions") + @app.post("/v1/trails/{trail_id}/chat/completions") + @app.post("/trails/{trail_id}/v1/chat/completions") + @app.post("/project_id/{project_id}/trails/{trail_id}/chat/completions") + @app.post("/v1/project_id/{project_id}/trails/{trail_id}/chat/completions") + async def trail_chat_completion( + trail_id: str, + request: Request, + project_id: Optional[str] = None, + config: ProxyConfig = Depends(get_config), + redis_client: redis.Redis = Depends(get_redis), + _: None = Depends(require_auth), + ): + """ + Trail Management System endpoint for LLM inference tracking. + + Automatically injects trail_id and insertion_id as Langfuse tags for tracing. + All requests under the same trail_id can be queried together for analysis and training. + """ + params = ChatParams( + project_id=project_id, + trail_id=trail_id, + ) + return await handle_chat_completion( + config=config, + redis_client=redis_client, + request=request, + params=params, + ) + + # ============ Legacy Evaluation System Routes ============ @app.post( "/project_id/{project_id}/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/chat/completions" ) @@ -246,6 +279,71 @@ async def chat_completion_with_project_only( # =============== # Traces routes # =============== + + # ============ Trail Traces Routes (New) ============ + @app.get("/trails/{trail_id}/traces", response_model=LangfuseTracesResponse) + @app.get("/v1/trails/{trail_id}/traces", response_model=LangfuseTracesResponse) + @app.get("/project_id/{project_id}/trails/{trail_id}/traces", response_model=LangfuseTracesResponse) + @app.get("/v1/project_id/{project_id}/trails/{trail_id}/traces", response_model=LangfuseTracesResponse) + async def get_trail_traces( + trail_id: str, + request: Request, + params: TracesParams = Depends(get_traces_params), + project_id: Optional[str] = None, + config: ProxyConfig = Depends(get_config), + redis_client: redis.Redis = Depends(get_redis), + _: None = Depends(require_auth), + ) -> LangfuseTracesResponse: + """ + Fetch all Langfuse traces for a specific trail. + + Waits for all expected insertion_ids to complete before returning traces. + """ + if project_id is not None: + params.project_id = project_id + # Inject trail_id tag into query parameters + if params.tags is None: + params.tags = [] + params.tags.append(f"trail_id:{trail_id}") + return await fetch_langfuse_traces( + config=config, + redis_client=redis_client, + request=request, + params=params, + ) + + @app.get("/trails/{trail_id}/traces/pointwise", response_model=LangfuseTracesResponse) + @app.get("/v1/trails/{trail_id}/traces/pointwise", response_model=LangfuseTracesResponse) + @app.get("/project_id/{project_id}/trails/{trail_id}/traces/pointwise", response_model=LangfuseTracesResponse) + @app.get("/v1/project_id/{project_id}/trails/{trail_id}/traces/pointwise", response_model=LangfuseTracesResponse) + async def get_trail_pointwise_trace( + trail_id: str, + request: Request, + params: TracesParams = Depends(get_traces_params), + project_id: Optional[str] = None, + config: ProxyConfig = Depends(get_config), + redis_client: redis.Redis = Depends(get_redis), + _: None = Depends(require_auth), + ) -> LangfuseTracesResponse: + """ + Fetch the latest trace for a trail (UUID7 time-ordered). + + Returns only the most recent trace, useful for real-time monitoring. + """ + if project_id is not None: + params.project_id = project_id + # Inject trail_id tag into query parameters + if params.tags is None: + params.tags = [] + params.tags.append(f"trail_id:{trail_id}") + return await pointwise_fetch_langfuse_trace( + config=config, + redis_client=redis_client, + request=request, + params=params, + ) + + # ============ Legacy Traces Routes ============ @app.get("/traces", response_model=LangfuseTracesResponse) @app.get("/v1/traces", response_model=LangfuseTracesResponse) @app.get("/project_id/{project_id}/traces", response_model=LangfuseTracesResponse) diff --git a/eval_protocol/proxy/proxy_core/langfuse.py b/eval_protocol/proxy/proxy_core/langfuse.py index d91da681..5520f1fa 100644 --- a/eval_protocol/proxy/proxy_core/langfuse.py +++ b/eval_protocol/proxy/proxy_core/langfuse.py @@ -76,12 +76,16 @@ async def _fetch_trace_list_with_retry( ) -> Any: """Fetch trace list with rate limit retry logic.""" list_retries = 0 - rollout_id: Optional[str] = None + tracking_key: Optional[str] = None # Could be rollout_id or trail_id if tags: for t in tags: - if isinstance(t, str) and t.startswith("rollout_id:"): - rollout_id = t.split(":", 1)[1] if ":" in t else t - break + if isinstance(t, str): + if t.startswith("rollout_id:"): + tracking_key = t.split(":", 1)[1] if ":" in t else t + break + elif t.startswith("trail_id:"): + tracking_key = t.split(":", 1)[1] if ":" in t else t + break while list_retries < max_retries: try: traces = langfuse_client.api.trace.list( @@ -124,9 +128,9 @@ async def _fetch_trace_list_with_retry( # Return 404 if we've retried max_retries # TODO: write some tests around proxy exception handling logger.error( - "Failed to fetch trace list after %d retries (rollout_id=%s): %s", + "Failed to fetch trace list after %d retries (tracking_key=%s): %s", max_retries, - rollout_id, + tracking_key, e, ) raise HTTPException( @@ -134,7 +138,7 @@ async def _fetch_trace_list_with_retry( ) else: # Catch all other exceptions - logger.error("Failed to fetch trace list (rollout_id=%s): %s", rollout_id, e) + logger.error("Failed to fetch trace list (tracking_key=%s): %s", tracking_key, e) raise HTTPException(status_code=500, detail=f"Failed to fetch traces: {str(e)}") @@ -247,16 +251,16 @@ async def fetch_langfuse_traces( # Get expected insertion_ids from Redis for completeness checking expected_ids: Set[str] = set() - if rollout_id: - expected_ids = get_insertion_ids(redis_client, rollout_id) - logger.info(f"Fetching traces for rollout_id '{rollout_id}', expecting {len(expected_ids)} insertion_ids") + if tracking_key: + expected_ids = get_insertion_ids(redis_client, tracking_key) + logger.info(f"Fetching traces for {tracking_label} '{tracking_key}', expecting {len(expected_ids)} insertion_ids") if not expected_ids: logger.warning( - f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty traces." + f"No expected insertion_ids found in Redis for {tracking_label} '{tracking_key}'. Returning empty traces." ) raise HTTPException( status_code=500, - detail=f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty traces.", + detail=f"No expected insertion_ids found in Redis for {tracking_label} '{tracking_key}'. Returning empty traces.", ) # Track all traces we've collected across retry attempts @@ -265,7 +269,7 @@ async def fetch_langfuse_traces( insertion_ids: Set[str] = set() # Insertion IDs extracted from traces (for completeness check) for retry in range(max_retries): - # On first attempt, use rollout_id tag. On retries, target missing insertion_ids + # On first attempt, use tracking tag. On retries, target missing insertion_ids if retry == 0: fetch_tags = tags else: @@ -273,7 +277,7 @@ async def fetch_langfuse_traces( missing_ids = expected_ids - insertion_ids fetch_tags = [f"insertion_id:{id}" for id in missing_ids] logger.info( - f"Retry {retry}: Targeting {len(fetch_tags)} missing insertion_ids for rollout '{rollout_id}' (last5): {[id[-5:] for id in sorted(missing_ids)[:10]]}{'...' if len(missing_ids) > 10 else ''}" + f"Retry {retry}: Targeting {len(fetch_tags)} missing insertion_ids for {tracking_label} '{tracking_key}' (last5): {[id[-5:] for id in sorted(missing_ids)[:10]]}{'...' if len(missing_ids) > 10 else ''}" ) current_page = 1 @@ -329,7 +333,7 @@ async def fetch_langfuse_traces( insertion_id = _extract_tag_value(trace_dict.get("tags", []), "insertion_id:") if insertion_id: insertion_ids.add(insertion_id) - logger.debug(f"Found insertion_id '{insertion_id}' for rollout '{rollout_id}'") + logger.debug(f"Found insertion_id '{insertion_id}' for {tracking_label} '{tracking_key}'") except Exception as e: logger.warning("Failed to serialize trace %s: %s", trace_info.id, e) @@ -349,7 +353,7 @@ async def fetch_langfuse_traces( # If we have all expected completions or more, return traces. At least once is ok. if expected_ids <= insertion_ids: logger.info( - f"Traces complete for rollout '{rollout_id}': {len(insertion_ids)}/{len(expected_ids)} insertion_ids found, returning {len(all_traces)} traces" + f"Traces complete for {tracking_label} '{tracking_key}': {len(insertion_ids)}/{len(expected_ids)} insertion_ids found, returning {len(all_traces)} traces" ) if sample_size is not None and len(all_traces) > sample_size: all_traces = random.sample(all_traces, sample_size) @@ -366,16 +370,16 @@ async def fetch_langfuse_traces( wait_time = 2**retry still_missing = expected_ids - insertion_ids logger.info( - f"Attempt {retry + 1}/{max_retries}. Found {len(insertion_ids)}/{len(expected_ids)} for rollout '{rollout_id}'. Still missing (last5): {[id[-5:] for id in sorted(still_missing)[:10]]}{'...' if len(still_missing) > 10 else ''}. Waiting {wait_time}s..." + f"Attempt {retry + 1}/{max_retries}. Found {len(insertion_ids)}/{len(expected_ids)} for {tracking_label} '{tracking_key}'. Still missing (last5): {[id[-5:] for id in sorted(still_missing)[:10]]}{'...' if len(still_missing) > 10 else ''}. Waiting {wait_time}s..." ) await asyncio.sleep(wait_time) logger.error( - f"Incomplete traces for rollout_id '{rollout_id}': Found {len(insertion_ids)}/{len(expected_ids)} completions." + f"Incomplete traces for {tracking_label} '{tracking_key}': Found {len(insertion_ids)}/{len(expected_ids)} completions." ) raise HTTPException( status_code=404, - detail=f"Incomplete traces for rollout_id '{rollout_id}': Found {len(insertion_ids)}/{len(expected_ids)} completions.", + detail=f"Incomplete traces for {tracking_label} '{tracking_key}': Found {len(insertion_ids)}/{len(expected_ids)} completions.", ) except ImportError: @@ -431,8 +435,11 @@ async def pointwise_fetch_langfuse_trace( detail=f"Project ID '{project_id}' not found. Available projects: {list(config.langfuse_keys.keys())}", ) - # Extract rollout_id from tags for Redis lookup + # Extract tracking key (rollout_id or trail_id) from tags for Redis lookup rollout_id = _extract_tag_value(tags, "rollout_id:") + trail_id = _extract_tag_value(tags, "trail_id:") + tracking_key = trail_id if trail_id else rollout_id + tracking_label = "trail_id" if trail_id else "rollout_id" try: # Import the Langfuse adapter @@ -461,23 +468,23 @@ async def pointwise_fetch_langfuse_trace( # Get insertion_ids from Redis to find the latest one expected_ids: Set[str] = set() - if rollout_id: - expected_ids = get_insertion_ids(redis_client, rollout_id) + if tracking_key: + expected_ids = get_insertion_ids(redis_client, tracking_key) logger.info( - f"Pointwise fetch for rollout_id '{rollout_id}', found {len(expected_ids)} insertion_ids in Redis" + f"Pointwise fetch for {tracking_label} '{tracking_key}', found {len(expected_ids)} insertion_ids in Redis" ) if not expected_ids: logger.warning( - f"No insertion_ids found in Redis for rollout '{rollout_id}'. Cannot determine latest trace." + f"No insertion_ids found in Redis for {tracking_label} '{tracking_key}'. Cannot determine latest trace." ) raise HTTPException( status_code=500, - detail=f"No insertion_ids found in Redis for rollout '{rollout_id}'. Cannot determine latest trace.", + detail=f"No insertion_ids found in Redis for {tracking_label} '{tracking_key}'. Cannot determine latest trace.", ) # Get the latest (last) insertion_id since UUID v7 is time-ordered latest_insertion_id = max(expected_ids) # UUID v7 max = newest - logger.info(f"Targeting latest insertion_id: {latest_insertion_id} for rollout '{rollout_id}'") + logger.info(f"Targeting latest insertion_id: {latest_insertion_id} for {tracking_label} '{tracking_key}'") for retry in range(max_retries): # Fetch trace list targeting the latest insertion_id @@ -513,7 +520,7 @@ async def pointwise_fetch_langfuse_trace( if trace_full: trace_dict = _serialize_trace_to_dict(trace_full) logger.info( - f"Successfully fetched latest trace for rollout '{rollout_id}', insertion_id: {latest_insertion_id}" + f"Successfully fetched latest trace for {tracking_label} '{tracking_key}', insertion_id: {latest_insertion_id}" ) return LangfuseTracesResponse( project_id=project_id, @@ -525,17 +532,17 @@ async def pointwise_fetch_langfuse_trace( if retry < max_retries - 1: wait_time = 2**retry logger.info( - f"Pointwise fetch attempt {retry + 1}/{max_retries} failed for rollout '{rollout_id}', insertion_id: {latest_insertion_id}. Retrying in {wait_time}s..." + f"Pointwise fetch attempt {retry + 1}/{max_retries} failed for {tracking_label} '{tracking_key}', insertion_id: {latest_insertion_id}. Retrying in {wait_time}s..." ) await asyncio.sleep(wait_time) # After all retries failed logger.error( - f"Failed to fetch latest trace for rollout '{rollout_id}', insertion_id: {latest_insertion_id} after {max_retries} retries" + f"Failed to fetch latest trace for {tracking_label} '{tracking_key}', insertion_id: {latest_insertion_id} after {max_retries} retries" ) raise HTTPException( status_code=404, - detail=f"Failed to fetch latest trace for rollout '{rollout_id}' after {max_retries} retries", + detail=f"Failed to fetch latest trace for {tracking_label} '{tracking_key}' after {max_retries} retries", ) except ImportError: diff --git a/eval_protocol/proxy/proxy_core/litellm.py b/eval_protocol/proxy/proxy_core/litellm.py index cdd2383b..755dfc3a 100644 --- a/eval_protocol/proxy/proxy_core/litellm.py +++ b/eval_protocol/proxy/proxy_core/litellm.py @@ -36,6 +36,7 @@ async def handle_chat_completion( data, params = config.preprocess_chat_request(data, request, params) project_id = params.project_id + trail_id = params.trail_id rollout_id = params.rollout_id invocation_id = params.invocation_id experiment_id = params.experiment_id @@ -70,8 +71,31 @@ async def handle_chat_completion( # If metadata IDs are provided, add them as tags insertion_id = None - if rollout_id is not None: + tracking_key = None # Key for Redis tracking (trail_id or rollout_id) + + if trail_id is not None: + # Trail Management System: Simple tagging with just trail_id + insertion_id = str(uuid7()) + tracking_key = trail_id + + if "metadata" not in data: + data["metadata"] = {} + if "tags" not in data["metadata"]: + data["metadata"]["tags"] = [] + + # Add trail metadata as tags + data["metadata"]["tags"].extend( + [ + f"trail_id:{trail_id}", + f"insertion_id:{insertion_id}", + ] + ) + logger.debug(f"Trail request: trail_id={trail_id}, insertion_id={insertion_id}") + + elif rollout_id is not None: + # Legacy evaluation system: Complex tagging with multiple IDs insertion_id = str(uuid7()) + tracking_key = rollout_id if "metadata" not in data: data["metadata"] = {} @@ -89,6 +113,7 @@ async def handle_chat_completion( f"row_id:{row_id}", ] ) + logger.debug(f"Rollout request: rollout_id={rollout_id}, insertion_id={insertion_id}") # Add Langfuse configuration data["langfuse_public_key"] = config.langfuse_keys[project_id]["public_key"] @@ -115,8 +140,8 @@ async def handle_chat_completion( ) # Register insertion_id in Redis only on successful response - if response.status_code == 200 and insertion_id is not None and rollout_id is not None: - register_insertion_id(redis_client, rollout_id, insertion_id) + if response.status_code == 200 and insertion_id is not None and tracking_key is not None: + register_insertion_id(redis_client, tracking_key, insertion_id) # Return the response return Response( diff --git a/eval_protocol/proxy/proxy_core/models.py b/eval_protocol/proxy/proxy_core/models.py index f3b5e614..e68ca033 100644 --- a/eval_protocol/proxy/proxy_core/models.py +++ b/eval_protocol/proxy/proxy_core/models.py @@ -21,6 +21,9 @@ class ChatParams(BaseModel): """Typed container for chat completion URL path parameters.""" project_id: Optional[str] = None + # Trail Management System (simpler path) + trail_id: Optional[str] = None + # Legacy evaluation system (complex path) rollout_id: Optional[str] = None invocation_id: Optional[str] = None experiment_id: Optional[str] = None From 911dacaefa7aaca810dd167368d543b43282d18c Mon Sep 17 00:00:00 2001 From: Yifan Xiao Date: Sun, 14 Dec 2025 14:52:23 -0800 Subject: [PATCH 2/2] testing script --- eval_protocol/proxy/test_trail.py | 229 ++++++++++++++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 eval_protocol/proxy/test_trail.py diff --git a/eval_protocol/proxy/test_trail.py b/eval_protocol/proxy/test_trail.py new file mode 100644 index 00000000..af904514 --- /dev/null +++ b/eval_protocol/proxy/test_trail.py @@ -0,0 +1,229 @@ +""" +Tests for Trail Management System proxy implementation. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch +from fastapi.testclient import TestClient +import redis + +from proxy_core.models import ChatParams, ProxyConfig +from proxy_core.app import create_app +from proxy_core.auth import NoAuthProvider + + +@pytest.fixture +def mock_config(): + """Mock ProxyConfig.""" + return ProxyConfig( + litellm_url="http://mock-litellm:8000", + langfuse_host="https://mock-langfuse.com", + langfuse_keys={ + "test-project": { + "public_key": "pk-test", + "secret_key": "sk-test" + } + }, + default_project_id="test-project", + request_timeout=300.0 + ) + + +@pytest.fixture +def mock_redis(): + """Mock Redis client.""" + mock = Mock(spec=redis.Redis) + mock.ping.return_value = True + mock.close.return_value = None + mock.sadd = Mock() + return mock + + +@pytest.fixture +def app(mock_config, mock_redis): + """Create test app.""" + app = create_app(auth_provider=NoAuthProvider()) + app.state.config = mock_config + app.state.redis = mock_redis + return app + + +@pytest.fixture +def client(app): + """Create test client.""" + return TestClient(app) + + +class TestTrailModels: + """Test data models.""" + + def test_chat_params_trail_id(self): + """ChatParams accepts trail_id.""" + params = ChatParams(trail_id="test-trail-123", project_id="my-project") + assert params.trail_id == "test-trail-123" + assert params.project_id == "my-project" + assert params.rollout_id is None + + def test_chat_params_backward_compatibility(self): + """ChatParams still works with rollout_id.""" + params = ChatParams( + rollout_id="rollout-123", + invocation_id="inv-1", + experiment_id="exp-1", + run_id="run-1", + row_id="row-1" + ) + assert params.rollout_id == "rollout-123" + assert params.trail_id is None + + +class TestTrailRoutes: + """Test trail routes.""" + + def test_trail_chat_routes_registered(self, client): + """Trail chat completion routes exist.""" + routes = [route.path for route in client.app.routes] + assert "/trails/{trail_id}/chat/completions" in routes + assert "/v1/trails/{trail_id}/chat/completions" in routes + assert "/project_id/{project_id}/trails/{trail_id}/chat/completions" in routes + + def test_trail_traces_routes_registered(self, client): + """Trail traces routes exist.""" + routes = [route.path for route in client.app.routes] + assert "/trails/{trail_id}/traces" in routes + assert "/v1/trails/{trail_id}/traces" in routes + assert "/trails/{trail_id}/traces/pointwise" in routes + + def test_legacy_routes_preserved(self, client): + """Legacy rollout routes still exist.""" + routes_str = " ".join([route.path for route in client.app.routes]) + assert "rollout_id" in routes_str + assert "invocation_id" in routes_str + + def test_health_endpoint(self, client): + """Health endpoint works.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + +class TestTrailTagInjection: + """Test tag injection logic.""" + + @pytest.mark.asyncio + async def test_trail_simple_tags(self, mock_config, mock_redis): + """Trail requests inject simple tags (2 tags).""" + from proxy_core.litellm import handle_chat_completion + from fastapi import Request + + mock_request = Mock(spec=Request) + mock_request.headers = {"authorization": "Bearer test-key"} + mock_request.body = AsyncMock(return_value=b'{"model": "test", "messages": []}') + + params = ChatParams(trail_id="test-trail-123") + + with patch('proxy_core.litellm.httpx.AsyncClient') as mock_client: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.content = b'{"choices": []}' + mock_response.headers = {} + + mock_post = AsyncMock(return_value=mock_response) + mock_client.return_value.__aenter__.return_value.post = mock_post + + await handle_chat_completion(mock_config, mock_redis, mock_request, params) + + sent_data = mock_post.call_args.kwargs['json'] + tags = sent_data['metadata']['tags'] + + # Trail system: only 2 tags + assert len(tags) == 2 + trail_tags = [t for t in tags if t.startswith('trail_id:')] + assert len(trail_tags) == 1 + assert trail_tags[0] == 'trail_id:test-trail-123' + + insertion_tags = [t for t in tags if t.startswith('insertion_id:')] + assert len(insertion_tags) == 1 + + @pytest.mark.asyncio + async def test_rollout_complex_tags(self, mock_config, mock_redis): + """Rollout requests inject complex tags (6 tags) - backward compat.""" + from proxy_core.litellm import handle_chat_completion + from fastapi import Request + + mock_request = Mock(spec=Request) + mock_request.headers = {"authorization": "Bearer test-key"} + mock_request.body = AsyncMock(return_value=b'{"model": "test", "messages": []}') + + params = ChatParams( + rollout_id="rollout-123", + invocation_id="inv-1", + experiment_id="exp-1", + run_id="run-1", + row_id="row-1" + ) + + with patch('proxy_core.litellm.httpx.AsyncClient') as mock_client: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.content = b'{"choices": []}' + mock_response.headers = {} + + mock_post = AsyncMock(return_value=mock_response) + mock_client.return_value.__aenter__.return_value.post = mock_post + + await handle_chat_completion(mock_config, mock_redis, mock_request, params) + + sent_data = mock_post.call_args.kwargs['json'] + tags = sent_data['metadata']['tags'] + + # Legacy system: 6 tags + assert len(tags) == 6 + tag_prefixes = [t.split(':')[0] for t in tags] + assert 'rollout_id' in tag_prefixes + assert 'invocation_id' in tag_prefixes + assert 'experiment_id' in tag_prefixes + + +class TestRedisTracking: + """Test Redis tracking.""" + + @pytest.mark.asyncio + async def test_redis_uses_trail_id_as_key(self, mock_config, mock_redis): + """Redis uses trail_id as key.""" + from proxy_core.litellm import handle_chat_completion + from fastapi import Request + + mock_request = Mock(spec=Request) + mock_request.headers = {"authorization": "Bearer test-key"} + mock_request.body = AsyncMock(return_value=b'{"model": "test", "messages": []}') + + params = ChatParams(trail_id="my-trail-456") + + with patch('proxy_core.litellm.httpx.AsyncClient') as mock_client: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.content = b'{"choices": []}' + mock_response.headers = {} + + mock_post = AsyncMock(return_value=mock_response) + mock_client.return_value.__aenter__.return_value.post = mock_post + + await handle_chat_completion(mock_config, mock_redis, mock_request, params) + + # Verify Redis sadd was called with trail_id + assert mock_redis.sadd.called + call_args = mock_redis.sadd.call_args[0] + assert call_args[0] == "my-trail-456" + + # Second arg should be insertion_id + insertion_id = call_args[1] + assert isinstance(insertion_id, str) + assert len(insertion_id) > 0 + + + + + +