From 12fcae3b891978b88feb657b12430c8df86f36aa Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Tue, 25 Nov 2025 14:53:00 -0800 Subject: [PATCH] support for tokenids logprobs --- .../default_single_turn_rollout_process.py | 47 ++++++++++++++++ tests/pytest/test_pytest_input_messages.py | 56 +++++++++++++++++++ 2 files changed, 103 insertions(+) diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index df40f01e..7633600a 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -170,6 +170,53 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: ) ) + # Synchronously extract token_ids, routing_matrix, and logprobs from the provider response. + try: + token_ids = [] + routing_matrix = [] + logprobs_obj = getattr(response.choices[0], "logprobs", None) + + if logprobs_obj is not None: + if isinstance(logprobs_obj, dict): + content = logprobs_obj.get("content", []) + else: + content = getattr(logprobs_obj, "content", []) + + if isinstance(content, list): + for item in content: + if isinstance(item, dict): + tid = item.get("token_id") + rm = item.get("routing_matrix") + else: + tid = getattr(item, "token_id", None) + rm = getattr(item, "routing_matrix", None) + + if tid is not None: + token_ids.append(tid) + if rm is not None: + routing_matrix.append(rm) + + logger.info( + "[SingleTurnRolloutProcessor] Extracted %d token_ids and %d routing_matrix entries from logprobs", + len(token_ids), + len(routing_matrix), + ) + + # Store as 1D lists directly for SingleTurn (no step dimension needed) + if token_ids or routing_matrix or logprobs_obj is not None: + if not row.execution_metadata.extra: + row.execution_metadata.extra = {} + if token_ids: + row.execution_metadata.extra["token_ids"] = token_ids + if routing_matrix: + row.execution_metadata.extra["routing_matrix"] = routing_matrix + if logprobs_obj is not None: + row.execution_metadata.extra["logprobs"] = logprobs_obj + except Exception as e: + logger.warning( + "[SingleTurnRolloutProcessor] Failed to extract token_ids/routing_matrix/logprobs: %s", e + ) + row.messages = messages row.execution_metadata.duration_seconds = time.perf_counter() - start_time diff --git a/tests/pytest/test_pytest_input_messages.py b/tests/pytest/test_pytest_input_messages.py index f545f0f5..7773ae8f 100644 --- a/tests/pytest/test_pytest_input_messages.py +++ b/tests/pytest/test_pytest_input_messages.py @@ -22,3 +22,59 @@ def test_input_messages_in_decorator(rows: List[EvaluationRow]) -> List[Evaluati for row in rows: row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result") return rows + + +@pytest.mark.parametrize( + "completion_params", + [ + { + "model": "fireworks_ai/accounts/fireworks/models/qwen3-30b-a3b", + "logprobs": True, + # "include_routing_matrix": True, # Requires --enable-moe-stats on server + "temperature": 0.6, + "max_tokens": 256, + } + ], +) +@evaluation_test( + input_messages=[ + [ + [ + Message(role="user", content="What is 2+2?"), + ] + ] + ], + rollout_processor=SingleTurnRolloutProcessor(), + mode="all", +) +def test_single_turn_with_logprobs_and_routing_matrix(rows: List[EvaluationRow]) -> List[EvaluationRow]: + """Test SingleTurnRolloutProcessor with logprobs and routing_matrix extraction.""" + for row in rows: + # Check if extra metadata was extracted + extra = row.execution_metadata.extra + print("\n=== DEBUG: execution_metadata.extra ===") + print(f"extra type: {type(extra)}") + print(f"extra keys: {extra.keys() if isinstance(extra, dict) else 'N/A'}") + + if isinstance(extra, dict): + if "token_ids" in extra: + token_ids = extra["token_ids"] + print(f"token_ids: found, len={len(token_ids)}, first 10 ids={token_ids[:10]}") + else: + print("token_ids: NOT FOUND") + + if "routing_matrix" in extra: + routing_matrix = extra["routing_matrix"] + print(f"routing_matrix: found, len={len(routing_matrix)}") + else: + print("routing_matrix: NOT FOUND") + + if "logprobs" in extra: + print("logprobs: found") + else: + print("logprobs: NOT FOUND") + + print("=" * 50) + + row.evaluation_result = EvaluateResult(score=1.0, reason="Test passed") + return rows