Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions eval_protocol/proxy/proxy_core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,39 @@ async def require_auth(request: Request) -> None:
# =====================
# Chat completion routes
# =====================

# ============ Trail Management System Routes (New) ============
@app.post("/trails/{trail_id}/chat/completions")
@app.post("/v1/trails/{trail_id}/chat/completions")
@app.post("/trails/{trail_id}/v1/chat/completions")
@app.post("/project_id/{project_id}/trails/{trail_id}/chat/completions")
@app.post("/v1/project_id/{project_id}/trails/{trail_id}/chat/completions")
async def trail_chat_completion(
trail_id: str,
request: Request,
project_id: Optional[str] = None,
config: ProxyConfig = Depends(get_config),
redis_client: redis.Redis = Depends(get_redis),
_: None = Depends(require_auth),
):
"""
Trail Management System endpoint for LLM inference tracking.

Automatically injects trail_id and insertion_id as Langfuse tags for tracing.
All requests under the same trail_id can be queried together for analysis and training.
"""
params = ChatParams(
project_id=project_id,
trail_id=trail_id,
)
return await handle_chat_completion(
config=config,
redis_client=redis_client,
request=request,
params=params,
)

# ============ Legacy Evaluation System Routes ============
@app.post(
"/project_id/{project_id}/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/chat/completions"
)
Expand Down Expand Up @@ -246,6 +279,71 @@ async def chat_completion_with_project_only(
# ===============
# Traces routes
# ===============

# ============ Trail Traces Routes (New) ============
@app.get("/trails/{trail_id}/traces", response_model=LangfuseTracesResponse)
@app.get("/v1/trails/{trail_id}/traces", response_model=LangfuseTracesResponse)
@app.get("/project_id/{project_id}/trails/{trail_id}/traces", response_model=LangfuseTracesResponse)
@app.get("/v1/project_id/{project_id}/trails/{trail_id}/traces", response_model=LangfuseTracesResponse)
async def get_trail_traces(
trail_id: str,
request: Request,
params: TracesParams = Depends(get_traces_params),
project_id: Optional[str] = None,
config: ProxyConfig = Depends(get_config),
redis_client: redis.Redis = Depends(get_redis),
_: None = Depends(require_auth),
) -> LangfuseTracesResponse:
"""
Fetch all Langfuse traces for a specific trail.

Waits for all expected insertion_ids to complete before returning traces.
"""
if project_id is not None:
params.project_id = project_id
# Inject trail_id tag into query parameters
if params.tags is None:
params.tags = []
params.tags.append(f"trail_id:{trail_id}")
return await fetch_langfuse_traces(
config=config,
redis_client=redis_client,
request=request,
params=params,
)

@app.get("/trails/{trail_id}/traces/pointwise", response_model=LangfuseTracesResponse)
@app.get("/v1/trails/{trail_id}/traces/pointwise", response_model=LangfuseTracesResponse)
@app.get("/project_id/{project_id}/trails/{trail_id}/traces/pointwise", response_model=LangfuseTracesResponse)
@app.get("/v1/project_id/{project_id}/trails/{trail_id}/traces/pointwise", response_model=LangfuseTracesResponse)
async def get_trail_pointwise_trace(
trail_id: str,
request: Request,
params: TracesParams = Depends(get_traces_params),
project_id: Optional[str] = None,
config: ProxyConfig = Depends(get_config),
redis_client: redis.Redis = Depends(get_redis),
_: None = Depends(require_auth),
) -> LangfuseTracesResponse:
"""
Fetch the latest trace for a trail (UUID7 time-ordered).

Returns only the most recent trace, useful for real-time monitoring.
"""
if project_id is not None:
params.project_id = project_id
# Inject trail_id tag into query parameters
if params.tags is None:
params.tags = []
params.tags.append(f"trail_id:{trail_id}")
return await pointwise_fetch_langfuse_trace(
config=config,
redis_client=redis_client,
request=request,
params=params,
)

# ============ Legacy Traces Routes ============
@app.get("/traces", response_model=LangfuseTracesResponse)
@app.get("/v1/traces", response_model=LangfuseTracesResponse)
@app.get("/project_id/{project_id}/traces", response_model=LangfuseTracesResponse)
Expand Down
67 changes: 37 additions & 30 deletions eval_protocol/proxy/proxy_core/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,16 @@ async def _fetch_trace_list_with_retry(
) -> Any:
"""Fetch trace list with rate limit retry logic."""
list_retries = 0
rollout_id: Optional[str] = None
tracking_key: Optional[str] = None # Could be rollout_id or trail_id
if tags:
for t in tags:
if isinstance(t, str) and t.startswith("rollout_id:"):
rollout_id = t.split(":", 1)[1] if ":" in t else t
break
if isinstance(t, str):
if t.startswith("rollout_id:"):
tracking_key = t.split(":", 1)[1] if ":" in t else t
break
elif t.startswith("trail_id:"):
tracking_key = t.split(":", 1)[1] if ":" in t else t
break
while list_retries < max_retries:
try:
traces = langfuse_client.api.trace.list(
Expand Down Expand Up @@ -124,17 +128,17 @@ async def _fetch_trace_list_with_retry(
# Return 404 if we've retried max_retries
# TODO: write some tests around proxy exception handling
logger.error(
"Failed to fetch trace list after %d retries (rollout_id=%s): %s",
"Failed to fetch trace list after %d retries (tracking_key=%s): %s",
max_retries,
rollout_id,
tracking_key,
e,
)
raise HTTPException(
status_code=404, detail=f"Failed to fetch traces after {max_retries} retries: {str(e)}"
)
else:
# Catch all other exceptions
logger.error("Failed to fetch trace list (rollout_id=%s): %s", rollout_id, e)
logger.error("Failed to fetch trace list (tracking_key=%s): %s", tracking_key, e)
raise HTTPException(status_code=500, detail=f"Failed to fetch traces: {str(e)}")


Expand Down Expand Up @@ -247,16 +251,16 @@ async def fetch_langfuse_traces(

# Get expected insertion_ids from Redis for completeness checking
expected_ids: Set[str] = set()
if rollout_id:
expected_ids = get_insertion_ids(redis_client, rollout_id)
logger.info(f"Fetching traces for rollout_id '{rollout_id}', expecting {len(expected_ids)} insertion_ids")
if tracking_key:
expected_ids = get_insertion_ids(redis_client, tracking_key)
logger.info(f"Fetching traces for {tracking_label} '{tracking_key}', expecting {len(expected_ids)} insertion_ids")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Missing variable definitions in fetch_langfuse_traces function

The fetch_langfuse_traces function uses tracking_key and tracking_label variables (at lines 254, 256, 259, 263, 280, 336, 356, 373, 378, 382) but these are never defined in this function. The function only extracts rollout_id at line 226. The corresponding pointwise_fetch_langfuse_trace function was correctly updated to define both tracking_key and tracking_label from trail_id and rollout_id, but this definition block was not added to fetch_langfuse_traces. This will cause a NameError at runtime when accessing the trail traces endpoints.

Fix in Cursor Fix in Web

if not expected_ids:
logger.warning(
f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty traces."
f"No expected insertion_ids found in Redis for {tracking_label} '{tracking_key}'. Returning empty traces."
)
raise HTTPException(
status_code=500,
detail=f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty traces.",
detail=f"No expected insertion_ids found in Redis for {tracking_label} '{tracking_key}'. Returning empty traces.",
)

# Track all traces we've collected across retry attempts
Expand All @@ -265,15 +269,15 @@ async def fetch_langfuse_traces(
insertion_ids: Set[str] = set() # Insertion IDs extracted from traces (for completeness check)

for retry in range(max_retries):
# On first attempt, use rollout_id tag. On retries, target missing insertion_ids
# On first attempt, use tracking tag. On retries, target missing insertion_ids
if retry == 0:
fetch_tags = tags
else:
# Build targeted tags for missing insertion_ids
missing_ids = expected_ids - insertion_ids
fetch_tags = [f"insertion_id:{id}" for id in missing_ids]
logger.info(
f"Retry {retry}: Targeting {len(fetch_tags)} missing insertion_ids for rollout '{rollout_id}' (last5): {[id[-5:] for id in sorted(missing_ids)[:10]]}{'...' if len(missing_ids) > 10 else ''}"
f"Retry {retry}: Targeting {len(fetch_tags)} missing insertion_ids for {tracking_label} '{tracking_key}' (last5): {[id[-5:] for id in sorted(missing_ids)[:10]]}{'...' if len(missing_ids) > 10 else ''}"
)

current_page = 1
Expand Down Expand Up @@ -329,7 +333,7 @@ async def fetch_langfuse_traces(
insertion_id = _extract_tag_value(trace_dict.get("tags", []), "insertion_id:")
if insertion_id:
insertion_ids.add(insertion_id)
logger.debug(f"Found insertion_id '{insertion_id}' for rollout '{rollout_id}'")
logger.debug(f"Found insertion_id '{insertion_id}' for {tracking_label} '{tracking_key}'")

except Exception as e:
logger.warning("Failed to serialize trace %s: %s", trace_info.id, e)
Expand All @@ -349,7 +353,7 @@ async def fetch_langfuse_traces(
# If we have all expected completions or more, return traces. At least once is ok.
if expected_ids <= insertion_ids:
logger.info(
f"Traces complete for rollout '{rollout_id}': {len(insertion_ids)}/{len(expected_ids)} insertion_ids found, returning {len(all_traces)} traces"
f"Traces complete for {tracking_label} '{tracking_key}': {len(insertion_ids)}/{len(expected_ids)} insertion_ids found, returning {len(all_traces)} traces"
)
if sample_size is not None and len(all_traces) > sample_size:
all_traces = random.sample(all_traces, sample_size)
Expand All @@ -366,16 +370,16 @@ async def fetch_langfuse_traces(
wait_time = 2**retry
still_missing = expected_ids - insertion_ids
logger.info(
f"Attempt {retry + 1}/{max_retries}. Found {len(insertion_ids)}/{len(expected_ids)} for rollout '{rollout_id}'. Still missing (last5): {[id[-5:] for id in sorted(still_missing)[:10]]}{'...' if len(still_missing) > 10 else ''}. Waiting {wait_time}s..."
f"Attempt {retry + 1}/{max_retries}. Found {len(insertion_ids)}/{len(expected_ids)} for {tracking_label} '{tracking_key}'. Still missing (last5): {[id[-5:] for id in sorted(still_missing)[:10]]}{'...' if len(still_missing) > 10 else ''}. Waiting {wait_time}s..."
)
await asyncio.sleep(wait_time)

logger.error(
f"Incomplete traces for rollout_id '{rollout_id}': Found {len(insertion_ids)}/{len(expected_ids)} completions."
f"Incomplete traces for {tracking_label} '{tracking_key}': Found {len(insertion_ids)}/{len(expected_ids)} completions."
)
raise HTTPException(
status_code=404,
detail=f"Incomplete traces for rollout_id '{rollout_id}': Found {len(insertion_ids)}/{len(expected_ids)} completions.",
detail=f"Incomplete traces for {tracking_label} '{tracking_key}': Found {len(insertion_ids)}/{len(expected_ids)} completions.",
)

except ImportError:
Expand Down Expand Up @@ -431,8 +435,11 @@ async def pointwise_fetch_langfuse_trace(
detail=f"Project ID '{project_id}' not found. Available projects: {list(config.langfuse_keys.keys())}",
)

# Extract rollout_id from tags for Redis lookup
# Extract tracking key (rollout_id or trail_id) from tags for Redis lookup
rollout_id = _extract_tag_value(tags, "rollout_id:")
trail_id = _extract_tag_value(tags, "trail_id:")
tracking_key = trail_id if trail_id else rollout_id
tracking_label = "trail_id" if trail_id else "rollout_id"

try:
# Import the Langfuse adapter
Expand Down Expand Up @@ -461,23 +468,23 @@ async def pointwise_fetch_langfuse_trace(

# Get insertion_ids from Redis to find the latest one
expected_ids: Set[str] = set()
if rollout_id:
expected_ids = get_insertion_ids(redis_client, rollout_id)
if tracking_key:
expected_ids = get_insertion_ids(redis_client, tracking_key)
logger.info(
f"Pointwise fetch for rollout_id '{rollout_id}', found {len(expected_ids)} insertion_ids in Redis"
f"Pointwise fetch for {tracking_label} '{tracking_key}', found {len(expected_ids)} insertion_ids in Redis"
)
if not expected_ids:
logger.warning(
f"No insertion_ids found in Redis for rollout '{rollout_id}'. Cannot determine latest trace."
f"No insertion_ids found in Redis for {tracking_label} '{tracking_key}'. Cannot determine latest trace."
)
raise HTTPException(
status_code=500,
detail=f"No insertion_ids found in Redis for rollout '{rollout_id}'. Cannot determine latest trace.",
detail=f"No insertion_ids found in Redis for {tracking_label} '{tracking_key}'. Cannot determine latest trace.",
)

# Get the latest (last) insertion_id since UUID v7 is time-ordered
latest_insertion_id = max(expected_ids) # UUID v7 max = newest
logger.info(f"Targeting latest insertion_id: {latest_insertion_id} for rollout '{rollout_id}'")
logger.info(f"Targeting latest insertion_id: {latest_insertion_id} for {tracking_label} '{tracking_key}'")

for retry in range(max_retries):
# Fetch trace list targeting the latest insertion_id
Expand Down Expand Up @@ -513,7 +520,7 @@ async def pointwise_fetch_langfuse_trace(
if trace_full:
trace_dict = _serialize_trace_to_dict(trace_full)
logger.info(
f"Successfully fetched latest trace for rollout '{rollout_id}', insertion_id: {latest_insertion_id}"
f"Successfully fetched latest trace for {tracking_label} '{tracking_key}', insertion_id: {latest_insertion_id}"
)
return LangfuseTracesResponse(
project_id=project_id,
Expand All @@ -525,17 +532,17 @@ async def pointwise_fetch_langfuse_trace(
if retry < max_retries - 1:
wait_time = 2**retry
logger.info(
f"Pointwise fetch attempt {retry + 1}/{max_retries} failed for rollout '{rollout_id}', insertion_id: {latest_insertion_id}. Retrying in {wait_time}s..."
f"Pointwise fetch attempt {retry + 1}/{max_retries} failed for {tracking_label} '{tracking_key}', insertion_id: {latest_insertion_id}. Retrying in {wait_time}s..."
)
await asyncio.sleep(wait_time)

# After all retries failed
logger.error(
f"Failed to fetch latest trace for rollout '{rollout_id}', insertion_id: {latest_insertion_id} after {max_retries} retries"
f"Failed to fetch latest trace for {tracking_label} '{tracking_key}', insertion_id: {latest_insertion_id} after {max_retries} retries"
)
raise HTTPException(
status_code=404,
detail=f"Failed to fetch latest trace for rollout '{rollout_id}' after {max_retries} retries",
detail=f"Failed to fetch latest trace for {tracking_label} '{tracking_key}' after {max_retries} retries",
)

except ImportError:
Expand Down
31 changes: 28 additions & 3 deletions eval_protocol/proxy/proxy_core/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async def handle_chat_completion(
data, params = config.preprocess_chat_request(data, request, params)

project_id = params.project_id
trail_id = params.trail_id
rollout_id = params.rollout_id
invocation_id = params.invocation_id
experiment_id = params.experiment_id
Expand Down Expand Up @@ -70,8 +71,31 @@ async def handle_chat_completion(

# If metadata IDs are provided, add them as tags
insertion_id = None
if rollout_id is not None:
tracking_key = None # Key for Redis tracking (trail_id or rollout_id)

if trail_id is not None:
# Trail Management System: Simple tagging with just trail_id
insertion_id = str(uuid7())
tracking_key = trail_id

if "metadata" not in data:
data["metadata"] = {}
if "tags" not in data["metadata"]:
data["metadata"]["tags"] = []

# Add trail metadata as tags
data["metadata"]["tags"].extend(
[
f"trail_id:{trail_id}",
f"insertion_id:{insertion_id}",
]
)
logger.debug(f"Trail request: trail_id={trail_id}, insertion_id={insertion_id}")

elif rollout_id is not None:
# Legacy evaluation system: Complex tagging with multiple IDs
insertion_id = str(uuid7())
tracking_key = rollout_id

if "metadata" not in data:
data["metadata"] = {}
Expand All @@ -89,6 +113,7 @@ async def handle_chat_completion(
f"row_id:{row_id}",
]
)
logger.debug(f"Rollout request: rollout_id={rollout_id}, insertion_id={insertion_id}")

# Add Langfuse configuration
data["langfuse_public_key"] = config.langfuse_keys[project_id]["public_key"]
Expand All @@ -115,8 +140,8 @@ async def handle_chat_completion(
)

# Register insertion_id in Redis only on successful response
if response.status_code == 200 and insertion_id is not None and rollout_id is not None:
register_insertion_id(redis_client, rollout_id, insertion_id)
if response.status_code == 200 and insertion_id is not None and tracking_key is not None:
register_insertion_id(redis_client, tracking_key, insertion_id)

# Return the response
return Response(
Expand Down
3 changes: 3 additions & 0 deletions eval_protocol/proxy/proxy_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class ChatParams(BaseModel):
"""Typed container for chat completion URL path parameters."""

project_id: Optional[str] = None
# Trail Management System (simpler path)
trail_id: Optional[str] = None
# Legacy evaluation system (complex path)
rollout_id: Optional[str] = None
invocation_id: Optional[str] = None
experiment_id: Optional[str] = None
Expand Down
Loading
Loading