From 9f6aa7b61dea90fdba1ba164115fd51ed587841a Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Thu, 6 Nov 2025 13:28:34 -0800 Subject: [PATCH 01/13] added model quality gha --- .github/workflows/streaming_compliance.yml | 117 ++++++ .../test_glm_streaming_compliance.py | 341 ++++++++++++++++++ eval_protocol/models.py | 10 + .../default_single_turn_rollout_process.py | 7 + 4 files changed, 475 insertions(+) create mode 100644 .github/workflows/streaming_compliance.yml create mode 100644 eval_protocol/benchmarks/test_glm_streaming_compliance.py diff --git a/.github/workflows/streaming_compliance.yml b/.github/workflows/streaming_compliance.yml new file mode 100644 index 00000000..c317dcdc --- /dev/null +++ b/.github/workflows/streaming_compliance.yml @@ -0,0 +1,117 @@ +name: Streaming Compliance Benchmark + +on: + workflow_dispatch: + inputs: + model: + description: "Model id" + required: true + default: "fireworks_ai/accounts/fireworks/models/glm-4p6" + max_tokens: + description: "Override max_tokens (integer)" + required: false + default: "" + reasoning_effort: + description: "Reasoning effort (low|medium|high|none)" + required: false + default: "" + max_rows: + description: "Max rows for smoke vs full run (integer or 'all')" + required: false + default: "" + temperature: + description: "Temperature (float)" + required: false + default: "" + stream: + description: "Enable streaming (true or empty)" + required: false + default: "true" + max_concurrency: + description: "Max concurrency (integer)" + required: false + default: "" + num_runs: + description: "Number of runs (integer)" + required: false + default: "" + max_retry: + description: "Max retry (integer)" + required: false + default: "" + success_threshold: + description: "Minimum test score needed to pass (float)" + required: false + default: "" + +jobs: + streaming-compliance: + runs-on: 8-core-32gb-ubuntu + timeout-minutes: 180 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Setup uv and .venv + run: | + python -m pip install --upgrade pip + pip install uv + uv venv + . .venv/bin/activate + uv pip install --upgrade pip + + - name: Install python-sdk package + run: | + . .venv/bin/activate + uv pip install . + + - name: Run streaming compliance benchmark (pytest) + env: + FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} + FIREWORKS_ACCOUNT_ID: ${{ vars.FIREWORKS_ACCOUNT_ID }} + run: | + . .venv/bin/activate + mkdir -p artifacts + + MODEL="${{ github.event.inputs.model }}" + MAX_TOKENS="${{ github.event.inputs.max_tokens }}" + REASONING="${{ github.event.inputs.reasoning_effort }}" + MAX_ROWS="${{ github.event.inputs.max_rows }}" + TEMPERATURE="${{ github.event.inputs.temperature }}" + STREAM="${{ github.event.inputs.stream }}" + NUM_RUNS="${{ github.event.inputs.num_runs }}" + MAX_CONC="${{ github.event.inputs.max_concurrency }}" + MAX_RETRY="${{ github.event.inputs.max_retry }}" + SUCCESS_THRESHOLD="${{ github.event.inputs.success_threshold }}" + + echo "Running streaming compliance with reasoning_effort=${REASONING:-} max_rows=${MAX_ROWS:-} model=${MODEL:-} max_tokens=${MAX_TOKENS:-} temperature=${TEMPERATURE:-} stream=${STREAM:-} num_runs=${NUM_RUNS:-} max_concurrency=${MAX_CONC:-} max_retry=${MAX_RETRY:-} success_threshold=${SUCCESS_THRESHOLD:-}" + + PYTEST_TARGET=eval_protocol.benchmarks.test_glm_streaming_compliance + PYTEST_ARGS="--pyargs $PYTEST_TARGET -q -s --ep-print-summary --ep-summary-json artifacts/streaming_compliance.json" + [ -n "$MAX_ROWS" ] && PYTEST_ARGS="$PYTEST_ARGS --ep-max-rows=$MAX_ROWS" + [ -n "$REASONING" ] && PYTEST_ARGS="$PYTEST_ARGS --ep-reasoning-effort=$REASONING" + [ -n "$MODEL" ] && PYTEST_ARGS="$PYTEST_ARGS --ep-input-param model=$MODEL" + [ -n "$MAX_TOKENS" ] && PYTEST_ARGS="$PYTEST_ARGS --ep-input-param max_tokens=$MAX_TOKENS" + [ -n "$TEMPERATURE" ] && PYTEST_ARGS="$PYTEST_ARGS --ep-input-param temperature=$TEMPERATURE" + [ -n "$STREAM" ] && PYTEST_ARGS="$PYTEST_ARGS --ep-input-param stream=$STREAM" + [ -n "$NUM_RUNS" ] && PYTEST_ARGS="$PYTEST_ARGS --ep-num-runs=$NUM_RUNS" + [ -n "$MAX_CONC" ] && PYTEST_ARGS="$PYTEST_ARGS --ep-max-concurrent-rollouts=$MAX_CONC" + [ -n "$MAX_RETRY" ] && PYTEST_ARGS="$PYTEST_ARGS --ep-max-retry=$MAX_RETRY" + [ -n "$SUCCESS_THRESHOLD" ] && PYTEST_ARGS="$PYTEST_ARGS --ep-success-threshold=$SUCCESS_THRESHOLD" + echo "Running: pytest $PYTEST_ARGS" + pytest $PYTEST_ARGS + + - name: Upload JSON artifact(s) + if: always() + uses: actions/upload-artifact@v4 + with: + name: streaming_compliance_json + path: artifacts/*.json + if-no-files-found: warn + retention-days: 14 diff --git a/eval_protocol/benchmarks/test_glm_streaming_compliance.py b/eval_protocol/benchmarks/test_glm_streaming_compliance.py new file mode 100644 index 00000000..e11f6053 --- /dev/null +++ b/eval_protocol/benchmarks/test_glm_streaming_compliance.py @@ -0,0 +1,341 @@ +"""Benchmarks for GLM streaming regressions (structured output + tool calls).""" + +from __future__ import annotations + +import json +from typing import Any + +from eval_protocol.models import ( + EvaluateResult, + EvaluationRow, + Message, + MetricResult, + ChatCompletionContentPartTextParam, +) +from eval_protocol.pytest.default_single_turn_rollout_process import ( + SingleTurnRolloutProcessor, +) +from eval_protocol.pytest.evaluation_test import evaluation_test + + +DEFAULT_MODEL_ID = "fireworks_ai/accounts/fireworks/models/glm-4p6" +DEFAULT_MAX_TOKENS = 1024 + + +def _coerce_content_to_str( + content: str | list[ChatCompletionContentPartTextParam] | None, +) -> str: + if isinstance(content, list): + texts: list[str] = [] + for part in content: + text_val = getattr(part, "text", None) + if text_val: + texts.append(text_val) + return "".join(texts) + if content is None: + return "" + return str(content) + + +def _safe_json_loads(payload: str) -> Any | None: + if not payload: + return None + try: + return json.loads(payload) + except Exception: + return None + + +STRUCTURED_SYSTEM_PROMPT = "You are a weather assistant. Respond with a JSON object matching the provided schema." + +STRUCTURED_RESPONSE_FORMAT = { + "type": "json_object", + "schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City or location name", + "enum": ["London", "New York"], + }, + "temperature": { + "type": "number", + "description": "Temperature in Celsius", + }, + "conditions": { + "type": "string", + "description": "Weather conditions description", + }, + }, + "required": ["location", "temperature", "conditions"], + }, +} + +STRUCTURED_OUTPUT_ROW = EvaluationRow( + messages=[ + Message(role="system", content=STRUCTURED_SYSTEM_PROMPT), + Message(role="user", content="What is the weather like in London?"), + ] +) +STRUCTURED_OUTPUT_ROW.input_metadata.dataset_info = { + "case": "glm-structured-output-streaming", +} + + +TOOL_SYSTEM_PROMPT = ( + "You are a weather assistant. If tools are available, always call them to gather data before responding." +) + +WEATHER_TOOL_DEFINITION = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "additionalProperties": False, + "required": ["location", "unit"], + }, + }, + } +] + +TOOL_CALL_ROW = EvaluationRow( + messages=[ + Message(role="system", content=TOOL_SYSTEM_PROMPT), + Message(role="user", content="What is the weather like in Boston in fahrenheit?"), + ], + tools=WEATHER_TOOL_DEFINITION, +) +TOOL_CALL_ROW.input_metadata.dataset_info = { + "case": "glm-tool-call-streaming", +} + + +@evaluation_test( + input_rows=[[STRUCTURED_OUTPUT_ROW]], + completion_params=[ + { + "model": DEFAULT_MODEL_ID, + "stream": True, + "temperature": 1.0, + "top_p": 1.0, + "max_tokens": DEFAULT_MAX_TOKENS, + "response_format": STRUCTURED_RESPONSE_FORMAT, + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=1.0, + num_runs=1, + mode="pointwise", +) +def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: + """Ensure structured output arrives in assistant content when streaming.""" + + assistant_msg = row.last_assistant_message() + if assistant_msg is None: + row.evaluation_result = EvaluateResult( + score=0.0, + is_score_valid=False, + reason="No assistant message produced", + metrics={}, + ) + return row + + content_str = _coerce_content_to_str(assistant_msg.content) + reasoning_str = assistant_msg.reasoning_content or "" + parsed_content = _safe_json_loads(content_str) + parsed_reasoning = _safe_json_loads(reasoning_str) if reasoning_str else None + + required_fields = {"location", "temperature", "conditions"} + + content_is_json = parsed_content is not None + required_keys_present = content_is_json and required_fields <= set(parsed_content.keys()) + temperature_is_number = content_is_json and isinstance(parsed_content.get("temperature"), (int, float)) + location_valid = content_is_json and parsed_content.get("location") in {"London", "New York"} + reasoning_contains_payload = parsed_reasoning is not None + finish_reason = row.execution_metadata.finish_reason + finish_reason_expected = finish_reason == "stop" + + all_checks_passed = ( + content_str.strip() + and content_is_json + and required_keys_present + and temperature_is_number + and location_valid + and not reasoning_contains_payload + and finish_reason_expected + ) + + metrics = { + "content_is_json": MetricResult( + score=1.0 if content_is_json else 0.0, + is_score_valid=True, + reason="Assistant content parsed as JSON" if content_is_json else "Failed to parse JSON", + data={"content": content_str}, + ), + "required_keys_present": MetricResult( + score=1.0 if required_keys_present else 0.0, + is_score_valid=content_is_json, + reason=("All required keys present" if required_keys_present else "Missing required keys"), + data={"parsed_content": parsed_content}, + ), + "temperature_is_number": MetricResult( + score=1.0 if temperature_is_number else 0.0, + is_score_valid=content_is_json, + reason="Temperature is numeric" if temperature_is_number else "Temperature not numeric", + data={"temperature": parsed_content.get("temperature") if parsed_content else None}, + ), + "reasoning_contains_payload": MetricResult( + score=0.0 if reasoning_contains_payload else 1.0, + is_score_valid=True, + reason="Reasoning is empty" if not reasoning_contains_payload else "Payload leaked to reasoning", + data={"reasoning": reasoning_str}, + ), + "finish_reason_stop": MetricResult( + score=1.0 if finish_reason_expected else 0.0, + is_score_valid=True, + reason=( + "finish_reason is stop" if finish_reason_expected else f"Unexpected finish_reason: {finish_reason}" + ), + data={"finish_reason": finish_reason}, + ), + } + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason=( + "Structured output returned in assistant content" + if all_checks_passed + else "Structured output missing or malformed" + ), + metrics=metrics, + ) + return row + + +@evaluation_test( + input_rows=[[TOOL_CALL_ROW]], + completion_params=[ + { + "model": DEFAULT_MODEL_ID, + "stream": True, + "temperature": 1.0, + "top_p": 1.0, + "max_tokens": DEFAULT_MAX_TOKENS, + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=1.0, + num_runs=1, + mode="pointwise", +) +def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: + """Ensure streaming tool calls settle with finish_reason=tool_calls and a single call.""" + + assistant_msg = row.last_assistant_message() + if assistant_msg is None: + row.evaluation_result = EvaluateResult( + score=0.0, + is_score_valid=False, + reason="No assistant message produced", + metrics={}, + ) + return row + + tool_calls = assistant_msg.tool_calls or [] + finish_reason = row.execution_metadata.finish_reason + tool_call_count = row.execution_metadata.tool_call_count + + has_tool_call = len(tool_calls) > 0 + exactly_one_tool_call = len(tool_calls) == 1 + finish_reason_tool_calls = finish_reason == "tool_calls" + tool_call_count_matches = tool_call_count == len(tool_calls) + + tool_call_arguments_valid = False + if exactly_one_tool_call: + tool_call = tool_calls[0] + function_block = tool_call.get("function", {}) if isinstance(tool_call, dict) else {} + arguments_payload = function_block.get("arguments") + parsed_arguments = _safe_json_loads(arguments_payload) if isinstance(arguments_payload, str) else None + tool_call_arguments_valid = ( + isinstance(parsed_arguments, dict) + and parsed_arguments.get("location", "").lower() == "boston" + and parsed_arguments.get("unit") == "fahrenheit" + ) + else: + parsed_arguments = None + + all_checks_passed = ( + has_tool_call + and exactly_one_tool_call + and finish_reason_tool_calls + and tool_call_arguments_valid + and tool_call_count_matches + ) + + metrics = { + "has_tool_call": MetricResult( + score=1.0 if has_tool_call else 0.0, + is_score_valid=True, + reason="Assistant produced at least one tool call" if has_tool_call else "No tool calls returned", + data={"tool_call_count": len(tool_calls)}, + ), + "single_tool_call": MetricResult( + score=1.0 if exactly_one_tool_call else 0.0, + is_score_valid=has_tool_call, + reason=("Exactly one tool call" if exactly_one_tool_call else "Unexpected number of tool calls"), + data={"tool_calls": tool_calls}, + ), + "finish_reason_tool_calls": MetricResult( + score=1.0 if finish_reason_tool_calls else 0.0, + is_score_valid=True, + reason=( + "finish_reason is tool_calls" + if finish_reason_tool_calls + else f"Unexpected finish_reason: {finish_reason}" + ), + data={"finish_reason": finish_reason}, + ), + "tool_call_arguments_valid": MetricResult( + score=1.0 if tool_call_arguments_valid else 0.0, + is_score_valid=exactly_one_tool_call, + reason=("Tool call arguments valid" if tool_call_arguments_valid else "Tool call arguments invalid"), + data={"arguments": parsed_arguments}, + ), + "tool_call_count_matches": MetricResult( + score=1.0 if tool_call_count_matches else 0.0, + is_score_valid=True, + reason=( + "tool_call_count matches returned calls" + if tool_call_count_matches + else f"tool_call_count mismatch (metadata={tool_call_count}, actual={len(tool_calls)})" + ), + data={"metadata_tool_call_count": tool_call_count}, + ), + } + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason=( + "Streaming tool call completed correctly" if all_checks_passed else "Streaming tool call behaviour invalid" + ), + metrics=metrics, + ) + return row diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 7180ed72..aacab1e8 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -776,6 +776,16 @@ class ExecutionMetadata(BaseModel): description="Processing duration in seconds for an entire experiment. Note that includes time it took for retries.", ) + finish_reason: Optional[str] = Field( + default=None, + description="finish_reason reported by the completion response for this row.", + ) + + tool_call_count: Optional[int] = Field( + default=None, + description="Number of tool calls returned in the assistant message for this row.", + ) + class EvaluationRow(BaseModel): """ diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 27fe2559..9b582d15 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -96,6 +96,8 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: assert isinstance(response, ModelResponse), "Response should be ModelResponse" assert isinstance(response.choices[0], Choices), "Response choice should be a Choices" + finish_reason = getattr(response.choices[0], "finish_reason", None) + assistant_content = response.choices[0].message.content or "" tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None @@ -137,6 +139,11 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: tool_calls=converted_tool_calls, ) ] + + row.execution_metadata.finish_reason = str(finish_reason) if finish_reason is not None else None + row.execution_metadata.tool_call_count = ( + len(converted_tool_calls) if converted_tool_calls is not None else 0 + ) row.execution_metadata.usage = ( CompletionUsage( # Note: LiteLLM sets usage dynamically via setattr(), not as a typed field prompt_tokens=response.usage.prompt_tokens, # pyright: ignore[reportAttributeAccessIssue] From 7d6d9050a7c8991d0f22b9053589b8998182fe78 Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Thu, 6 Nov 2025 15:30:02 -0800 Subject: [PATCH 02/13] fixes --- .github/workflows/streaming_compliance.yml | 1 + .../test_glm_streaming_compliance.py | 24 +++++++++++++++++-- .../default_single_turn_rollout_process.py | 9 +++++-- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/.github/workflows/streaming_compliance.yml b/.github/workflows/streaming_compliance.yml index c317dcdc..96beb2b8 100644 --- a/.github/workflows/streaming_compliance.yml +++ b/.github/workflows/streaming_compliance.yml @@ -1,6 +1,7 @@ name: Streaming Compliance Benchmark on: + push: workflow_dispatch: inputs: model: diff --git a/eval_protocol/benchmarks/test_glm_streaming_compliance.py b/eval_protocol/benchmarks/test_glm_streaming_compliance.py index e11f6053..01a69845 100644 --- a/eval_protocol/benchmarks/test_glm_streaming_compliance.py +++ b/eval_protocol/benchmarks/test_glm_streaming_compliance.py @@ -1,7 +1,5 @@ """Benchmarks for GLM streaming regressions (structured output + tool calls).""" -from __future__ import annotations - import json from typing import Any @@ -145,6 +143,17 @@ def _safe_json_loads(payload: str) -> Any | None: def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: """Ensure structured output arrives in assistant content when streaming.""" + import os + + print( + "DEBUG completion params", + os.environ.get("EP_COMPLETION_PARAMS_JSON"), + "key len", + len(os.environ.get("FIREWORKS_API_KEY", "")), + "acct", + os.environ.get("FIREWORKS_ACCOUNT_ID"), + ) + assistant_msg = row.last_assistant_message() if assistant_msg is None: row.evaluation_result = EvaluateResult( @@ -248,6 +257,17 @@ def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: """Ensure streaming tool calls settle with finish_reason=tool_calls and a single call.""" + import os + + print( + "DEBUG completion params", + os.environ.get("EP_COMPLETION_PARAMS_JSON"), + "key len", + len(os.environ.get("FIREWORKS_API_KEY", "")), + "acct", + os.environ.get("FIREWORKS_ACCOUNT_ID"), + ) + assistant_msg = row.last_assistant_message() if assistant_msg is None: row.evaluation_result = EvaluateResult( diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 9b582d15..c99fae6c 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -98,8 +98,12 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: finish_reason = getattr(response.choices[0], "finish_reason", None) - assistant_content = response.choices[0].message.content or "" - tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None + assistant_message = response.choices[0].message + assistant_content = assistant_message.content or "" + reasoning_content = getattr(assistant_message, "reasoning_content", None) + if reasoning_content is None: + reasoning_content = getattr(assistant_message, "reasoning", None) + tool_calls = assistant_message.tool_calls if assistant_message.tool_calls else None converted_tool_calls = None if tool_calls: @@ -136,6 +140,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: Message( role="assistant", content=assistant_content, + reasoning_content=reasoning_content, tool_calls=converted_tool_calls, ) ] From c6995324ae04b65f08ebfa6ad1d7708a41f451cc Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Thu, 6 Nov 2025 15:36:06 -0800 Subject: [PATCH 03/13] streaming --- .github/workflows/streaming_compliance.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/streaming_compliance.yml b/.github/workflows/streaming_compliance.yml index 96beb2b8..5b5723fb 100644 --- a/.github/workflows/streaming_compliance.yml +++ b/.github/workflows/streaming_compliance.yml @@ -47,7 +47,7 @@ on: jobs: streaming-compliance: - runs-on: 8-core-32gb-ubuntu + runs-on: ubuntu-latest timeout-minutes: 180 steps: From bb743b782bebf1d706fa9d28a491de8baaef73b5 Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Thu, 6 Nov 2025 15:43:46 -0800 Subject: [PATCH 04/13] fixes --- .github/workflows/streaming_compliance.yml | 1 + .../test_glm_streaming_compliance.py | 22 ------------------- 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/.github/workflows/streaming_compliance.yml b/.github/workflows/streaming_compliance.yml index 5b5723fb..3f5c2549 100644 --- a/.github/workflows/streaming_compliance.yml +++ b/.github/workflows/streaming_compliance.yml @@ -76,6 +76,7 @@ jobs: env: FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} FIREWORKS_ACCOUNT_ID: ${{ vars.FIREWORKS_ACCOUNT_ID }} + EP_DISABLE_DATASET_LOGGER: "1" run: | . .venv/bin/activate mkdir -p artifacts diff --git a/eval_protocol/benchmarks/test_glm_streaming_compliance.py b/eval_protocol/benchmarks/test_glm_streaming_compliance.py index 01a69845..07863efc 100644 --- a/eval_protocol/benchmarks/test_glm_streaming_compliance.py +++ b/eval_protocol/benchmarks/test_glm_streaming_compliance.py @@ -143,17 +143,6 @@ def _safe_json_loads(payload: str) -> Any | None: def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: """Ensure structured output arrives in assistant content when streaming.""" - import os - - print( - "DEBUG completion params", - os.environ.get("EP_COMPLETION_PARAMS_JSON"), - "key len", - len(os.environ.get("FIREWORKS_API_KEY", "")), - "acct", - os.environ.get("FIREWORKS_ACCOUNT_ID"), - ) - assistant_msg = row.last_assistant_message() if assistant_msg is None: row.evaluation_result = EvaluateResult( @@ -257,17 +246,6 @@ def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: """Ensure streaming tool calls settle with finish_reason=tool_calls and a single call.""" - import os - - print( - "DEBUG completion params", - os.environ.get("EP_COMPLETION_PARAMS_JSON"), - "key len", - len(os.environ.get("FIREWORKS_API_KEY", "")), - "acct", - os.environ.get("FIREWORKS_ACCOUNT_ID"), - ) - assistant_msg = row.last_assistant_message() if assistant_msg is None: row.evaluation_result = EvaluateResult( From 275f9924d8890607676eb361dc1753158f5d3626 Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Thu, 6 Nov 2025 15:54:12 -0800 Subject: [PATCH 05/13] fix --- .github/workflows/streaming_compliance.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/streaming_compliance.yml b/.github/workflows/streaming_compliance.yml index 3f5c2549..797ccf70 100644 --- a/.github/workflows/streaming_compliance.yml +++ b/.github/workflows/streaming_compliance.yml @@ -76,7 +76,7 @@ jobs: env: FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} FIREWORKS_ACCOUNT_ID: ${{ vars.FIREWORKS_ACCOUNT_ID }} - EP_DISABLE_DATASET_LOGGER: "1" + DISABLE_EP_SQLITE_LOG: "1" run: | . .venv/bin/activate mkdir -p artifacts From 4315f1ef0b1dad4156b5e8a8e6358cdd7389bc37 Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Thu, 6 Nov 2025 16:01:37 -0800 Subject: [PATCH 06/13] fix --- .github/workflows/streaming_compliance.yml | 1 - eval_protocol/benchmarks/test_glm_streaming_compliance.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/streaming_compliance.yml b/.github/workflows/streaming_compliance.yml index 797ccf70..5b5723fb 100644 --- a/.github/workflows/streaming_compliance.yml +++ b/.github/workflows/streaming_compliance.yml @@ -76,7 +76,6 @@ jobs: env: FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} FIREWORKS_ACCOUNT_ID: ${{ vars.FIREWORKS_ACCOUNT_ID }} - DISABLE_EP_SQLITE_LOG: "1" run: | . .venv/bin/activate mkdir -p artifacts diff --git a/eval_protocol/benchmarks/test_glm_streaming_compliance.py b/eval_protocol/benchmarks/test_glm_streaming_compliance.py index 07863efc..7f8d321e 100644 --- a/eval_protocol/benchmarks/test_glm_streaming_compliance.py +++ b/eval_protocol/benchmarks/test_glm_streaming_compliance.py @@ -239,7 +239,7 @@ def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: ], rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", - passed_threshold=1.0, + passed_threshold=0.0, num_runs=1, mode="pointwise", ) From fcc7e10ec8a07df90f077c26d7ac66803f2af711 Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Thu, 6 Nov 2025 16:04:20 -0800 Subject: [PATCH 07/13] fix --- .github/workflows/streaming_compliance.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/streaming_compliance.yml b/.github/workflows/streaming_compliance.yml index 5b5723fb..797ccf70 100644 --- a/.github/workflows/streaming_compliance.yml +++ b/.github/workflows/streaming_compliance.yml @@ -76,6 +76,7 @@ jobs: env: FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} FIREWORKS_ACCOUNT_ID: ${{ vars.FIREWORKS_ACCOUNT_ID }} + DISABLE_EP_SQLITE_LOG: "1" run: | . .venv/bin/activate mkdir -p artifacts From a1a2046b3f8241aa24565d71930354f9becc531d Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Thu, 6 Nov 2025 16:09:53 -0800 Subject: [PATCH 08/13] fix --- eval_protocol/benchmarks/test_glm_streaming_compliance.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/eval_protocol/benchmarks/test_glm_streaming_compliance.py b/eval_protocol/benchmarks/test_glm_streaming_compliance.py index 7f8d321e..5a9dfec9 100644 --- a/eval_protocol/benchmarks/test_glm_streaming_compliance.py +++ b/eval_protocol/benchmarks/test_glm_streaming_compliance.py @@ -336,4 +336,9 @@ def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: ), metrics=metrics, ) + try: + row.model_dump(exclude_none=True, mode="json") + except Exception as exc: # pragma: no cover - debug helper + print("DEBUG model_dump failure", exc, row.messages) + return row From cf0ab9d98da1d2b9a4bbfb06088240ce68a637d9 Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Thu, 6 Nov 2025 16:20:54 -0800 Subject: [PATCH 09/13] fix --- .github/workflows/streaming_compliance.yml | 1 - eval_protocol/benchmarks/test_glm_streaming_compliance.py | 5 ----- eval_protocol/pytest/default_single_turn_rollout_process.py | 6 ++++++ 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/streaming_compliance.yml b/.github/workflows/streaming_compliance.yml index 797ccf70..5b5723fb 100644 --- a/.github/workflows/streaming_compliance.yml +++ b/.github/workflows/streaming_compliance.yml @@ -76,7 +76,6 @@ jobs: env: FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} FIREWORKS_ACCOUNT_ID: ${{ vars.FIREWORKS_ACCOUNT_ID }} - DISABLE_EP_SQLITE_LOG: "1" run: | . .venv/bin/activate mkdir -p artifacts diff --git a/eval_protocol/benchmarks/test_glm_streaming_compliance.py b/eval_protocol/benchmarks/test_glm_streaming_compliance.py index 5a9dfec9..7f8d321e 100644 --- a/eval_protocol/benchmarks/test_glm_streaming_compliance.py +++ b/eval_protocol/benchmarks/test_glm_streaming_compliance.py @@ -336,9 +336,4 @@ def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: ), metrics=metrics, ) - try: - row.model_dump(exclude_none=True, mode="json") - except Exception as exc: # pragma: no cover - debug helper - print("DEBUG model_dump failure", exc, row.messages) - return row diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index c99fae6c..ad78046d 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -1,4 +1,5 @@ import asyncio +import json import logging import os import time @@ -103,6 +104,11 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: reasoning_content = getattr(assistant_message, "reasoning_content", None) if reasoning_content is None: reasoning_content = getattr(assistant_message, "reasoning", None) + if reasoning_content is not None and not isinstance(reasoning_content, str): + try: + reasoning_content = json.dumps(reasoning_content) + except Exception: + reasoning_content = str(reasoning_content) tool_calls = assistant_message.tool_calls if assistant_message.tool_calls else None converted_tool_calls = None From 67c26198741d61ae09e9508ee9ec865c53cbc8f3 Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Thu, 6 Nov 2025 16:27:37 -0800 Subject: [PATCH 10/13] df --- .../benchmarks/test_glm_streaming_compliance.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/eval_protocol/benchmarks/test_glm_streaming_compliance.py b/eval_protocol/benchmarks/test_glm_streaming_compliance.py index 7f8d321e..6fdc94b0 100644 --- a/eval_protocol/benchmarks/test_glm_streaming_compliance.py +++ b/eval_protocol/benchmarks/test_glm_streaming_compliance.py @@ -257,6 +257,17 @@ def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: return row tool_calls = assistant_msg.tool_calls or [] + tool_calls_for_metrics: list[Any] = [] + for tc in tool_calls: + if hasattr(tc, "model_dump"): + try: + tool_calls_for_metrics.append(tc.model_dump(exclude_none=True)) + except Exception: + tool_calls_for_metrics.append(str(tc)) + elif isinstance(tc, dict): + tool_calls_for_metrics.append(tc) + else: + tool_calls_for_metrics.append(str(tc)) finish_reason = row.execution_metadata.finish_reason tool_call_count = row.execution_metadata.tool_call_count @@ -298,7 +309,7 @@ def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: score=1.0 if exactly_one_tool_call else 0.0, is_score_valid=has_tool_call, reason=("Exactly one tool call" if exactly_one_tool_call else "Unexpected number of tool calls"), - data={"tool_calls": tool_calls}, + data={"tool_calls": tool_calls_for_metrics}, ), "finish_reason_tool_calls": MetricResult( score=1.0 if finish_reason_tool_calls else 0.0, From 2900b87e6e3f8e99f43a01e259f376900404c1cb Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Mon, 17 Nov 2025 13:05:56 -0800 Subject: [PATCH 11/13] streaming ouput --- .../test_glm_streaming_compliance.py | 1363 ++++++++++++++++- 1 file changed, 1346 insertions(+), 17 deletions(-) diff --git a/eval_protocol/benchmarks/test_glm_streaming_compliance.py b/eval_protocol/benchmarks/test_glm_streaming_compliance.py index 6fdc94b0..d71b066b 100644 --- a/eval_protocol/benchmarks/test_glm_streaming_compliance.py +++ b/eval_protocol/benchmarks/test_glm_streaming_compliance.py @@ -1,6 +1,8 @@ """Benchmarks for GLM streaming regressions (structured output + tool calls).""" import json +import os +import re from typing import Any from eval_protocol.models import ( @@ -121,6 +123,543 @@ def _safe_json_loads(payload: str) -> Any | None: "case": "glm-tool-call-streaming", } +PEER_SIMPLE_STREAM_PAYLOAD = { + "messages": [ + { + "role": "user", + "content": "Write a short explanation of how binary search works. Keep it under 200 words.", + } + ], + "temperature": 0.3, + "top_p": 1, + "max_tokens": 400, +} + +PEER_JSON_STREAM_PAYLOAD = { + "messages": [{"role": "user", "content": "What is the weather like in London?"}], + "max_tokens": 25344, + "temperature": 1, + "top_p": 1, + "response_format": { + "type": "json_object", + "schema": STRUCTURED_RESPONSE_FORMAT["schema"], + }, +} + +PEER_TOOL_BRACE_PAYLOAD = { + "messages": [ + { + "role": "user", + "content": "Call test_brace_bug with param1='test_value', param2=42, and param3=true", + } + ], + "tools": WEATHER_TOOL_DEFINITION, + "temperature": 0.1, + "top_p": 1, +} + +PEER_TOOL_MULTI_PAYLOAD = { + "messages": [ + { + "role": "user", + "content": "Process data {'users': [{'name': 'John', 'age': 30}], 'total': 1} with count 5 and enabled true", + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "process_data", + "description": "Process complex data structures", + "parameters": { + "type": "object", + "properties": { + "data": {"type": "object"}, + "filters": {"type": "array"}, + "count": {"type": "integer"}, + "enabled": {"type": "boolean"}, + }, + "required": ["data"], + }, + }, + } + ], + "stream": True, + "temperature": 0.1, +} + +PEER_KIMI_MULTI_TOOL_CALLS_PAYLOAD = { + "messages": [ + { + "role": "system", + "content": ( + "You are Kimi. When the task requires multiple tools or repeated calls, you MUST emit multiple " + "tool calls in a single assistant turn. Emit ONLY tool calls (no natural language content). " + "Use one <|tool_call_begin|>...<|tool_call_end|> block per call, all wrapped inside a single " + "<|tool_calls_section_begin|>...<|tool_calls_section_end|> section. Ensure arguments are strictly " + "valid JSON and exactly match the provided tool schemas. Do not wait for tool results; just emit " + "the tool calls required by the user request." + ), + }, + { + "role": "user", + "content": ( + "For Boston and San Francisco: get the current weather in Fahrenheit for each city, and also get " + "the air quality for Boston. Emit all tool calls in a single assistant turn and nothing else." + ), + }, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location", "unit"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_air_quality", + "description": "Get the current air quality in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. Boston, MA", + } + }, + "required": ["location"], + "additionalProperties": False, + }, + }, + }, + ], + "tool_choice": "required", + "temperature": 0.2, + "top_p": 1, + "stream": True, +} + +PEER_TOOL_MISSING_REQUIRED_PARAM_PAYLOAD = { + "messages": [ + { + "role": "user", + "content": "View the file at /tmp/test.txt with view_range [160, 210]", + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "view", + "description": "View a file or directory", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the file or directory to view", + }, + "type": { + "type": "string", + "enum": ["file", "directory"], + "description": "Type of the path (file or directory)", + }, + "view_range": { + "type": "array", + "items": {"type": "integer"}, + "minItems": 2, + "maxItems": 2, + "description": "Line range to view [start, end]", + }, + }, + "required": ["path", "type"], + "additionalProperties": False, + }, + }, + } + ], + "tool_choice": "required", + "temperature": 0.1, + "stream": True, +} + +PEER_TOOL_STRING_INSTEAD_ARRAY_PAYLOAD = { + "messages": [ + { + "role": "user", + "content": "View the file at /tmp/test.txt with view_range [160, 210]", + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "view", + "description": "View a file or directory", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the file or directory to view", + }, + "type": { + "type": "string", + "enum": ["file", "directory"], + "description": "Type of the path (file or directory)", + }, + "view_range": { + "type": "array", + "items": {"type": "integer"}, + "minItems": 2, + "maxItems": 2, + "description": "Line range to view [start, end]", + }, + }, + "required": ["path", "type", "view_range"], + "additionalProperties": False, + }, + }, + } + ], + "tool_choice": "required", + "temperature": 0.1, + "stream": True, +} + +PEER_TOOL_NAMING_ERROR_PAYLOAD = { + "messages": [ + { + "role": "user", + "content": "View the file at /tmp/test.txt", + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "view", + "description": "View a file or directory", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the file or directory to view", + }, + "type": { + "type": "string", + "enum": ["file", "directory"], + "description": "Type of the path (file or directory)", + }, + }, + "required": ["path", "type"], + "additionalProperties": False, + }, + }, + } + ], + "tool_choice": "required", + "temperature": 0.1, + "stream": True, +} + +PEER_TOOL_COMMAND_EXECUTION_PAYLOAD = { + "messages": [ + { + "role": "user", + "content": "Launch a process to run 'ls -la' in the /tmp directory", + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "launch_process", + "description": "Launch a shell process", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Shell command to execute", + }, + "cwd": { + "type": "string", + "description": "Working directory for the command", + }, + "env": { + "type": "object", + "additionalProperties": {"type": "string"}, + "description": "Environment variables", + }, + }, + "required": ["command", "cwd"], + "additionalProperties": False, + }, + }, + } + ], + "tool_choice": "required", + "temperature": 0.1, + "stream": True, +} + +PEER_TOOL_PARAMETER_FORMAT_ERRORS_PAYLOAD = { + "messages": [ + { + "role": "user", + "content": "Process data with count 60, enabled true, and timeout 30", + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "process_data", + "description": "Process data with various parameters", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "count": { + "type": "integer", + "description": "Number of items to process", + }, + "enabled": { + "type": "boolean", + "description": "Whether processing is enabled", + }, + "timeout": { + "type": "number", + "description": "Timeout in seconds", + }, + "options": { + "type": "object", + "properties": { + "retry": {"type": "boolean"}, + "max_attempts": {"type": "integer"}, + }, + "description": "Processing options", + }, + }, + "required": ["count", "enabled"], + "additionalProperties": False, + }, + }, + } + ], + "tool_choice": "required", + "temperature": 0.1, + "stream": True, +} + +PEER_TOOL_RECOVERY_FAILURE_PAYLOAD = { + "messages": [ + { + "role": "user", + "content": ( + "View the file at /tmp/test.txt. If that fails, try again with the correct parameters. " + "Keep retrying until it works." + ), + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "view", + "description": "View a file or directory", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the file or directory to view", + }, + "type": { + "type": "string", + "enum": ["file", "directory"], + "description": "Type of the path (file or directory)", + }, + }, + "required": ["path", "type"], + "additionalProperties": False, + }, + }, + } + ], + "tool_choice": "required", + "temperature": 0.1, + "max_tokens": 4000, + "stream": True, +} + + +def _build_row_from_payload(case: str, payload: dict[str, Any]) -> EvaluationRow: + messages = [ + Message(role=message["role"], content=message.get("content", "")) for message in payload.get("messages", []) + ] + row = EvaluationRow(messages=messages, tools=payload.get("tools")) + row.input_metadata.dataset_info = {"case": case} + return row + + +def _build_completion_params_from_payload(payload: dict[str, Any]) -> dict[str, Any]: + params: dict[str, Any] = { + "model": DEFAULT_MODEL_ID, + "stream": True, + "return_reasoning_with_separate_field": True, + } + passthrough_keys = {"temperature", "top_p", "max_tokens", "response_format"} + for key in passthrough_keys: + if key in payload: + params[key] = payload[key] + return params + + +def _normalize_tool_call(tc: Any) -> tuple[str | None, dict[str, Any] | None]: + """Convert LiteLLM tool call objects/dicts into (name, arguments dict).""" + + record: dict[str, Any] + if hasattr(tc, "model_dump"): + try: + record = tc.model_dump(exclude_none=True) + except Exception: + return (None, None) + elif isinstance(tc, dict): + record = tc + else: + return (None, None) + + fn = record.get("function") or {} + name = fn.get("name") + args_raw = fn.get("arguments") + if isinstance(args_raw, str): + args = _safe_json_loads(args_raw) + else: + args = args_raw if isinstance(args_raw, dict) else None + return (name, args if isinstance(args, dict) else None) + + +def _collect_tool_calls(tool_calls: list[Any] | None) -> list[tuple[str | None, dict[str, Any] | None]]: + return [_normalize_tool_call(tc) for tc in (tool_calls or [])] + + +XML_TAG_PATTERN = re.compile(r"<\s*/?\s*[A-Za-z][^>]*>") +FORBIDDEN_TAG_KEYWORDS = ("think", "tool_call", "tool_calls", "tool_call_section") +DEBUG_RESPONSES = os.getenv("EP_DEBUG_LOG_RESPONSES") == "1" + + +def _unique_preserving(items: list[str]) -> list[str]: + seen: dict[str, None] = {} + for item in items: + if item not in seen: + seen[item] = None + return list(seen.keys()) + + +def _scan_xml_tags(text: str) -> list[str]: + if not text: + return [] + return _unique_preserving([match.group(0) for match in XML_TAG_PATTERN.finditer(text)]) + + +def _scan_forbidden_tags(content: str, reasoning: str) -> tuple[list[str], list[str]]: + combined = "\n".join(part for part in (content, reasoning) if part) + xml_tags = _scan_xml_tags(combined) + forbidden = [tag for tag in xml_tags if any(keyword in tag.lower() for keyword in FORBIDDEN_TAG_KEYWORDS)] + return forbidden, xml_tags + + +def _augment_metrics_with_common_checks( + metrics: dict[str, MetricResult], + finish_reason: Any, + content: str, + reasoning: str, +) -> tuple[bool, bool, bool]: + finish_reason_str = "" + if finish_reason is not None: + finish_reason_str = str(finish_reason).strip() + finish_reason_present = finish_reason_str != "" + + forbidden_tags, xml_tags = _scan_forbidden_tags(content, reasoning) + no_forbidden_tags = not forbidden_tags + no_xml_tags = not xml_tags + + metrics["finish_reason_not_null"] = MetricResult( + score=1.0 if finish_reason_present else 0.0, + is_score_valid=True, + reason="finish_reason present" if finish_reason_present else "finish_reason missing or empty", + data={"finish_reason": finish_reason}, + ) + metrics["no_forbidden_tags"] = MetricResult( + score=1.0 if no_forbidden_tags else 0.0, + is_score_valid=True, + reason="No forbidden tags detected" if no_forbidden_tags else "Forbidden tags detected", + data={"matches": forbidden_tags, "count": len(forbidden_tags)}, + ) + metrics["no_xml_tags"] = MetricResult( + score=1.0 if no_xml_tags else 0.0, + is_score_valid=True, + reason="No XML-like tags detected" if no_xml_tags else "XML-like tags detected", + data={"matches": xml_tags, "count": len(xml_tags)}, + ) + + return finish_reason_present, no_forbidden_tags, no_xml_tags + + +def _debug_log_assistant_message(test_name: str, assistant_message: Message | None, finish_reason: Any) -> None: + if not DEBUG_RESPONSES: + return + print(f"[EP][DEBUG] test={test_name} finish_reason={finish_reason!r}") + if assistant_message is None: + print(" assistant_message: ") + return + content = _coerce_content_to_str(assistant_message.content) + reasoning = (assistant_message.reasoning_content or "").strip() if assistant_message.reasoning_content else "" + tool_calls = assistant_message.tool_calls or [] + print(f" content: {content[:400]!r}") + print(f" reasoning: {reasoning[:400]!r}") + if tool_calls: + try: + serialized = [] + for tc in tool_calls: + if hasattr(tc, "model_dump"): + serialized.append(tc.model_dump(exclude_none=True)) + elif isinstance(tc, dict): + serialized.append(tc) + else: + serialized.append(str(tc)) + print(f" tool_calls: {serialized}") + except Exception as exc: # pragma: no cover - debug helper + print(f" tool_calls: {exc!r}") + else: + print(" tool_calls: []") + @evaluation_test( input_rows=[[STRUCTURED_OUTPUT_ROW]], @@ -152,6 +691,8 @@ def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: metrics={}, ) return row + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("structured_output", assistant_msg, finish_reason) content_str = _coerce_content_to_str(assistant_msg.content) reasoning_str = assistant_msg.reasoning_content or "" @@ -160,24 +701,14 @@ def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: required_fields = {"location", "temperature", "conditions"} + content_present = bool(content_str.strip()) content_is_json = parsed_content is not None required_keys_present = content_is_json and required_fields <= set(parsed_content.keys()) temperature_is_number = content_is_json and isinstance(parsed_content.get("temperature"), (int, float)) location_valid = content_is_json and parsed_content.get("location") in {"London", "New York"} reasoning_contains_payload = parsed_reasoning is not None - finish_reason = row.execution_metadata.finish_reason finish_reason_expected = finish_reason == "stop" - all_checks_passed = ( - content_str.strip() - and content_is_json - and required_keys_present - and temperature_is_number - and location_valid - and not reasoning_contains_payload - and finish_reason_expected - ) - metrics = { "content_is_json": MetricResult( score=1.0 if content_is_json else 0.0, @@ -213,6 +744,23 @@ def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: ), } + finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + content_present + and content_is_json + and required_keys_present + and temperature_is_number + and location_valid + and not reasoning_contains_payload + and finish_reason_expected + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + ) + row.evaluation_result = EvaluateResult( score=1.0 if all_checks_passed else 0.0, is_score_valid=True, @@ -226,6 +774,158 @@ def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: return row +_SIMPLE_STREAM_ROW = _build_row_from_payload("peer-simple-stream", PEER_SIMPLE_STREAM_PAYLOAD) + + +@evaluation_test( + input_rows=[[_SIMPLE_STREAM_ROW]], + completion_params=[_build_completion_params_from_payload(PEER_SIMPLE_STREAM_PAYLOAD)], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_glm_streaming_simple_completion(row: EvaluationRow) -> EvaluationRow: + """Ensure plain streaming completion returns content without leaking reasoning.""" + + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("simple_completion", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + + has_content = bool(content_str.strip()) + finish_reason_stop = finish_reason == "stop" + reasoning_empty = reasoning_str == "" + has_tool_calls = bool(assistant_msg and assistant_msg.tool_calls) + + metrics = { + "has_content": MetricResult( + score=1.0 if has_content else 0.0, + is_score_valid=True, + reason="Assistant content present" if has_content else "Assistant content empty", + data={"preview": content_str[:120]}, + ), + "finish_reason": MetricResult( + score=1.0 if finish_reason_stop else 0.0, + is_score_valid=True, + reason="finish_reason is stop" if finish_reason_stop else f"Unexpected finish_reason: {finish_reason}", + ), + "reasoning_empty": MetricResult( + score=1.0 if reasoning_empty else 0.0, + is_score_valid=True, + reason="Reasoning is empty" if reasoning_empty else "Unexpected reasoning output", + data={"reasoning": reasoning_str}, + ), + "tool_calls_absent": MetricResult( + score=1.0 if not has_tool_calls else 0.0, + is_score_valid=True, + reason="No tool calls emitted" if not has_tool_calls else "Unexpected tool calls emitted", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + has_content + and finish_reason_stop + and reasoning_empty + and not has_tool_calls + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Simple streaming completion compliant" + if all_checks_passed + else "Simple completion failed compliance checks", + metrics=metrics, + ) + return row + + +_PEER_JSON_ROW = _build_row_from_payload("peer-json-stream", PEER_JSON_STREAM_PAYLOAD) + + +@evaluation_test( + input_rows=[[_PEER_JSON_ROW]], + completion_params=[_build_completion_params_from_payload(PEER_JSON_STREAM_PAYLOAD)], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_glm_streaming_peer_json(row: EvaluationRow) -> EvaluationRow: + """Validate peer JSON streaming payload keeps structure in assistant content.""" + + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("peer_json", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + parsed_content = _safe_json_loads(content_str) + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + + content_is_json = parsed_content is not None + includes_required = content_is_json and set(parsed_content.keys()) >= {"location", "temperature", "conditions"} + reasoning_empty = reasoning_str == "" + finish_reason_stop = finish_reason == "stop" + + metrics = { + "content_is_json": MetricResult( + score=1.0 if content_is_json else 0.0, + is_score_valid=True, + reason="Assistant content parsed as JSON" if content_is_json else "Failed to parse JSON", + data={"content": content_str}, + ), + "required_fields": MetricResult( + score=1.0 if includes_required else 0.0, + is_score_valid=content_is_json, + reason="All required fields present" if includes_required else "Missing required fields", + data={"parsed_content": parsed_content}, + ), + "reasoning_empty": MetricResult( + score=1.0 if reasoning_empty else 0.0, + is_score_valid=True, + reason="Reasoning is empty" if reasoning_empty else "Unexpected reasoning output", + data={"reasoning": reasoning_str}, + ), + "finish_reason": MetricResult( + score=1.0 if finish_reason_stop else 0.0, + is_score_valid=True, + reason="finish_reason is stop" if finish_reason_stop else f"Unexpected finish_reason: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + content_is_json + and includes_required + and reasoning_empty + and finish_reason_stop + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="JSON structure preserved" if all_checks_passed else "JSON structure missing or response malformed", + metrics=metrics, + ) + return row + + @evaluation_test( input_rows=[[TOOL_CALL_ROW]], completion_params=[ @@ -256,7 +956,11 @@ def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: ) return row - tool_calls = assistant_msg.tool_calls or [] + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("tool_call", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) + reasoning_str = (assistant_msg.reasoning_content or "").strip() + tool_calls = assistant_msg.tool_calls or [] tool_calls_for_metrics: list[Any] = [] for tc in tool_calls: if hasattr(tc, "model_dump"): @@ -268,7 +972,6 @@ def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: tool_calls_for_metrics.append(tc) else: tool_calls_for_metrics.append(str(tc)) - finish_reason = row.execution_metadata.finish_reason tool_call_count = row.execution_metadata.tool_call_count has_tool_call = len(tool_calls) > 0 @@ -284,13 +987,13 @@ def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: parsed_arguments = _safe_json_loads(arguments_payload) if isinstance(arguments_payload, str) else None tool_call_arguments_valid = ( isinstance(parsed_arguments, dict) - and parsed_arguments.get("location", "").lower() == "boston" + and ("boston" in (parsed_arguments.get("location") or "").lower()) and parsed_arguments.get("unit") == "fahrenheit" ) else: parsed_arguments = None - all_checks_passed = ( + base_checks_passed = ( has_tool_call and exactly_one_tool_call and finish_reason_tool_calls @@ -339,12 +1042,638 @@ def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: ), } + finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = base_checks_passed and finish_reason_present and no_forbidden_tags and no_xml_tags + row.evaluation_result = EvaluateResult( score=1.0 if all_checks_passed else 0.0, is_score_valid=True, - reason=( - "Streaming tool call completed correctly" if all_checks_passed else "Streaming tool call behaviour invalid" + reason="Streaming tool call completed correctly" + if all_checks_passed + else "Streaming tool call behaviour invalid", + metrics=metrics, + ) + return row + + +_PEER_TOOL_BRACE_ROW = _build_row_from_payload("peer-tool-brace-bug", PEER_TOOL_BRACE_PAYLOAD) + + +@evaluation_test( + input_rows=[[_PEER_TOOL_BRACE_ROW]], + completion_params=[_build_completion_params_from_payload(PEER_TOOL_BRACE_PAYLOAD)], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_glm_streaming_tool_brace_arguments(row: EvaluationRow) -> EvaluationRow: + """Ensure streamed tool arguments remain valid JSON (no truncated braces).""" + + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("tool_brace_arguments", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + tool_calls = assistant_msg.tool_calls or [] if assistant_msg else [] + + parsed_arguments: list[Any] = [] + valid_arguments = True + for tool_call in tool_calls: + function_block = tool_call.get("function", {}) if isinstance(tool_call, dict) else {} + arguments_payload = function_block.get("arguments") + try: + parsed = _safe_json_loads(arguments_payload) if isinstance(arguments_payload, str) else None + parsed_arguments.append(parsed) + valid_arguments = valid_arguments and isinstance(parsed, dict) + except Exception: # pragma: no cover - defensive + parsed_arguments.append(None) + valid_arguments = False + + finish_reason_ok = finish_reason in {"tool_calls", "stop"} + + metrics = { + "tool_call_present": MetricResult( + score=1.0 if tool_calls else 0.0, + is_score_valid=True, + reason="Tool calls emitted" if tool_calls else "No tool calls emitted", + data={"tool_calls": tool_calls}, + ), + "arguments_json": MetricResult( + score=1.0 if (tool_calls and valid_arguments) else 0.0, + is_score_valid=bool(tool_calls), + reason="Arguments parsed as JSON" if valid_arguments else "Arguments not valid JSON", + data={"parsed_arguments": parsed_arguments}, + ), + "finish_reason": MetricResult( + score=1.0 if finish_reason_ok else 0.0, + is_score_valid=True, + reason="finish_reason acceptable" if finish_reason_ok else f"Unexpected finish_reason: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + tool_calls + and valid_arguments + and finish_reason_ok + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Tool arguments preserved as JSON" if all_checks_passed else "Tool argument JSON malformed", + metrics=metrics, + ) + return row + + +_PEER_TOOL_MULTI_ROW = _build_row_from_payload("peer-tool-multi", PEER_TOOL_MULTI_PAYLOAD) + + +@evaluation_test( + input_rows=[[_PEER_TOOL_MULTI_ROW]], + completion_params=[_build_completion_params_from_payload(PEER_TOOL_MULTI_PAYLOAD)], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_glm_streaming_tool_multi(row: EvaluationRow) -> EvaluationRow: + """Validate complex tool arguments are preserved when streaming.""" + + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("tool_multi", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) + + valid_call = None + for name, args in calls: + if name == "process_data" and isinstance(args, dict): + valid_call = args + break + + finish_reason_ok = finish_reason in {"tool_calls", "stop"} + + metrics = { + "tool_calls_count": MetricResult( + score=1.0 if calls else 0.0, + is_score_valid=True, + reason="Tool calls emitted" if calls else "No tool calls emitted", + data={"tool_call_count": len(calls)}, + ), + "process_data_arguments_valid": MetricResult( + score=1.0 if isinstance(valid_call, dict) else 0.0, + is_score_valid=bool(calls), + reason="process_data arguments parsed" if valid_call else "process_data arguments missing/invalid", + data={"arguments": valid_call}, + ), + "finish_reason": MetricResult( + score=1.0 if finish_reason_ok else 0.0, + is_score_valid=True, + reason="finish_reason acceptable" if finish_reason_ok else f"Unexpected finish_reason: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = valid_call and finish_reason_ok and finish_reason_present and no_forbidden_tags and no_xml_tags + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="process_data call preserved" if all_checks_passed else "process_data call missing or invalid", + metrics=metrics, + ) + return row + + +_PEER_KIMI_MULTI_ROW = _build_row_from_payload("peer-kimi-multi-tool-calls", PEER_KIMI_MULTI_TOOL_CALLS_PAYLOAD) + + +@evaluation_test( + input_rows=[[_PEER_KIMI_MULTI_ROW]], + completion_params=[_build_completion_params_from_payload(PEER_KIMI_MULTI_TOOL_CALLS_PAYLOAD)], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_glm_streaming_kimi_multi_tool_calls(row: EvaluationRow) -> EvaluationRow: + """Ensure multiple tool calls survive a single streamed assistant turn.""" + + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("kimi_multi_tool_calls", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) + + names = [name for name, _ in calls if name] + has_weather = names.count("get_current_weather") >= 2 + has_air_quality = "get_air_quality" in names + + metrics = { + "tool_calls_count": MetricResult( + score=1.0 if len(calls) >= 3 else 0.0, + is_score_valid=True, + reason="Three or more tool calls" if len(calls) >= 3 else "Fewer than three tool calls", + data={"tool_calls": names}, ), + "includes_weather_calls": MetricResult( + score=1.0 if has_weather else 0.0, + is_score_valid=True, + reason="Weather tool called for multiple cities" if has_weather else "Insufficient weather tool calls", + ), + "includes_air_quality": MetricResult( + score=1.0 if has_air_quality else 0.0, + is_score_valid=True, + reason="Air quality tool called" if has_air_quality else "Air quality tool missing", + ), + "finish_reason_tool_calls": MetricResult( + score=1.0 if finish_reason == "tool_calls" else 0.0, + is_score_valid=True, + reason=( + "finish_reason is tool_calls" + if finish_reason == "tool_calls" + else f"Unexpected finish_reason: {finish_reason}" + ), + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + len(calls) >= 3 + and has_weather + and has_air_quality + and finish_reason == "tool_calls" + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Multiple tool calls emitted" + if all_checks_passed + else "Expected multi-tool response missing or invalid", + metrics=metrics, + ) + return row + + +_PEER_TOOL_MISSING_REQUIRED_ROW = _build_row_from_payload( + "peer-tool-missing-required-param", PEER_TOOL_MISSING_REQUIRED_PARAM_PAYLOAD +) + + +@evaluation_test( + input_rows=[[_PEER_TOOL_MISSING_REQUIRED_ROW]], + completion_params=[_build_completion_params_from_payload(PEER_TOOL_MISSING_REQUIRED_PARAM_PAYLOAD)], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_glm_streaming_tool_missing_required_param(row: EvaluationRow) -> EvaluationRow: + """Detect whether required parameters are omitted during streaming.""" + + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("tool_missing_required_param", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) + + missing_required = False + arguments = None + for _, args in calls: + if args: + arguments = args + missing_required = "type" not in args or args.get("type") not in {"file", "directory"} + + metrics = { + "tool_call_emitted": MetricResult( + score=1.0 if calls else 0.0, + is_score_valid=True, + reason="Tool call emitted" if calls else "No tool call emitted", + ), + "missing_required_param": MetricResult( + score=1.0 if missing_required else 0.0, + is_score_valid=bool(calls), + reason="Required parameter missing or invalid" if missing_required else "All required parameters present", + data={"arguments": arguments}, + ), + "finish_reason": MetricResult( + score=1.0 if finish_reason in {"tool_calls", "stop"} else 0.0, + is_score_valid=True, + reason="finish_reason acceptable" + if finish_reason in {"tool_calls", "stop"} + else f"Unexpected finish_reason: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = missing_required and finish_reason_present and no_forbidden_tags and no_xml_tags + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Detected missing required parameter" + if all_checks_passed + else "Required parameters satisfied or response invalid", + metrics=metrics, + ) + return row + + +_PEER_TOOL_STRING_ARRAY_ROW = _build_row_from_payload( + "peer-tool-string-instead-of-array", PEER_TOOL_STRING_INSTEAD_ARRAY_PAYLOAD +) + + +@evaluation_test( + input_rows=[[_PEER_TOOL_STRING_ARRAY_ROW]], + completion_params=[_build_completion_params_from_payload(PEER_TOOL_STRING_INSTEAD_ARRAY_PAYLOAD)], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_glm_streaming_tool_view_range_format(row: EvaluationRow) -> EvaluationRow: + """Check streamed arguments keep view_range as an integer array.""" + + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("tool_view_range_format", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) + + view_range_valid = False + view_range_value = None + for _, args in calls: + if args and "view_range" in args: + view_range_value = args["view_range"] + if ( + isinstance(view_range_value, list) + and len(view_range_value) == 2 + and all(isinstance(item, int) for item in view_range_value) + ): + view_range_valid = True + + metrics = { + "tool_call_emitted": MetricResult( + score=1.0 if calls else 0.0, + is_score_valid=True, + reason="Tool call emitted" if calls else "No tool call emitted", + ), + "view_range_valid": MetricResult( + score=1.0 if view_range_valid else 0.0, + is_score_valid=bool(calls), + reason="view_range is integer array" if view_range_valid else "view_range malformed", + data={"view_range": view_range_value}, + ), + "finish_reason": MetricResult( + score=1.0 if finish_reason in {"tool_calls", "stop"} else 0.0, + is_score_valid=True, + reason="finish_reason acceptable" + if finish_reason in {"tool_calls", "stop"} + else f"Unexpected finish_reason: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = view_range_valid and finish_reason_present and no_forbidden_tags and no_xml_tags + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="view_range array preserved" if all_checks_passed else "view_range malformed", + metrics=metrics, + ) + return row + + +_PEER_TOOL_NAMING_ROW = _build_row_from_payload("peer-tool-naming-error", PEER_TOOL_NAMING_ERROR_PAYLOAD) + + +@evaluation_test( + input_rows=[[_PEER_TOOL_NAMING_ROW]], + completion_params=[_build_completion_params_from_payload(PEER_TOOL_NAMING_ERROR_PAYLOAD)], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_glm_streaming_tool_naming(row: EvaluationRow) -> EvaluationRow: + """Confirm tool arguments include required naming fields.""" + + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("tool_naming", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) + + args_valid = False + arguments = None + for _, args in calls: + if args: + arguments = args + args_valid = "type" in args and args.get("type") in {"file", "directory"} + + metrics = { + "tool_call_emitted": MetricResult( + score=1.0 if calls else 0.0, + is_score_valid=True, + reason="Tool call emitted" if calls else "No tool call emitted", + ), + "naming_valid": MetricResult( + score=1.0 if args_valid else 0.0, + is_score_valid=bool(calls), + reason="Tool arguments include type" if args_valid else "Tool arguments missing type", + data={"arguments": arguments}, + ), + "finish_reason": MetricResult( + score=1.0 if finish_reason in {"tool_calls", "stop"} else 0.0, + is_score_valid=True, + reason="finish_reason acceptable" + if finish_reason in {"tool_calls", "stop"} + else f"Unexpected finish_reason: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = args_valid and finish_reason_present and no_forbidden_tags and no_xml_tags + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Tool naming fields intact" if all_checks_passed else "Tool naming fields missing or response invalid", + metrics=metrics, + ) + return row + + +_PEER_TOOL_COMMAND_ROW = _build_row_from_payload("peer-tool-command-execution", PEER_TOOL_COMMAND_EXECUTION_PAYLOAD) + + +@evaluation_test( + input_rows=[[_PEER_TOOL_COMMAND_ROW]], + completion_params=[_build_completion_params_from_payload(PEER_TOOL_COMMAND_EXECUTION_PAYLOAD)], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_glm_streaming_tool_command(row: EvaluationRow) -> EvaluationRow: + """Validate command execution tool receives the correct parameters.""" + + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("tool_command", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) + + command_valid = False + arguments = None + for name, args in calls: + if name == "launch_process" and args: + arguments = args + command_valid = args.get("command") == "ls -la" and args.get("cwd") == "/tmp" + + metrics = { + "launch_process_call": MetricResult( + score=1.0 if arguments else 0.0, + is_score_valid=True, + reason="launch_process called" if arguments else "launch_process missing", + data={"arguments": arguments}, + ), + "finish_reason": MetricResult( + score=1.0 if finish_reason in {"tool_calls", "stop"} else 0.0, + is_score_valid=True, + reason="finish_reason acceptable" + if finish_reason in {"tool_calls", "stop"} + else f"Unexpected finish_reason: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = command_valid and finish_reason_present and no_forbidden_tags and no_xml_tags + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Command execution arguments correct" + if all_checks_passed + else "Command execution arguments incorrect or response invalid", + metrics=metrics, + ) + return row + + +_PEER_TOOL_PARAMETER_ROW = _build_row_from_payload( + "peer-tool-parameter-format-errors", PEER_TOOL_PARAMETER_FORMAT_ERRORS_PAYLOAD +) + + +@evaluation_test( + input_rows=[[_PEER_TOOL_PARAMETER_ROW]], + completion_params=[_build_completion_params_from_payload(PEER_TOOL_PARAMETER_FORMAT_ERRORS_PAYLOAD)], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_glm_streaming_tool_parameter_formats(row: EvaluationRow) -> EvaluationRow: + """Ensure streamed parameters respect expected JSON types.""" + + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("tool_parameter_formats", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) + + args = None + types_valid = False + for name, payload_args in calls: + if name == "process_data" and isinstance(payload_args, dict): + args = payload_args + types_valid = isinstance(payload_args.get("count"), int) and isinstance(payload_args.get("enabled"), bool) + + metrics = { + "process_data_call": MetricResult( + score=1.0 if args else 0.0, + is_score_valid=True, + reason="process_data call present" if args else "process_data call missing", + ), + "types_valid": MetricResult( + score=1.0 if types_valid else 0.0, + is_score_valid=bool(args), + reason="Numeric/boolean fields have correct types" if types_valid else "Type mismatch in arguments", + data={"arguments": args}, + ), + "finish_reason": MetricResult( + score=1.0 if finish_reason in {"tool_calls", "stop"} else 0.0, + is_score_valid=True, + reason="finish_reason acceptable" + if finish_reason in {"tool_calls", "stop"} + else f"Unexpected finish_reason: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = types_valid and finish_reason_present and no_forbidden_tags and no_xml_tags + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Parameter types maintained" if all_checks_passed else "Parameter types incorrect or response invalid", + metrics=metrics, + ) + return row + + +_PEER_TOOL_RECOVERY_ROW = _build_row_from_payload("peer-tool-recovery-failure", PEER_TOOL_RECOVERY_FAILURE_PAYLOAD) + + +@evaluation_test( + input_rows=[[_PEER_TOOL_RECOVERY_ROW]], + completion_params=[_build_completion_params_from_payload(PEER_TOOL_RECOVERY_FAILURE_PAYLOAD)], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_glm_streaming_tool_recovery(row: EvaluationRow) -> EvaluationRow: + """Check whether the assistant retries tool calls when instructed to recover.""" + + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("tool_recovery", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) + reasoning = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + + multiple_attempts = len(calls) >= 2 + metrics = { + "tool_call_attempts": MetricResult( + score=1.0 if multiple_attempts else 0.0, + is_score_valid=True, + reason="Multiple tool call attempts" if multiple_attempts else "Single/no tool call attempt", + data={"tool_call_count": len(calls)}, + ), + "reasoning_present": MetricResult( + score=1.0 if reasoning else 0.0, + is_score_valid=True, + reason="Reasoning present" if reasoning else "No reasoning provided", + data={"reasoning": reasoning[:160]}, + ), + "finish_reason": MetricResult( + score=1.0 if finish_reason in {"tool_calls", "stop"} else 0.0, + is_score_valid=True, + reason="finish_reason acceptable" + if finish_reason in {"tool_calls", "stop"} + else f"Unexpected finish_reason: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning + ) + + all_checks_passed = multiple_attempts and finish_reason_present and no_forbidden_tags and no_xml_tags + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Multiple recovery attempts observed" + if all_checks_passed + else "Recovery attempts missing or response invalid", metrics=metrics, ) return row From 406ed5ba9b78176319dfe450673d73b13a8e9052 Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Tue, 18 Nov 2025 20:31:28 -0800 Subject: [PATCH 12/13] changes --- .../test_glm_streaming_compliance.py | 2022 ++++++++++++++++- .../default_single_turn_rollout_process.py | 7 +- 2 files changed, 1953 insertions(+), 76 deletions(-) diff --git a/eval_protocol/benchmarks/test_glm_streaming_compliance.py b/eval_protocol/benchmarks/test_glm_streaming_compliance.py index d71b066b..b570cfa4 100644 --- a/eval_protocol/benchmarks/test_glm_streaming_compliance.py +++ b/eval_protocol/benchmarks/test_glm_streaming_compliance.py @@ -1,4 +1,4 @@ -"""Benchmarks for GLM streaming regressions (structured output + tool calls).""" +"""Benchmarks for output streaming compliance (structured output + tool calls).""" import json import os @@ -19,7 +19,7 @@ DEFAULT_MODEL_ID = "fireworks_ai/accounts/fireworks/models/glm-4p6" -DEFAULT_MAX_TOKENS = 1024 +DEFAULT_MAX_TOKENS = 10000 def _coerce_content_to_str( @@ -188,24 +188,16 @@ def _safe_json_loads(payload: str) -> Any | None: "temperature": 0.1, } -PEER_KIMI_MULTI_TOOL_CALLS_PAYLOAD = { +MULTI_TOOL_CALLS_PAYLOAD = { "messages": [ { "role": "system", - "content": ( - "You are Kimi. When the task requires multiple tools or repeated calls, you MUST emit multiple " - "tool calls in a single assistant turn. Emit ONLY tool calls (no natural language content). " - "Use one <|tool_call_begin|>...<|tool_call_end|> block per call, all wrapped inside a single " - "<|tool_calls_section_begin|>...<|tool_calls_section_end|> section. Ensure arguments are strictly " - "valid JSON and exactly match the provided tool schemas. Do not wait for tool results; just emit " - "the tool calls required by the user request." - ), + "content": "You are a helpful assistant. When multiple tools are needed, call them all in one response.", }, { "role": "user", "content": ( - "For Boston and San Francisco: get the current weather in Fahrenheit for each city, and also get " - "the air quality for Boston. Emit all tool calls in a single assistant turn and nothing else." + "What's the weather in Boston and San Francisco (in Fahrenheit)? Also check the air quality in Boston." ), }, ], @@ -533,8 +525,9 @@ def _build_completion_params_from_payload(payload: dict[str, Any]) -> dict[str, "model": DEFAULT_MODEL_ID, "stream": True, "return_reasoning_with_separate_field": True, + "reasoning_effort": "none", # Default: no reasoning unless explicitly requested } - passthrough_keys = {"temperature", "top_p", "max_tokens", "response_format"} + passthrough_keys = {"temperature", "top_p", "max_tokens", "response_format", "reasoning_effort"} for key in passthrough_keys: if key in payload: params[key] = payload[key] @@ -595,12 +588,47 @@ def _scan_forbidden_tags(content: str, reasoning: str) -> tuple[list[str], list[ return forbidden, xml_tags +def _detect_reasoning_leakage(content: str, reasoning: str) -> list[str]: + """ + Detect thinking/reasoning phrases in content when reasoning_content exists. + + Returns list of detected thinking patterns that should be in reasoning, not content. + Only checks for very clear reasoning indicators, not common phrases. + """ + if not reasoning: + # No reasoning present, so no leakage possible + return [] + + if not content: + return [] + + # Only check for VERY clear reasoning/thinking indicators + # Avoid common phrases that might appear in normal responses + thinking_patterns = [ + r"", # XML thinking tags + r"", + r"\bStep \d+:", # "Step 1:", "Step 2:", etc. (clear reasoning structure) + r"\bThinking:", # Explicit thinking label + r"\bReasoning:", # Explicit reasoning label + r"\bMy thought process:", # Explicit meta-reasoning + r"\bLet me think", # Explicit thinking phrase (not just "Let me") + r"\bI need to think", # Explicit thinking + ] + + detected = [] + for pattern in thinking_patterns: + if re.search(pattern, content, re.IGNORECASE): + detected.append(pattern) + + return detected + + def _augment_metrics_with_common_checks( metrics: dict[str, MetricResult], finish_reason: Any, content: str, reasoning: str, -) -> tuple[bool, bool, bool]: +) -> tuple[bool, bool, bool, bool]: finish_reason_str = "" if finish_reason is not None: finish_reason_str = str(finish_reason).strip() @@ -610,6 +638,10 @@ def _augment_metrics_with_common_checks( no_forbidden_tags = not forbidden_tags no_xml_tags = not xml_tags + # Check for reasoning leakage into content + reasoning_leakage = _detect_reasoning_leakage(content, reasoning) + no_reasoning_leakage = not reasoning_leakage + metrics["finish_reason_not_null"] = MetricResult( score=1.0 if finish_reason_present else 0.0, is_score_valid=True, @@ -628,8 +660,16 @@ def _augment_metrics_with_common_checks( reason="No XML-like tags detected" if no_xml_tags else "XML-like tags detected", data={"matches": xml_tags, "count": len(xml_tags)}, ) + metrics["no_reasoning_leakage"] = MetricResult( + score=1.0 if no_reasoning_leakage else 0.0, + is_score_valid=True, # Always valid - if no reasoning, then no leakage possible (score=1.0) + reason="No thinking phrases in content" + if no_reasoning_leakage + else f"Thinking phrases detected in content: {reasoning_leakage}", + data={"detected_patterns": reasoning_leakage, "has_reasoning": bool(reasoning)}, + ) - return finish_reason_present, no_forbidden_tags, no_xml_tags + return finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage def _debug_log_assistant_message(test_name: str, assistant_message: Message | None, finish_reason: Any) -> None: @@ -671,6 +711,7 @@ def _debug_log_assistant_message(test_name: str, assistant_message: Message | No "top_p": 1.0, "max_tokens": DEFAULT_MAX_TOKENS, "response_format": STRUCTURED_RESPONSE_FORMAT, + "reasoning_effort": "none", # No reasoning expected for structured output } ], rollout_processor=SingleTurnRolloutProcessor(), @@ -679,7 +720,7 @@ def _debug_log_assistant_message(test_name: str, assistant_message: Message | No num_runs=1, mode="pointwise", ) -def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: +def test_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: """Ensure structured output arrives in assistant content when streaming.""" assistant_msg = row.last_assistant_message() @@ -744,7 +785,7 @@ def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: ), } - finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( metrics, finish_reason, content_str, reasoning_str ) @@ -759,6 +800,7 @@ def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: and finish_reason_present and no_forbidden_tags and no_xml_tags + and no_reasoning_leakage ) row.evaluation_result = EvaluateResult( @@ -786,7 +828,7 @@ def test_glm_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: num_runs=1, mode="pointwise", ) -def test_glm_streaming_simple_completion(row: EvaluationRow) -> EvaluationRow: +def test_streaming_simple_completion(row: EvaluationRow) -> EvaluationRow: """Ensure plain streaming completion returns content without leaking reasoning.""" assistant_msg = row.last_assistant_message() @@ -825,7 +867,7 @@ def test_glm_streaming_simple_completion(row: EvaluationRow) -> EvaluationRow: ), } - finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( metrics, finish_reason, content_str, reasoning_str ) @@ -837,6 +879,7 @@ def test_glm_streaming_simple_completion(row: EvaluationRow) -> EvaluationRow: and finish_reason_present and no_forbidden_tags and no_xml_tags + and no_reasoning_leakage ) row.evaluation_result = EvaluateResult( @@ -862,7 +905,7 @@ def test_glm_streaming_simple_completion(row: EvaluationRow) -> EvaluationRow: num_runs=1, mode="pointwise", ) -def test_glm_streaming_peer_json(row: EvaluationRow) -> EvaluationRow: +def test_streaming_json_preservation(row: EvaluationRow) -> EvaluationRow: """Validate peer JSON streaming payload keeps structure in assistant content.""" assistant_msg = row.last_assistant_message() @@ -903,7 +946,7 @@ def test_glm_streaming_peer_json(row: EvaluationRow) -> EvaluationRow: ), } - finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( metrics, finish_reason, content_str, reasoning_str ) @@ -915,6 +958,7 @@ def test_glm_streaming_peer_json(row: EvaluationRow) -> EvaluationRow: and finish_reason_present and no_forbidden_tags and no_xml_tags + and no_reasoning_leakage ) row.evaluation_result = EvaluateResult( @@ -935,6 +979,7 @@ def test_glm_streaming_peer_json(row: EvaluationRow) -> EvaluationRow: "temperature": 1.0, "top_p": 1.0, "max_tokens": DEFAULT_MAX_TOKENS, + "reasoning_effort": "none", # No reasoning expected for tool calls } ], rollout_processor=SingleTurnRolloutProcessor(), @@ -943,7 +988,7 @@ def test_glm_streaming_peer_json(row: EvaluationRow) -> EvaluationRow: num_runs=1, mode="pointwise", ) -def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: +def test_streaming_single_tool_call(row: EvaluationRow) -> EvaluationRow: """Ensure streaming tool calls settle with finish_reason=tool_calls and a single call.""" assistant_msg = row.last_assistant_message() @@ -980,18 +1025,15 @@ def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: tool_call_count_matches = tool_call_count == len(tool_calls) tool_call_arguments_valid = False + parsed_arguments = None if exactly_one_tool_call: - tool_call = tool_calls[0] - function_block = tool_call.get("function", {}) if isinstance(tool_call, dict) else {} - arguments_payload = function_block.get("arguments") - parsed_arguments = _safe_json_loads(arguments_payload) if isinstance(arguments_payload, str) else None + # Use helper to normalize tool call (handles both dict and Pydantic objects) + name, parsed_arguments = _normalize_tool_call(tool_calls[0]) tool_call_arguments_valid = ( isinstance(parsed_arguments, dict) and ("boston" in (parsed_arguments.get("location") or "").lower()) and parsed_arguments.get("unit") == "fahrenheit" ) - else: - parsed_arguments = None base_checks_passed = ( has_tool_call @@ -1042,11 +1084,13 @@ def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: ), } - finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( metrics, finish_reason, content_str, reasoning_str ) - all_checks_passed = base_checks_passed and finish_reason_present and no_forbidden_tags and no_xml_tags + all_checks_passed = ( + base_checks_passed and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage + ) row.evaluation_result = EvaluateResult( score=1.0 if all_checks_passed else 0.0, @@ -1071,7 +1115,7 @@ def test_glm_streaming_tool_call(row: EvaluationRow) -> EvaluationRow: num_runs=1, mode="pointwise", ) -def test_glm_streaming_tool_brace_arguments(row: EvaluationRow) -> EvaluationRow: +def test_streaming_tool_json_validity(row: EvaluationRow) -> EvaluationRow: """Ensure streamed tool arguments remain valid JSON (no truncated braces).""" assistant_msg = row.last_assistant_message() @@ -1081,18 +1125,10 @@ def test_glm_streaming_tool_brace_arguments(row: EvaluationRow) -> EvaluationRow reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" tool_calls = assistant_msg.tool_calls or [] if assistant_msg else [] - parsed_arguments: list[Any] = [] - valid_arguments = True - for tool_call in tool_calls: - function_block = tool_call.get("function", {}) if isinstance(tool_call, dict) else {} - arguments_payload = function_block.get("arguments") - try: - parsed = _safe_json_loads(arguments_payload) if isinstance(arguments_payload, str) else None - parsed_arguments.append(parsed) - valid_arguments = valid_arguments and isinstance(parsed, dict) - except Exception: # pragma: no cover - defensive - parsed_arguments.append(None) - valid_arguments = False + # Use helper to normalize tool calls (handles both dict and Pydantic objects) + normalized_calls = _collect_tool_calls(tool_calls) + parsed_arguments: list[Any] = [args for name, args in normalized_calls] + valid_arguments = all(isinstance(args, dict) for args in parsed_arguments) finish_reason_ok = finish_reason in {"tool_calls", "stop"} @@ -1116,7 +1152,7 @@ def test_glm_streaming_tool_brace_arguments(row: EvaluationRow) -> EvaluationRow ), } - finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( metrics, finish_reason, content_str, reasoning_str ) @@ -1127,6 +1163,7 @@ def test_glm_streaming_tool_brace_arguments(row: EvaluationRow) -> EvaluationRow and finish_reason_present and no_forbidden_tags and no_xml_tags + and no_reasoning_leakage ) row.evaluation_result = EvaluateResult( @@ -1150,7 +1187,7 @@ def test_glm_streaming_tool_brace_arguments(row: EvaluationRow) -> EvaluationRow num_runs=1, mode="pointwise", ) -def test_glm_streaming_tool_multi(row: EvaluationRow) -> EvaluationRow: +def test_streaming_tool_complex_arguments(row: EvaluationRow) -> EvaluationRow: """Validate complex tool arguments are preserved when streaming.""" assistant_msg = row.last_assistant_message() @@ -1166,7 +1203,7 @@ def test_glm_streaming_tool_multi(row: EvaluationRow) -> EvaluationRow: valid_call = args break - finish_reason_ok = finish_reason in {"tool_calls", "stop"} + finish_reason_tool_calls = finish_reason == "tool_calls" metrics = { "tool_calls_count": MetricResult( @@ -1181,18 +1218,27 @@ def test_glm_streaming_tool_multi(row: EvaluationRow) -> EvaluationRow: reason="process_data arguments parsed" if valid_call else "process_data arguments missing/invalid", data={"arguments": valid_call}, ), - "finish_reason": MetricResult( - score=1.0 if finish_reason_ok else 0.0, + "finish_reason_tool_calls": MetricResult( + score=1.0 if finish_reason_tool_calls else 0.0, is_score_valid=True, - reason="finish_reason acceptable" if finish_reason_ok else f"Unexpected finish_reason: {finish_reason}", + reason="finish_reason is tool_calls" + if finish_reason_tool_calls + else f"Unexpected finish_reason: {finish_reason}", ), } - finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( metrics, finish_reason, content_str, reasoning_str ) - all_checks_passed = valid_call and finish_reason_ok and finish_reason_present and no_forbidden_tags and no_xml_tags + all_checks_passed = ( + valid_call + and finish_reason_tool_calls + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) row.evaluation_result = EvaluateResult( score=1.0 if all_checks_passed else 0.0, @@ -1203,19 +1249,19 @@ def test_glm_streaming_tool_multi(row: EvaluationRow) -> EvaluationRow: return row -_PEER_KIMI_MULTI_ROW = _build_row_from_payload("peer-kimi-multi-tool-calls", PEER_KIMI_MULTI_TOOL_CALLS_PAYLOAD) +_MULTI_TOOL_CALLS_ROW = _build_row_from_payload("multi-tool-calls", MULTI_TOOL_CALLS_PAYLOAD) @evaluation_test( - input_rows=[[_PEER_KIMI_MULTI_ROW]], - completion_params=[_build_completion_params_from_payload(PEER_KIMI_MULTI_TOOL_CALLS_PAYLOAD)], + input_rows=[[_MULTI_TOOL_CALLS_ROW]], + completion_params=[_build_completion_params_from_payload(MULTI_TOOL_CALLS_PAYLOAD)], rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=0.0, num_runs=1, mode="pointwise", ) -def test_glm_streaming_kimi_multi_tool_calls(row: EvaluationRow) -> EvaluationRow: +def test_streaming_multiple_tool_calls(row: EvaluationRow) -> EvaluationRow: """Ensure multiple tool calls survive a single streamed assistant turn.""" assistant_msg = row.last_assistant_message() @@ -1257,7 +1303,7 @@ def test_glm_streaming_kimi_multi_tool_calls(row: EvaluationRow) -> EvaluationRo ), } - finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( metrics, finish_reason, content_str, reasoning_str ) @@ -1269,6 +1315,7 @@ def test_glm_streaming_kimi_multi_tool_calls(row: EvaluationRow) -> EvaluationRo and finish_reason_present and no_forbidden_tags and no_xml_tags + and no_reasoning_leakage ) row.evaluation_result = EvaluateResult( @@ -1296,7 +1343,7 @@ def test_glm_streaming_kimi_multi_tool_calls(row: EvaluationRow) -> EvaluationRo num_runs=1, mode="pointwise", ) -def test_glm_streaming_tool_missing_required_param(row: EvaluationRow) -> EvaluationRow: +def test_streaming_tool_missing_required_param(row: EvaluationRow) -> EvaluationRow: """Detect whether required parameters are omitted during streaming.""" assistant_msg = row.last_assistant_message() @@ -1334,11 +1381,13 @@ def test_glm_streaming_tool_missing_required_param(row: EvaluationRow) -> Evalua ), } - finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( metrics, finish_reason, content_str, reasoning_str ) - all_checks_passed = missing_required and finish_reason_present and no_forbidden_tags and no_xml_tags + all_checks_passed = ( + missing_required and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage + ) row.evaluation_result = EvaluateResult( score=1.0 if all_checks_passed else 0.0, @@ -1365,7 +1414,7 @@ def test_glm_streaming_tool_missing_required_param(row: EvaluationRow) -> Evalua num_runs=1, mode="pointwise", ) -def test_glm_streaming_tool_view_range_format(row: EvaluationRow) -> EvaluationRow: +def test_streaming_tool_array_parameters(row: EvaluationRow) -> EvaluationRow: """Check streamed arguments keep view_range as an integer array.""" assistant_msg = row.last_assistant_message() @@ -1408,11 +1457,13 @@ def test_glm_streaming_tool_view_range_format(row: EvaluationRow) -> EvaluationR ), } - finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( metrics, finish_reason, content_str, reasoning_str ) - all_checks_passed = view_range_valid and finish_reason_present and no_forbidden_tags and no_xml_tags + all_checks_passed = ( + view_range_valid and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage + ) row.evaluation_result = EvaluateResult( score=1.0 if all_checks_passed else 0.0, @@ -1435,7 +1486,7 @@ def test_glm_streaming_tool_view_range_format(row: EvaluationRow) -> EvaluationR num_runs=1, mode="pointwise", ) -def test_glm_streaming_tool_naming(row: EvaluationRow) -> EvaluationRow: +def test_streaming_tool_naming_fields(row: EvaluationRow) -> EvaluationRow: """Confirm tool arguments include required naming fields.""" assistant_msg = row.last_assistant_message() @@ -1473,11 +1524,13 @@ def test_glm_streaming_tool_naming(row: EvaluationRow) -> EvaluationRow: ), } - finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( metrics, finish_reason, content_str, reasoning_str ) - all_checks_passed = args_valid and finish_reason_present and no_forbidden_tags and no_xml_tags + all_checks_passed = ( + args_valid and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage + ) row.evaluation_result = EvaluateResult( score=1.0 if all_checks_passed else 0.0, @@ -1500,7 +1553,7 @@ def test_glm_streaming_tool_naming(row: EvaluationRow) -> EvaluationRow: num_runs=1, mode="pointwise", ) -def test_glm_streaming_tool_command(row: EvaluationRow) -> EvaluationRow: +def test_streaming_tool_command_execution(row: EvaluationRow) -> EvaluationRow: """Validate command execution tool receives the correct parameters.""" assistant_msg = row.last_assistant_message() @@ -1533,11 +1586,13 @@ def test_glm_streaming_tool_command(row: EvaluationRow) -> EvaluationRow: ), } - finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( metrics, finish_reason, content_str, reasoning_str ) - all_checks_passed = command_valid and finish_reason_present and no_forbidden_tags and no_xml_tags + all_checks_passed = ( + command_valid and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage + ) row.evaluation_result = EvaluateResult( score=1.0 if all_checks_passed else 0.0, @@ -1564,7 +1619,7 @@ def test_glm_streaming_tool_command(row: EvaluationRow) -> EvaluationRow: num_runs=1, mode="pointwise", ) -def test_glm_streaming_tool_parameter_formats(row: EvaluationRow) -> EvaluationRow: +def test_streaming_tool_parameter_types(row: EvaluationRow) -> EvaluationRow: """Ensure streamed parameters respect expected JSON types.""" assistant_msg = row.last_assistant_message() @@ -1602,11 +1657,13 @@ def test_glm_streaming_tool_parameter_formats(row: EvaluationRow) -> EvaluationR ), } - finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( metrics, finish_reason, content_str, reasoning_str ) - all_checks_passed = types_valid and finish_reason_present and no_forbidden_tags and no_xml_tags + all_checks_passed = ( + types_valid and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage + ) row.evaluation_result = EvaluateResult( score=1.0 if all_checks_passed else 0.0, @@ -1629,10 +1686,11 @@ def test_glm_streaming_tool_parameter_formats(row: EvaluationRow) -> EvaluationR num_runs=1, mode="pointwise", ) -def test_glm_streaming_tool_recovery(row: EvaluationRow) -> EvaluationRow: +def test_streaming_tool_retry_behavior(row: EvaluationRow) -> EvaluationRow: """Check whether the assistant retries tool calls when instructed to recover.""" assistant_msg = row.last_assistant_message() + print(f"assistant_msg: {assistant_msg}") finish_reason = row.execution_metadata.finish_reason _debug_log_assistant_message("tool_recovery", assistant_msg, finish_reason) content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" @@ -1662,11 +1720,13 @@ def test_glm_streaming_tool_recovery(row: EvaluationRow) -> EvaluationRow: ), } - finish_reason_present, no_forbidden_tags, no_xml_tags = _augment_metrics_with_common_checks( + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( metrics, finish_reason, content_str, reasoning ) - all_checks_passed = multiple_attempts and finish_reason_present and no_forbidden_tags and no_xml_tags + all_checks_passed = ( + multiple_attempts and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage + ) row.evaluation_result = EvaluateResult( score=1.0 if all_checks_passed else 0.0, @@ -1677,3 +1737,1815 @@ def test_glm_streaming_tool_recovery(row: EvaluationRow) -> EvaluationRow: metrics=metrics, ) return row + + +# ============================================================================ +# Reasoning Effort Tests +# ============================================================================ + +REASONING_DISABLED_ROW = EvaluationRow( + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="What is 2+2? Explain your answer."), + ] +) +REASONING_DISABLED_ROW.input_metadata.dataset_info = { + "test_name": "reasoning_disabled", + "description": "Verify reasoning_content is empty when reasoning_effort=none", +} + + +@evaluation_test( + input_rows=[[REASONING_DISABLED_ROW]], + completion_params=[ + { + "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", # Reasoning-capable model + "reasoning_effort": "none", # Explicitly disable reasoning + "max_tokens": DEFAULT_MAX_TOKENS, + "temperature": 0.0, + "stream": True, + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + passed_threshold=1.0, + mode="pointwise", +) +def test_reasoning_effort_none_no_reasoning(row: EvaluationRow) -> EvaluationRow: + """ + Verify that when reasoning_effort=none, reasoning_content is empty. + + Tests that reasoning-capable models respect the reasoning_effort=none parameter. + """ + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + + has_content = bool(content_str.strip()) + reasoning_empty = reasoning_str == "" + finish_reason_stop = finish_reason == "stop" + + metrics = { + "has_content": MetricResult( + score=1.0 if has_content else 0.0, + is_score_valid=True, + reason="Content present" if has_content else "No content", + data={"content_preview": content_str[:100]}, + ), + "reasoning_empty": MetricResult( + score=1.0 if reasoning_empty else 0.0, + is_score_valid=True, + reason="reasoning_content is empty (as expected)" + if reasoning_empty + else f"Unexpected reasoning_content: {reasoning_str[:100]}", + data={"reasoning_length": len(reasoning_str)}, + ), + "finish_reason_stop": MetricResult( + score=1.0 if finish_reason_stop else 0.0, + is_score_valid=True, + reason="finish_reason is stop" if finish_reason_stop else f"Unexpected finish_reason: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + has_content + and reasoning_empty + and finish_reason_stop + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) + + # Build detailed failure reason + failure_reasons = [] + if not has_content: + failure_reasons.append("no content") + if not reasoning_empty: + failure_reasons.append(f"reasoning_content present (len={len(reasoning_str)})") + if not finish_reason_stop: + failure_reasons.append(f"finish_reason={finish_reason}") + if not finish_reason_present: + failure_reasons.append("finish_reason null") + if not no_forbidden_tags: + failure_reasons.append("forbidden tags detected") + if not no_xml_tags: + failure_reasons.append("XML tags detected") + + reason = ( + "reasoning_effort=none respected" if all_checks_passed else f"Compliance failed: {', '.join(failure_reasons)}" + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason=reason, + metrics=metrics, + ) + return row + + +REASONING_ENABLED_ROW = EvaluationRow( + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message( + role="user", + content="Solve this problem: If Alice has 3 apples and Bob gives her 5 more, how many does she have?", + ), + ] +) +REASONING_ENABLED_ROW.input_metadata.dataset_info = { + "test_name": "reasoning_enabled", + "description": "Verify reasoning_content is present when reasoning_effort=low", +} + + +@evaluation_test( + input_rows=[[REASONING_ENABLED_ROW]], + completion_params=[ + { + "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", # Reasoning-capable model + "reasoning_effort": "low", # Enable reasoning + "max_tokens": DEFAULT_MAX_TOKENS, + "temperature": 0.0, + "stream": True, + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + passed_threshold=1.0, + mode="pointwise", +) +def test_reasoning_effort_low_has_reasoning(row: EvaluationRow) -> EvaluationRow: + """ + Verify that when reasoning_effort=low, reasoning_content is present. + + Tests that reasoning-capable models generate reasoning when requested. + """ + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + + has_content = bool(content_str.strip()) + reasoning_present = bool(reasoning_str) + finish_reason_stop = finish_reason == "stop" + + metrics = { + "has_content": MetricResult( + score=1.0 if has_content else 0.0, + is_score_valid=True, + reason="Content present" if has_content else "No content", + data={"content_preview": content_str[:100]}, + ), + "reasoning_present": MetricResult( + score=1.0 if reasoning_present else 0.0, + is_score_valid=True, + reason="reasoning_content present (as expected)" + if reasoning_present + else "reasoning_content missing when reasoning_effort=low", + data={"reasoning_length": len(reasoning_str), "reasoning_preview": reasoning_str[:200]}, + ), + "finish_reason_stop": MetricResult( + score=1.0 if finish_reason_stop else 0.0, + is_score_valid=True, + reason="finish_reason is stop" if finish_reason_stop else f"Unexpected finish_reason: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + has_content + and reasoning_present + and finish_reason_stop + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) + + # Build detailed failure reason + failure_reasons = [] + if not has_content: + failure_reasons.append("no content") + if not reasoning_present: + failure_reasons.append("reasoning_content missing") + if not finish_reason_stop: + failure_reasons.append(f"finish_reason={finish_reason}") + if not finish_reason_present: + failure_reasons.append("finish_reason null") + if not no_forbidden_tags: + failure_reasons.append("forbidden tags detected") + if not no_xml_tags: + failure_reasons.append("XML tags detected") + if not no_reasoning_leakage: + failure_reasons.append("thinking phrases in content") + + reason = ( + "reasoning_effort=low produces reasoning_content" + if all_checks_passed + else f"Compliance failed: {', '.join(failure_reasons)}" + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason=reason, + metrics=metrics, + ) + return row + + +# ============================================================================ +# Tools + Reasoning Combination Test +# ============================================================================ + +TOOLS_WITH_REASONING_ROW = EvaluationRow( + messages=[ + Message(role="system", content="You are a helpful assistant with access to tools."), + Message( + role="user", + content="What's the weather in San Francisco? Think through which tool to use and why.", + ), + ], + tools=[ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ], +) +TOOLS_WITH_REASONING_ROW.input_metadata.dataset_info = { + "test_name": "tools_with_reasoning", + "description": "Verify tools and reasoning work together in streaming", +} + + +@evaluation_test( + input_rows=[[TOOLS_WITH_REASONING_ROW]], + completion_params=[ + { + "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", # Reasoning-capable model + "reasoning_effort": "low", # Enable reasoning + "max_tokens": DEFAULT_MAX_TOKENS, + "temperature": 0.0, + "stream": True, + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + passed_threshold=1.0, + mode="pointwise", +) +def test_streaming_tools_with_reasoning(row: EvaluationRow) -> EvaluationRow: + """ + Verify that streaming works correctly when BOTH tools and reasoning are present. + + Requirements: + - reasoning_content should be present + - tool_calls should be present + - finish_reason should be "tool_calls" + - No XML tags or reasoning leakage + """ + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) + + reasoning_present = bool(reasoning_str) + has_tool_calls = len(calls) > 0 + finish_reason_tool_calls = finish_reason == "tool_calls" + + # Validate tool call has required params + tool_call_valid = False + if has_tool_calls: + for name, args in calls: + if name == "get_current_weather" and isinstance(args, dict): + location = args.get("location", "") + if "san francisco" in location.lower() or "sf" in location.lower(): + tool_call_valid = True + break + + metrics = { + "reasoning_present": MetricResult( + score=1.0 if reasoning_present else 0.0, + is_score_valid=True, + reason="reasoning_content present" if reasoning_present else "reasoning_content missing", + data={"reasoning_length": len(reasoning_str), "reasoning_preview": reasoning_str[:200]}, + ), + "has_tool_calls": MetricResult( + score=1.0 if has_tool_calls else 0.0, + is_score_valid=True, + reason="Tool calls present" if has_tool_calls else "No tool calls", + data={"tool_call_count": len(calls)}, + ), + "finish_reason_tool_calls": MetricResult( + score=1.0 if finish_reason_tool_calls else 0.0, + is_score_valid=True, + reason="finish_reason is tool_calls" + if finish_reason_tool_calls + else f"Unexpected finish_reason: {finish_reason}", + ), + "tool_call_valid": MetricResult( + score=1.0 if tool_call_valid else 0.0, + is_score_valid=has_tool_calls, + reason="Tool call has correct location" if tool_call_valid else "Tool call missing required params", + data={"tool_calls": [{"name": name, "args": args} for name, args in calls]}, + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + reasoning_present + and has_tool_calls + and finish_reason_tool_calls + and tool_call_valid + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) + + # Build detailed failure reason + failure_reasons = [] + if not reasoning_present: + failure_reasons.append("reasoning_content missing") + if not has_tool_calls: + failure_reasons.append("no tool calls") + if not finish_reason_tool_calls: + failure_reasons.append(f"finish_reason={finish_reason}") + if not tool_call_valid: + failure_reasons.append("tool params invalid") + if not finish_reason_present: + failure_reasons.append("finish_reason null") + if not no_forbidden_tags: + failure_reasons.append("forbidden tags detected") + if not no_xml_tags: + failure_reasons.append("XML tags detected") + if not no_reasoning_leakage: + failure_reasons.append("thinking phrases in content") + + reason = ( + "Tools + reasoning work together in streaming" + if all_checks_passed + else f"Compliance failed: {', '.join(failure_reasons)}" + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason=reason, + metrics=metrics, + ) + return row + + +# ============================================================================ +# Streaming Consistency Test (Shadow Test) +# ============================================================================ + +STREAMING_CONSISTENCY_ROW = EvaluationRow( + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="Count from 1 to 5 and explain why you're counting."), + ] +) +STREAMING_CONSISTENCY_ROW.input_metadata.dataset_info = { + "test_name": "streaming_consistency", + "description": "Shadow test: verify stream=true produces identical output to stream=false", +} + + +@evaluation_test( + input_rows=[[STREAMING_CONSISTENCY_ROW]], + completion_params=[ + { + "model": os.getenv("EP_MODEL", DEFAULT_MODEL_ID), + "max_tokens": os.getenv("EP_MAX_TOKENS", DEFAULT_MAX_TOKENS), + "temperature": 0.0, # Deterministic for consistency + "stream": False, # Will be overridden by custom rollout + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + passed_threshold=1.0, + mode="pointwise", +) +async def test_streaming_output_consistency(row: EvaluationRow) -> EvaluationRow: + """ + Shadow stress test for streaming consistency. + + Strategy: + 1. Run request with stream=false, capture output + 2. Run request with stream=true + forced_generation (same tokens), concat chunks + 3. Verify concatenated streaming output matches non-streaming output + + This catches bugs like: "you're5" (streaming) vs "you're 5" (non-streaming) + """ + from openai import AsyncOpenAI + + api_key = os.environ.get("FIREWORKS_API_KEY") + if not api_key: + row.evaluation_result = EvaluateResult( + score=0.0, + is_score_valid=False, + reason="FIREWORKS_API_KEY not set", + ) + return row + + model = os.getenv("EP_MODEL", DEFAULT_MODEL_ID) + messages = [msg.model_dump() for msg in row.messages] + + try: + async with AsyncOpenAI( + api_key=api_key, + base_url="https://api.fireworks.ai/inference/v1", + ) as client: + # Step 1: Get non-streaming output + response_non_stream = await client.chat.completions.create( + model=model, + messages=messages, # type: ignore[arg-type] + temperature=0.0, + max_tokens=DEFAULT_MAX_TOKENS, + stream=False, + ) + + non_stream_content = response_non_stream.choices[0].message.content or "" + non_stream_tool_calls = response_non_stream.choices[0].message.tool_calls + non_stream_finish = response_non_stream.choices[0].finish_reason + + # Step 2: Get streaming output with forced generation + # Extract token IDs from non-streaming response if available + # (Note: OpenAI API doesn't expose token IDs directly, so we'll just verify + # that streaming with same params produces same output) + stream_response = await client.chat.completions.create( + model=model, + messages=messages, # type: ignore[arg-type] + temperature=0.0, + max_tokens=DEFAULT_MAX_TOKENS, + stream=True, + ) + + # Concatenate streaming chunks + stream_content_parts = [] + stream_finish = None + stream_tool_calls = None + + async for chunk in stream_response: + if chunk.choices and len(chunk.choices) > 0: + delta = chunk.choices[0].delta + if delta.content: + stream_content_parts.append(delta.content) + if delta.tool_calls: + stream_tool_calls = delta.tool_calls + if chunk.choices[0].finish_reason: + stream_finish = chunk.choices[0].finish_reason + + stream_content = "".join(stream_content_parts) + + # Step 3: Compare outputs + content_match = non_stream_content == stream_content + finish_match = non_stream_finish == stream_finish + + # Tool calls comparison (basic check) + tool_calls_match = True + if non_stream_tool_calls and stream_tool_calls: + tool_calls_match = len(non_stream_tool_calls) == len(stream_tool_calls) + elif non_stream_tool_calls or stream_tool_calls: + tool_calls_match = False + + metrics = { + "content_match": MetricResult( + score=1.0 if content_match else 0.0, + is_score_valid=True, + reason="Content identical" + if content_match + else f"Content mismatch: non-stream='{non_stream_content[:100]}' vs stream='{stream_content[:100]}'", + data={ + "non_stream_length": len(non_stream_content), + "stream_length": len(stream_content), + }, + ), + "finish_reason_match": MetricResult( + score=1.0 if finish_match else 0.0, + is_score_valid=True, + reason=f"finish_reason: {non_stream_finish}" + if finish_match + else f"Mismatch: {non_stream_finish} vs {stream_finish}", + ), + "tool_calls_match": MetricResult( + score=1.0 if tool_calls_match else 0.0, + is_score_valid=True, + reason="Tool calls consistent" if tool_calls_match else "Tool call count mismatch", + ), + } + + all_match = content_match and finish_match and tool_calls_match + + row.evaluation_result = EvaluateResult( + score=1.0 if all_match else 0.0, + is_score_valid=True, + reason="Streaming output matches non-streaming" if all_match else "Streaming inconsistency detected", + metrics=metrics, + ) + + except Exception as e: + row.evaluation_result = EvaluateResult( + score=0.0, + is_score_valid=False, + reason=f"Test execution failed: {e}", + ) + + return row + + +# ============================================================================ +# Non-Streaming Tests (Mirror of Streaming Tests) +# ============================================================================ + + +# Test 1: Structured Output (Non-Streaming) +@evaluation_test( + input_rows=[[STRUCTURED_OUTPUT_ROW]], + completion_params=[ + { + "model": DEFAULT_MODEL_ID, + "stream": False, # Non-streaming + "temperature": 1.0, + "top_p": 1.0, + "max_tokens": DEFAULT_MAX_TOKENS, + "response_format": STRUCTURED_RESPONSE_FORMAT, + "reasoning_effort": "none", + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=1.0, + num_runs=1, + mode="pointwise", +) +def test_non_streaming_structured_output(row: EvaluationRow) -> EvaluationRow: + """Non-streaming version: Validate structured output with JSON schema.""" + assistant_msg = row.last_assistant_message() + if assistant_msg is None: + row.evaluation_result = EvaluateResult( + score=0.0, + is_score_valid=False, + reason="No assistant message produced", + metrics={}, + ) + return row + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("structured_output_non_stream", assistant_msg, finish_reason) + + content_str = _coerce_content_to_str(assistant_msg.content) + reasoning_str = assistant_msg.reasoning_content or "" + parsed_content = _safe_json_loads(content_str) + parsed_reasoning = _safe_json_loads(reasoning_str) if reasoning_str else None + + required_fields = {"location", "temperature", "conditions"} + + content_present = bool(content_str.strip()) + content_is_json = parsed_content is not None + required_keys_present = content_is_json and required_fields <= set(parsed_content.keys()) + temperature_is_number = content_is_json and isinstance(parsed_content.get("temperature"), (int, float)) + location_valid = content_is_json and parsed_content.get("location") in {"London", "New York"} + reasoning_contains_payload = parsed_reasoning is not None + finish_reason_expected = finish_reason == "stop" + + metrics = { + "content_is_json": MetricResult( + score=1.0 if content_is_json else 0.0, + is_score_valid=True, + reason="Assistant content parsed as JSON" if content_is_json else "Failed to parse JSON", + data={"content": content_str}, + ), + "required_keys_present": MetricResult( + score=1.0 if required_keys_present else 0.0, + is_score_valid=content_is_json, + reason=("All required keys present" if required_keys_present else "Missing required keys"), + data={"parsed_content": parsed_content}, + ), + "temperature_is_number": MetricResult( + score=1.0 if temperature_is_number else 0.0, + is_score_valid=content_is_json, + reason="Temperature is numeric" if temperature_is_number else "Temperature not numeric", + data={"temperature": parsed_content.get("temperature") if parsed_content else None}, + ), + "reasoning_contains_payload": MetricResult( + score=0.0 if reasoning_contains_payload else 1.0, + is_score_valid=True, + reason="Reasoning is empty" if not reasoning_contains_payload else "Payload leaked to reasoning", + data={"reasoning": reasoning_str}, + ), + "finish_reason_stop": MetricResult( + score=1.0 if finish_reason_expected else 0.0, + is_score_valid=True, + reason=( + "finish_reason is stop" if finish_reason_expected else f"Unexpected finish_reason: {finish_reason}" + ), + data={"finish_reason": finish_reason}, + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + content_present + and content_is_json + and required_keys_present + and temperature_is_number + and location_valid + and not reasoning_contains_payload + and finish_reason_expected + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason=( + "Structured output returned in assistant content" + if all_checks_passed + else "Structured output missing or response malformed" + ), + metrics=metrics, + ) + return row + + +# Test 2: Simple Completion (Non-Streaming) +_SIMPLE_COMPLETION_NON_STREAM_ROW = _build_row_from_payload("simple-completion-non-stream", PEER_SIMPLE_STREAM_PAYLOAD) + + +@evaluation_test( + input_rows=[[_SIMPLE_COMPLETION_NON_STREAM_ROW]], + completion_params=[ + { + **_build_completion_params_from_payload(PEER_SIMPLE_STREAM_PAYLOAD), + "stream": False, # Override to non-streaming + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_non_streaming_simple_completion(row: EvaluationRow) -> EvaluationRow: + """Non-streaming version: Validate plain text completion.""" + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("simple_completion_non_stream", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + + has_content = bool(content_str.strip()) + finish_reason_stop = finish_reason == "stop" + reasoning_empty = reasoning_str == "" + has_tool_calls = bool(assistant_msg and assistant_msg.tool_calls) + + metrics = { + "has_content": MetricResult( + score=1.0 if has_content else 0.0, + is_score_valid=True, + reason="Assistant content present" if has_content else "Assistant content empty", + data={"preview": content_str[:120]}, + ), + "finish_reason": MetricResult( + score=1.0 if finish_reason_stop else 0.0, + is_score_valid=True, + reason="finish_reason is stop" if finish_reason_stop else f"Unexpected finish_reason: {finish_reason}", + ), + "reasoning_empty": MetricResult( + score=1.0 if reasoning_empty else 0.0, + is_score_valid=True, + reason="Reasoning is empty" if reasoning_empty else "Unexpected reasoning output", + data={"reasoning": reasoning_str}, + ), + "tool_calls_absent": MetricResult( + score=1.0 if not has_tool_calls else 0.0, + is_score_valid=True, + reason="No tool calls emitted" if not has_tool_calls else "Unexpected tool calls emitted", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + has_content + and finish_reason_stop + and reasoning_empty + and not has_tool_calls + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Simple completion compliant" if all_checks_passed else "Simple completion failed compliance checks", + metrics=metrics, + ) + return row + + +# Test 3: Tool Call (Non-Streaming) +TOOL_CALL_NON_STREAM_ROW = EvaluationRow( + messages=[ + Message(role="system", content=TOOL_SYSTEM_PROMPT), + Message(role="user", content="What's the weather in Boston in Fahrenheit?"), + ], + tools=WEATHER_TOOL_DEFINITION, +) +TOOL_CALL_NON_STREAM_ROW.input_metadata.dataset_info = { + "test_name": "tool_call_non_stream", + "description": "Non-streaming tool call test", +} + + +@evaluation_test( + input_rows=[[TOOL_CALL_NON_STREAM_ROW]], + completion_params=[ + { + "model": DEFAULT_MODEL_ID, + "stream": False, # Non-streaming + "temperature": 1.0, + "top_p": 1.0, + "max_tokens": DEFAULT_MAX_TOKENS, + "reasoning_effort": "none", + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_non_streaming_single_tool_call(row: EvaluationRow) -> EvaluationRow: + """Non-streaming version: Validate single tool call.""" + assistant_msg = row.last_assistant_message() + if assistant_msg is None: + row.evaluation_result = EvaluateResult( + score=0.0, + is_score_valid=False, + reason="No assistant message produced", + metrics={}, + ) + return row + + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("tool_call_non_stream", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) + reasoning_str = (assistant_msg.reasoning_content or "").strip() + tool_calls = assistant_msg.tool_calls or [] + tool_calls_for_metrics: list[Any] = [] + for tc in tool_calls: + if hasattr(tc, "model_dump"): + try: + tool_calls_for_metrics.append(tc.model_dump(exclude_none=True)) + except Exception: + tool_calls_for_metrics.append(str(tc)) + elif isinstance(tc, dict): + tool_calls_for_metrics.append(tc) + else: + tool_calls_for_metrics.append(str(tc)) + tool_call_count = row.execution_metadata.tool_call_count + + has_tool_call = len(tool_calls) > 0 + exactly_one_tool_call = len(tool_calls) == 1 + finish_reason_tool_calls = finish_reason == "tool_calls" + tool_call_count_matches = tool_call_count == len(tool_calls) + + tool_call_arguments_valid = False + parsed_arguments = None + if exactly_one_tool_call: + # Use helper to normalize tool call (handles both dict and Pydantic objects) + name, parsed_arguments = _normalize_tool_call(tool_calls[0]) + tool_call_arguments_valid = ( + isinstance(parsed_arguments, dict) + and ("boston" in (parsed_arguments.get("location") or "").lower()) + and parsed_arguments.get("unit") == "fahrenheit" + ) + + base_checks_passed = ( + has_tool_call + and exactly_one_tool_call + and finish_reason_tool_calls + and tool_call_arguments_valid + and tool_call_count_matches + ) + + metrics = { + "has_tool_call": MetricResult( + score=1.0 if has_tool_call else 0.0, + is_score_valid=True, + reason="Assistant produced at least one tool call" if has_tool_call else "No tool calls returned", + data={"tool_call_count": len(tool_calls)}, + ), + "single_tool_call": MetricResult( + score=1.0 if exactly_one_tool_call else 0.0, + is_score_valid=has_tool_call, + reason=("Exactly one tool call" if exactly_one_tool_call else "Unexpected number of tool calls"), + data={"tool_calls": tool_calls_for_metrics}, + ), + "finish_reason_tool_calls": MetricResult( + score=1.0 if finish_reason_tool_calls else 0.0, + is_score_valid=True, + reason=( + "finish_reason is tool_calls" + if finish_reason_tool_calls + else f"Unexpected finish_reason: {finish_reason}" + ), + data={"finish_reason": finish_reason}, + ), + "tool_call_arguments_valid": MetricResult( + score=1.0 if tool_call_arguments_valid else 0.0, + is_score_valid=exactly_one_tool_call, + reason=("Tool call arguments valid" if tool_call_arguments_valid else "Tool call arguments invalid"), + data={"arguments": parsed_arguments}, + ), + "tool_call_count_matches": MetricResult( + score=1.0 if tool_call_count_matches else 0.0, + is_score_valid=True, + reason=( + "tool_call_count matches returned calls" + if tool_call_count_matches + else f"tool_call_count mismatch (metadata={tool_call_count}, actual={len(tool_calls)})" + ), + data={"metadata_tool_call_count": tool_call_count}, + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + base_checks_passed and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Tool call completed correctly" if all_checks_passed else "Tool call behaviour invalid", + metrics=metrics, + ) + return row + + +# Test 4: Multiple Tool Calls (Non-Streaming) +_MULTI_TOOL_CALLS_NON_STREAM_ROW = _build_row_from_payload("multi-tool-calls-non-stream", MULTI_TOOL_CALLS_PAYLOAD) + + +@evaluation_test( + input_rows=[[_MULTI_TOOL_CALLS_NON_STREAM_ROW]], + completion_params=[ + { + **_build_completion_params_from_payload(MULTI_TOOL_CALLS_PAYLOAD), + "stream": False, # Override to non-streaming + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + aggregation_method="mean", + passed_threshold=0.0, + num_runs=1, + mode="pointwise", +) +def test_non_streaming_multiple_tool_calls(row: EvaluationRow) -> EvaluationRow: + """Non-streaming version: Validate multiple tool calls.""" + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + _debug_log_assistant_message("multi_tool_calls_non_stream", assistant_msg, finish_reason) + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) + + names = [name for name, _ in calls if name] + has_weather = names.count("get_current_weather") >= 2 + has_air_quality = "get_air_quality" in names + + metrics = { + "tool_calls_count": MetricResult( + score=1.0 if len(calls) >= 3 else 0.0, + is_score_valid=True, + reason="Three or more tool calls" if len(calls) >= 3 else "Fewer than three tool calls", + data={"tool_calls": names}, + ), + "includes_weather_calls": MetricResult( + score=1.0 if has_weather else 0.0, + is_score_valid=True, + reason="Weather tool called for multiple cities" if has_weather else "Insufficient weather tool calls", + ), + "includes_air_quality": MetricResult( + score=1.0 if has_air_quality else 0.0, + is_score_valid=True, + reason="Air quality tool called" if has_air_quality else "Air quality tool missing", + ), + "finish_reason_tool_calls": MetricResult( + score=1.0 if finish_reason == "tool_calls" else 0.0, + is_score_valid=True, + reason=( + "finish_reason is tool_calls" + if finish_reason == "tool_calls" + else f"Unexpected finish_reason: {finish_reason}" + ), + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + len(calls) >= 3 + and has_weather + and has_air_quality + and finish_reason == "tool_calls" + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Multiple tool calls emitted" + if all_checks_passed + else "Expected multi-tool response missing or invalid", + metrics=metrics, + ) + return row + + +# Test 5: Reasoning Disabled (Non-Streaming) +REASONING_DISABLED_NON_STREAM_ROW = EvaluationRow( + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="What is 2+2? Explain your answer."), + ] +) +REASONING_DISABLED_NON_STREAM_ROW.input_metadata.dataset_info = { + "test_name": "reasoning_disabled_non_stream", + "description": "Non-streaming: verify reasoning_content empty when reasoning_effort=none", +} + + +@evaluation_test( + input_rows=[[REASONING_DISABLED_NON_STREAM_ROW]], + completion_params=[ + { + "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "reasoning_effort": "none", + "max_tokens": DEFAULT_MAX_TOKENS, + "temperature": 0.0, + "stream": False, # Non-streaming + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + passed_threshold=1.0, + mode="pointwise", +) +def test_reasoning_effort_none_no_reasoning_non_stream(row: EvaluationRow) -> EvaluationRow: + """Non-streaming version: Verify reasoning_content empty when reasoning_effort=none.""" + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + + has_content = bool(content_str.strip()) + reasoning_empty = reasoning_str == "" + finish_reason_stop = finish_reason == "stop" + + metrics = { + "has_content": MetricResult( + score=1.0 if has_content else 0.0, + is_score_valid=True, + reason="Content present" if has_content else "No content", + data={"content_preview": content_str[:100]}, + ), + "reasoning_empty": MetricResult( + score=1.0 if reasoning_empty else 0.0, + is_score_valid=True, + reason="reasoning_content is empty (as expected)" + if reasoning_empty + else f"Unexpected reasoning_content: {reasoning_str[:100]}", + data={"reasoning_length": len(reasoning_str)}, + ), + "finish_reason_stop": MetricResult( + score=1.0 if finish_reason_stop else 0.0, + is_score_valid=True, + reason="finish_reason is stop" if finish_reason_stop else f"Unexpected finish_reason: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + has_content + and reasoning_empty + and finish_reason_stop + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) + + # Build detailed failure reason + failure_reasons = [] + if not has_content: + failure_reasons.append("no content") + if not reasoning_empty: + failure_reasons.append(f"reasoning_content present (len={len(reasoning_str)})") + if not finish_reason_stop: + failure_reasons.append(f"finish_reason={finish_reason}") + if not finish_reason_present: + failure_reasons.append("finish_reason null") + if not no_forbidden_tags: + failure_reasons.append("forbidden tags detected") + if not no_xml_tags: + failure_reasons.append("XML tags detected") + + reason = ( + "reasoning_effort=none respected" if all_checks_passed else f"Compliance failed: {', '.join(failure_reasons)}" + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason=reason, + metrics=metrics, + ) + return row + + +# Test 6: Reasoning Enabled (Non-Streaming) +REASONING_ENABLED_NON_STREAM_ROW = EvaluationRow( + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message( + role="user", + content="Solve this problem: If Alice has 3 apples and Bob gives her 5 more, how many does she have?", + ), + ] +) +REASONING_ENABLED_NON_STREAM_ROW.input_metadata.dataset_info = { + "test_name": "reasoning_enabled_non_stream", + "description": "Non-streaming: verify reasoning_content present when reasoning_effort=low", +} + + +@evaluation_test( + input_rows=[[REASONING_ENABLED_NON_STREAM_ROW]], + completion_params=[ + { + "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "reasoning_effort": "low", + "max_tokens": DEFAULT_MAX_TOKENS, + "temperature": 0.0, + "stream": False, # Non-streaming + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + passed_threshold=1.0, + mode="pointwise", +) +def test_reasoning_effort_low_has_reasoning_non_stream(row: EvaluationRow) -> EvaluationRow: + """Non-streaming version: Verify reasoning_content present when reasoning_effort=low.""" + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + + has_content = bool(content_str.strip()) + reasoning_present = bool(reasoning_str) + finish_reason_stop = finish_reason == "stop" + + metrics = { + "has_content": MetricResult( + score=1.0 if has_content else 0.0, + is_score_valid=True, + reason="Content present" if has_content else "No content", + data={"content_preview": content_str[:100]}, + ), + "reasoning_present": MetricResult( + score=1.0 if reasoning_present else 0.0, + is_score_valid=True, + reason="reasoning_content present (as expected)" + if reasoning_present + else "reasoning_content missing when reasoning_effort=low", + data={"reasoning_length": len(reasoning_str), "reasoning_preview": reasoning_str[:200]}, + ), + "finish_reason_stop": MetricResult( + score=1.0 if finish_reason_stop else 0.0, + is_score_valid=True, + reason="finish_reason is stop" if finish_reason_stop else f"Unexpected finish_reason: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + has_content + and reasoning_present + and finish_reason_stop + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) + + # Build detailed failure reason + failure_reasons = [] + if not has_content: + failure_reasons.append("no content") + if not reasoning_present: + failure_reasons.append("reasoning_content missing") + if not finish_reason_stop: + failure_reasons.append(f"finish_reason={finish_reason}") + if not finish_reason_present: + failure_reasons.append("finish_reason null") + if not no_forbidden_tags: + failure_reasons.append("forbidden tags detected") + if not no_xml_tags: + failure_reasons.append("XML tags detected") + if not no_reasoning_leakage: + failure_reasons.append("thinking phrases in content") + + reason = ( + "reasoning_effort=low produces reasoning_content" + if all_checks_passed + else f"Compliance failed: {', '.join(failure_reasons)}" + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason=reason, + metrics=metrics, + ) + return row + + +# Test 7: Tools + Reasoning (Non-Streaming) +TOOLS_WITH_REASONING_NON_STREAM_ROW = EvaluationRow( + messages=[ + Message(role="system", content="You are a helpful assistant with access to tools."), + Message( + role="user", + content="What's the weather in San Francisco? Think through which tool to use and why.", + ), + ], + tools=[ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ], +) +TOOLS_WITH_REASONING_NON_STREAM_ROW.input_metadata.dataset_info = { + "test_name": "tools_with_reasoning_non_stream", + "description": "Non-streaming: verify tools and reasoning work together", +} + + +@evaluation_test( + input_rows=[[TOOLS_WITH_REASONING_NON_STREAM_ROW]], + completion_params=[ + { + "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "reasoning_effort": "low", + "max_tokens": DEFAULT_MAX_TOKENS, + "temperature": 0.0, + "stream": False, # Non-streaming + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + passed_threshold=1.0, + mode="pointwise", +) +def test_non_streaming_tools_with_reasoning(row: EvaluationRow) -> EvaluationRow: + """Non-streaming version: Verify tools and reasoning work together.""" + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) + + reasoning_present = bool(reasoning_str) + has_tool_calls = len(calls) > 0 + finish_reason_tool_calls = finish_reason == "tool_calls" + + # Validate tool call has required params + tool_call_valid = False + if has_tool_calls: + for name, args in calls: + if name == "get_current_weather" and isinstance(args, dict): + location = args.get("location", "") + if "san francisco" in location.lower() or "sf" in location.lower(): + tool_call_valid = True + break + + metrics = { + "reasoning_present": MetricResult( + score=1.0 if reasoning_present else 0.0, + is_score_valid=True, + reason="reasoning_content present" if reasoning_present else "reasoning_content missing", + data={"reasoning_length": len(reasoning_str), "reasoning_preview": reasoning_str[:200]}, + ), + "has_tool_calls": MetricResult( + score=1.0 if has_tool_calls else 0.0, + is_score_valid=True, + reason="Tool calls present" if has_tool_calls else "No tool calls", + data={"tool_call_count": len(calls)}, + ), + "finish_reason_tool_calls": MetricResult( + score=1.0 if finish_reason_tool_calls else 0.0, + is_score_valid=True, + reason="finish_reason is tool_calls" + if finish_reason_tool_calls + else f"Unexpected finish_reason: {finish_reason}", + ), + "tool_call_valid": MetricResult( + score=1.0 if tool_call_valid else 0.0, + is_score_valid=has_tool_calls, + reason="Tool call has correct location" if tool_call_valid else "Tool call missing required params", + data={"tool_calls": [{"name": name, "args": args} for name, args in calls]}, + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + reasoning_present + and has_tool_calls + and finish_reason_tool_calls + and tool_call_valid + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) + + # Build detailed failure reason + failure_reasons = [] + if not reasoning_present: + failure_reasons.append("reasoning_content missing") + if not has_tool_calls: + failure_reasons.append("no tool calls") + if not finish_reason_tool_calls: + failure_reasons.append(f"finish_reason={finish_reason}") + if not tool_call_valid: + failure_reasons.append("tool params invalid") + if not finish_reason_present: + failure_reasons.append("finish_reason null") + if not no_forbidden_tags: + failure_reasons.append("forbidden tags detected") + if not no_xml_tags: + failure_reasons.append("XML tags detected") + if not no_reasoning_leakage: + failure_reasons.append("thinking phrases in content") + + reason = ( + "Tools + reasoning work together in streaming" + if all_checks_passed + else f"Compliance failed: {', '.join(failure_reasons)}" + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason=reason, + metrics=metrics, + ) + return row + + +# ============================================================================ +# Missing Permutations: Reasoning + Structured JSON +# ============================================================================ + +STRUCTURED_OUTPUT_WITH_REASONING_ROW = EvaluationRow( + messages=[ + Message(role="system", content="You are a helpful math assistant."), + Message( + role="user", + content="Solve this problem step by step: If a train travels 120 km in 2 hours, what is its average speed? Return your answer as JSON with 'speed_kmh' and 'explanation' fields.", + ), + ] +) +STRUCTURED_OUTPUT_WITH_REASONING_ROW.input_metadata.dataset_info = { + "test_name": "structured_output_with_reasoning_stream", + "description": "Streaming: structured JSON + reasoning", +} + +STRUCTURED_JSON_SCHEMA = { + "type": "json_object", + "schema": { + "type": "object", + "properties": { + "speed_kmh": {"type": "number", "description": "Speed in km/h"}, + "explanation": {"type": "string", "description": "Brief explanation"}, + }, + "required": ["speed_kmh", "explanation"], + }, +} + + +@evaluation_test( + input_rows=[[STRUCTURED_OUTPUT_WITH_REASONING_ROW]], + completion_params=[ + { + "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "stream": True, + "reasoning_effort": "low", + "response_format": STRUCTURED_JSON_SCHEMA, + "temperature": 0.0, + "max_tokens": DEFAULT_MAX_TOKENS, + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + passed_threshold=1.0, + mode="pointwise", +) +def test_streaming_structured_output_with_reasoning(row: EvaluationRow) -> EvaluationRow: + """Validate structured JSON output with reasoning in streaming mode.""" + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + parsed_content = _safe_json_loads(content_str) + + content_is_json = parsed_content is not None + has_required_keys = content_is_json and {"speed_kmh", "explanation"} <= set(parsed_content.keys()) + speed_is_number = content_is_json and isinstance(parsed_content.get("speed_kmh"), (int, float)) + reasoning_present = bool(reasoning_str) + finish_reason_stop = finish_reason == "stop" + + metrics = { + "content_is_json": MetricResult( + score=1.0 if content_is_json else 0.0, + is_score_valid=True, + reason="Content is valid JSON" if content_is_json else "Failed to parse JSON", + data={"content": content_str[:200]}, + ), + "has_required_keys": MetricResult( + score=1.0 if has_required_keys else 0.0, + is_score_valid=content_is_json, + reason="Required keys present" if has_required_keys else "Missing required keys", + data={"parsed": parsed_content}, + ), + "speed_is_number": MetricResult( + score=1.0 if speed_is_number else 0.0, + is_score_valid=content_is_json, + reason="speed_kmh is numeric" if speed_is_number else "speed_kmh not numeric", + ), + "reasoning_present": MetricResult( + score=1.0 if reasoning_present else 0.0, + is_score_valid=True, + reason="reasoning_content present" if reasoning_present else "reasoning_content missing", + data={"reasoning_length": len(reasoning_str), "reasoning_preview": reasoning_str[:200]}, + ), + "finish_reason_stop": MetricResult( + score=1.0 if finish_reason_stop else 0.0, + is_score_valid=True, + reason="finish_reason is stop" if finish_reason_stop else f"Unexpected: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + content_is_json + and has_required_keys + and speed_is_number + and reasoning_present + and finish_reason_stop + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Structured JSON + reasoning work together" if all_checks_passed else "JSON or reasoning invalid", + metrics=metrics, + ) + return row + + +# Non-streaming version +STRUCTURED_OUTPUT_WITH_REASONING_NON_STREAM_ROW = EvaluationRow( + messages=[ + Message(role="system", content="You are a helpful math assistant."), + Message( + role="user", + content="Solve this problem step by step: If a train travels 120 km in 2 hours, what is its average speed? Return your answer as JSON with 'speed_kmh' and 'explanation' fields.", + ), + ] +) +STRUCTURED_OUTPUT_WITH_REASONING_NON_STREAM_ROW.input_metadata.dataset_info = { + "test_name": "structured_output_with_reasoning_non_stream", + "description": "Non-streaming: structured JSON + reasoning", +} + + +@evaluation_test( + input_rows=[[STRUCTURED_OUTPUT_WITH_REASONING_NON_STREAM_ROW]], + completion_params=[ + { + "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "stream": False, + "reasoning_effort": "low", + "response_format": STRUCTURED_JSON_SCHEMA, + "temperature": 0.0, + "max_tokens": DEFAULT_MAX_TOKENS, + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + passed_threshold=1.0, + mode="pointwise", +) +def test_non_streaming_structured_output_with_reasoning(row: EvaluationRow) -> EvaluationRow: + """Non-streaming version: Validate structured JSON output with reasoning.""" + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + parsed_content = _safe_json_loads(content_str) + + content_is_json = parsed_content is not None + has_required_keys = content_is_json and {"speed_kmh", "explanation"} <= set(parsed_content.keys()) + speed_is_number = content_is_json and isinstance(parsed_content.get("speed_kmh"), (int, float)) + reasoning_present = bool(reasoning_str) + finish_reason_stop = finish_reason == "stop" + + metrics = { + "content_is_json": MetricResult( + score=1.0 if content_is_json else 0.0, + is_score_valid=True, + reason="Content is valid JSON" if content_is_json else "Failed to parse JSON", + data={"content": content_str[:200]}, + ), + "has_required_keys": MetricResult( + score=1.0 if has_required_keys else 0.0, + is_score_valid=content_is_json, + reason="Required keys present" if has_required_keys else "Missing required keys", + data={"parsed": parsed_content}, + ), + "speed_is_number": MetricResult( + score=1.0 if speed_is_number else 0.0, + is_score_valid=content_is_json, + reason="speed_kmh is numeric" if speed_is_number else "speed_kmh not numeric", + ), + "reasoning_present": MetricResult( + score=1.0 if reasoning_present else 0.0, + is_score_valid=True, + reason="reasoning_content present" if reasoning_present else "reasoning_content missing", + data={"reasoning_length": len(reasoning_str), "reasoning_preview": reasoning_str[:200]}, + ), + "finish_reason_stop": MetricResult( + score=1.0 if finish_reason_stop else 0.0, + is_score_valid=True, + reason="finish_reason is stop" if finish_reason_stop else f"Unexpected: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + content_is_json + and has_required_keys + and speed_is_number + and reasoning_present + and finish_reason_stop + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Structured JSON + reasoning work together" if all_checks_passed else "JSON or reasoning invalid", + metrics=metrics, + ) + return row + + +# ============================================================================ +# Missing Permutations: Reasoning + Multiple Tools +# ============================================================================ + +MULTIPLE_TOOLS_WITH_REASONING_ROW = EvaluationRow( + messages=[ + Message(role="system", content="You are a helpful assistant with access to tools."), + Message( + role="user", + content="Get the weather for Boston, San Francisco, and Seattle (all in Fahrenheit). Think about which cities to check and explain your approach.", + ), + ], + tools=[ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location", "unit"], + }, + }, + } + ], +) +MULTIPLE_TOOLS_WITH_REASONING_ROW.input_metadata.dataset_info = { + "test_name": "multiple_tools_with_reasoning_stream", + "description": "Streaming: multiple tool calls + reasoning", +} + + +@evaluation_test( + input_rows=[[MULTIPLE_TOOLS_WITH_REASONING_ROW]], + completion_params=[ + { + "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "stream": True, + "reasoning_effort": "low", + "temperature": 0.0, + "max_tokens": DEFAULT_MAX_TOKENS, + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + passed_threshold=1.0, + mode="pointwise", +) +def test_streaming_multiple_tools_with_reasoning(row: EvaluationRow) -> EvaluationRow: + """Validate multiple tool calls with reasoning in streaming mode.""" + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) + + reasoning_present = bool(reasoning_str) + has_multiple_tools = len(calls) >= 3 + finish_reason_tool_calls = finish_reason == "tool_calls" + + # Check that all 3 cities are covered + cities_covered = set() + for name, args in calls: + if name == "get_current_weather" and isinstance(args, dict): + location = (args.get("location") or "").lower() + if "boston" in location: + cities_covered.add("boston") + if "san francisco" in location or "sf" in location: + cities_covered.add("san_francisco") + if "seattle" in location: + cities_covered.add("seattle") + + all_cities_covered = len(cities_covered) == 3 + + metrics = { + "reasoning_present": MetricResult( + score=1.0 if reasoning_present else 0.0, + is_score_valid=True, + reason="reasoning_content present" if reasoning_present else "reasoning_content missing", + data={"reasoning_length": len(reasoning_str), "reasoning_preview": reasoning_str[:200]}, + ), + "has_multiple_tools": MetricResult( + score=1.0 if has_multiple_tools else 0.0, + is_score_valid=True, + reason=f"{len(calls)} tool calls (expected 3+)" if has_multiple_tools else f"Only {len(calls)} tool calls", + data={"tool_call_count": len(calls)}, + ), + "all_cities_covered": MetricResult( + score=1.0 if all_cities_covered else 0.0, + is_score_valid=has_multiple_tools, + reason="All 3 cities covered" if all_cities_covered else f"Only {len(cities_covered)} cities covered", + data={"cities": list(cities_covered), "tool_calls": [{"name": n, "args": a} for n, a in calls]}, + ), + "finish_reason_tool_calls": MetricResult( + score=1.0 if finish_reason_tool_calls else 0.0, + is_score_valid=True, + reason="finish_reason is tool_calls" if finish_reason_tool_calls else f"Unexpected: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + reasoning_present + and has_multiple_tools + and all_cities_covered + and finish_reason_tool_calls + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Multiple tools + reasoning work together" + if all_checks_passed + else "Multiple tools or reasoning invalid", + metrics=metrics, + ) + return row + + +# Non-streaming version +MULTIPLE_TOOLS_WITH_REASONING_NON_STREAM_ROW = EvaluationRow( + messages=[ + Message(role="system", content="You are a helpful assistant with access to tools."), + Message( + role="user", + content="Get the weather for Boston, San Francisco, and Seattle (all in Fahrenheit). Think about which cities to check and explain your approach.", + ), + ], + tools=[ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location", "unit"], + }, + }, + } + ], +) +MULTIPLE_TOOLS_WITH_REASONING_NON_STREAM_ROW.input_metadata.dataset_info = { + "test_name": "multiple_tools_with_reasoning_non_stream", + "description": "Non-streaming: multiple tool calls + reasoning", +} + + +@evaluation_test( + input_rows=[[MULTIPLE_TOOLS_WITH_REASONING_NON_STREAM_ROW]], + completion_params=[ + { + "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "stream": False, + "reasoning_effort": "low", + "temperature": 0.0, + "max_tokens": DEFAULT_MAX_TOKENS, + } + ], + rollout_processor=SingleTurnRolloutProcessor(), + passed_threshold=1.0, + mode="pointwise", +) +def test_non_streaming_multiple_tools_with_reasoning(row: EvaluationRow) -> EvaluationRow: + """Non-streaming version: Validate multiple tool calls with reasoning.""" + assistant_msg = row.last_assistant_message() + finish_reason = row.execution_metadata.finish_reason + + content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" + reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" + calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) + + reasoning_present = bool(reasoning_str) + has_multiple_tools = len(calls) >= 3 + finish_reason_tool_calls = finish_reason == "tool_calls" + + # Check that all 3 cities are covered + cities_covered = set() + for name, args in calls: + if name == "get_current_weather" and isinstance(args, dict): + location = (args.get("location") or "").lower() + if "boston" in location: + cities_covered.add("boston") + if "san francisco" in location or "sf" in location: + cities_covered.add("san_francisco") + if "seattle" in location: + cities_covered.add("seattle") + + all_cities_covered = len(cities_covered) == 3 + + metrics = { + "reasoning_present": MetricResult( + score=1.0 if reasoning_present else 0.0, + is_score_valid=True, + reason="reasoning_content present" if reasoning_present else "reasoning_content missing", + data={"reasoning_length": len(reasoning_str), "reasoning_preview": reasoning_str[:200]}, + ), + "has_multiple_tools": MetricResult( + score=1.0 if has_multiple_tools else 0.0, + is_score_valid=True, + reason=f"{len(calls)} tool calls (expected 3+)" if has_multiple_tools else f"Only {len(calls)} tool calls", + data={"tool_call_count": len(calls)}, + ), + "all_cities_covered": MetricResult( + score=1.0 if all_cities_covered else 0.0, + is_score_valid=has_multiple_tools, + reason="All 3 cities covered" if all_cities_covered else f"Only {len(cities_covered)} cities covered", + data={"cities": list(cities_covered), "tool_calls": [{"name": n, "args": a} for n, a in calls]}, + ), + "finish_reason_tool_calls": MetricResult( + score=1.0 if finish_reason_tool_calls else 0.0, + is_score_valid=True, + reason="finish_reason is tool_calls" if finish_reason_tool_calls else f"Unexpected: {finish_reason}", + ), + } + + finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( + metrics, finish_reason, content_str, reasoning_str + ) + + all_checks_passed = ( + reasoning_present + and has_multiple_tools + and all_cities_covered + and finish_reason_tool_calls + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage + ) + + row.evaluation_result = EvaluateResult( + score=1.0 if all_checks_passed else 0.0, + is_score_valid=True, + reason="Multiple tools + reasoning work together" + if all_checks_passed + else "Multiple tools or reasoning invalid", + metrics=metrics, + ) + return row diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index ad78046d..e9276373 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -97,10 +97,13 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: assert isinstance(response, ModelResponse), "Response should be ModelResponse" assert isinstance(response.choices[0], Choices), "Response choice should be a Choices" + assistant_message = response.choices[0].message finish_reason = getattr(response.choices[0], "finish_reason", None) - assistant_message = response.choices[0].message + # Extract content assistant_content = assistant_message.content or "" + + # Extract reasoning content (if present) reasoning_content = getattr(assistant_message, "reasoning_content", None) if reasoning_content is None: reasoning_content = getattr(assistant_message, "reasoning", None) @@ -109,6 +112,8 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: reasoning_content = json.dumps(reasoning_content) except Exception: reasoning_content = str(reasoning_content) + + # Extract tool calls tool_calls = assistant_message.tool_calls if assistant_message.tool_calls else None converted_tool_calls = None From 688e87f81c2af20a5392264964f53cf797ed61a8 Mon Sep 17 00:00:00 2001 From: Shrey Modi Date: Mon, 1 Dec 2025 12:14:17 -0800 Subject: [PATCH 13/13] yo --- .../test_glm_streaming_compliance.py | 193 ++++++------------ 1 file changed, 60 insertions(+), 133 deletions(-) diff --git a/eval_protocol/benchmarks/test_glm_streaming_compliance.py b/eval_protocol/benchmarks/test_glm_streaming_compliance.py index b570cfa4..07af260b 100644 --- a/eval_protocol/benchmarks/test_glm_streaming_compliance.py +++ b/eval_protocol/benchmarks/test_glm_streaming_compliance.py @@ -18,7 +18,7 @@ from eval_protocol.pytest.evaluation_test import evaluation_test -DEFAULT_MODEL_ID = "fireworks_ai/accounts/fireworks/models/glm-4p6" +DEFAULT_MODEL_ID = "fireworks_ai/accounts/pyroworks/deployedModels/minimax-m2-zmi4qk9f" DEFAULT_MAX_TOKENS = 10000 @@ -153,7 +153,34 @@ def _safe_json_loads(payload: str) -> Any | None: "content": "Call test_brace_bug with param1='test_value', param2=42, and param3=true", } ], - "tools": WEATHER_TOOL_DEFINITION, + "tools": [ + { + "type": "function", + "function": { + "name": "test_brace_bug", + "description": "A test function to validate JSON brace handling in tool arguments", + "parameters": { + "type": "object", + "properties": { + "param1": { + "type": "string", + "description": "A string parameter", + }, + "param2": { + "type": "integer", + "description": "An integer parameter", + }, + "param3": { + "type": "boolean", + "description": "A boolean parameter", + }, + }, + "required": ["param1", "param2", "param3"], + "additionalProperties": False, + }, + }, + } + ], "temperature": 0.1, "top_p": 1, } @@ -468,48 +495,6 @@ def _safe_json_loads(payload: str) -> Any | None: "stream": True, } -PEER_TOOL_RECOVERY_FAILURE_PAYLOAD = { - "messages": [ - { - "role": "user", - "content": ( - "View the file at /tmp/test.txt. If that fails, try again with the correct parameters. " - "Keep retrying until it works." - ), - } - ], - "tools": [ - { - "type": "function", - "function": { - "name": "view", - "description": "View a file or directory", - "strict": True, - "parameters": { - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "Path to the file or directory to view", - }, - "type": { - "type": "string", - "enum": ["file", "directory"], - "description": "Type of the path (file or directory)", - }, - }, - "required": ["path", "type"], - "additionalProperties": False, - }, - }, - } - ], - "tool_choice": "required", - "temperature": 0.1, - "max_tokens": 4000, - "stream": True, -} - def _build_row_from_payload(case: str, payload: dict[str, Any]) -> EvaluationRow: messages = [ @@ -1329,13 +1314,13 @@ def test_streaming_multiple_tool_calls(row: EvaluationRow) -> EvaluationRow: return row -_PEER_TOOL_MISSING_REQUIRED_ROW = _build_row_from_payload( - "peer-tool-missing-required-param", PEER_TOOL_MISSING_REQUIRED_PARAM_PAYLOAD +_PEER_TOOL_REQUIRED_PARAMS_ROW = _build_row_from_payload( + "peer-tool-required-params", PEER_TOOL_MISSING_REQUIRED_PARAM_PAYLOAD ) @evaluation_test( - input_rows=[[_PEER_TOOL_MISSING_REQUIRED_ROW]], + input_rows=[[_PEER_TOOL_REQUIRED_PARAMS_ROW]], completion_params=[_build_completion_params_from_payload(PEER_TOOL_MISSING_REQUIRED_PARAM_PAYLOAD)], rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", @@ -1343,22 +1328,23 @@ def test_streaming_multiple_tool_calls(row: EvaluationRow) -> EvaluationRow: num_runs=1, mode="pointwise", ) -def test_streaming_tool_missing_required_param(row: EvaluationRow) -> EvaluationRow: - """Detect whether required parameters are omitted during streaming.""" +def test_streaming_tool_required_params_present(row: EvaluationRow) -> EvaluationRow: + """Verify that tool calls include all required parameters during streaming.""" assistant_msg = row.last_assistant_message() finish_reason = row.execution_metadata.finish_reason - _debug_log_assistant_message("tool_missing_required_param", assistant_msg, finish_reason) + _debug_log_assistant_message("tool_required_params", assistant_msg, finish_reason) content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) - missing_required = False + required_params_present = False arguments = None for _, args in calls: if args: arguments = args - missing_required = "type" not in args or args.get("type") not in {"file", "directory"} + # Check that required 'type' param is present and valid + required_params_present = "type" in args and args.get("type") in {"file", "directory"} metrics = { "tool_call_emitted": MetricResult( @@ -1366,10 +1352,12 @@ def test_streaming_tool_missing_required_param(row: EvaluationRow) -> Evaluation is_score_valid=True, reason="Tool call emitted" if calls else "No tool call emitted", ), - "missing_required_param": MetricResult( - score=1.0 if missing_required else 0.0, + "required_params_present": MetricResult( + score=1.0 if required_params_present else 0.0, is_score_valid=bool(calls), - reason="Required parameter missing or invalid" if missing_required else "All required parameters present", + reason="All required parameters present" + if required_params_present + else "Required parameter missing or invalid", data={"arguments": arguments}, ), "finish_reason": MetricResult( @@ -1386,15 +1374,19 @@ def test_streaming_tool_missing_required_param(row: EvaluationRow) -> Evaluation ) all_checks_passed = ( - missing_required and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage + required_params_present + and finish_reason_present + and no_forbidden_tags + and no_xml_tags + and no_reasoning_leakage ) row.evaluation_result = EvaluateResult( score=1.0 if all_checks_passed else 0.0, is_score_valid=True, - reason="Detected missing required parameter" + reason="All required parameters included in tool call" if all_checks_passed - else "Required parameters satisfied or response invalid", + else "Required parameters missing or response invalid", metrics=metrics, ) return row @@ -1674,71 +1666,6 @@ def test_streaming_tool_parameter_types(row: EvaluationRow) -> EvaluationRow: return row -_PEER_TOOL_RECOVERY_ROW = _build_row_from_payload("peer-tool-recovery-failure", PEER_TOOL_RECOVERY_FAILURE_PAYLOAD) - - -@evaluation_test( - input_rows=[[_PEER_TOOL_RECOVERY_ROW]], - completion_params=[_build_completion_params_from_payload(PEER_TOOL_RECOVERY_FAILURE_PAYLOAD)], - rollout_processor=SingleTurnRolloutProcessor(), - aggregation_method="mean", - passed_threshold=0.0, - num_runs=1, - mode="pointwise", -) -def test_streaming_tool_retry_behavior(row: EvaluationRow) -> EvaluationRow: - """Check whether the assistant retries tool calls when instructed to recover.""" - - assistant_msg = row.last_assistant_message() - print(f"assistant_msg: {assistant_msg}") - finish_reason = row.execution_metadata.finish_reason - _debug_log_assistant_message("tool_recovery", assistant_msg, finish_reason) - content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else "" - calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else []) - reasoning = (assistant_msg.reasoning_content or "").strip() if assistant_msg else "" - - multiple_attempts = len(calls) >= 2 - metrics = { - "tool_call_attempts": MetricResult( - score=1.0 if multiple_attempts else 0.0, - is_score_valid=True, - reason="Multiple tool call attempts" if multiple_attempts else "Single/no tool call attempt", - data={"tool_call_count": len(calls)}, - ), - "reasoning_present": MetricResult( - score=1.0 if reasoning else 0.0, - is_score_valid=True, - reason="Reasoning present" if reasoning else "No reasoning provided", - data={"reasoning": reasoning[:160]}, - ), - "finish_reason": MetricResult( - score=1.0 if finish_reason in {"tool_calls", "stop"} else 0.0, - is_score_valid=True, - reason="finish_reason acceptable" - if finish_reason in {"tool_calls", "stop"} - else f"Unexpected finish_reason: {finish_reason}", - ), - } - - finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks( - metrics, finish_reason, content_str, reasoning - ) - - all_checks_passed = ( - multiple_attempts and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage - ) - - row.evaluation_result = EvaluateResult( - score=1.0 if all_checks_passed else 0.0, - is_score_valid=True, - reason="Multiple recovery attempts observed" - if all_checks_passed - else "Recovery attempts missing or response invalid", - metrics=metrics, - ) - return row - - # ============================================================================ # Reasoning Effort Tests # ============================================================================ @@ -1759,7 +1686,7 @@ def test_streaming_tool_retry_behavior(row: EvaluationRow) -> EvaluationRow: input_rows=[[REASONING_DISABLED_ROW]], completion_params=[ { - "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", # Reasoning-capable model + "model": DEFAULT_MODEL_ID, # Reasoning-capable model "reasoning_effort": "none", # Explicitly disable reasoning "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.0, @@ -1869,7 +1796,7 @@ def test_reasoning_effort_none_no_reasoning(row: EvaluationRow) -> EvaluationRow input_rows=[[REASONING_ENABLED_ROW]], completion_params=[ { - "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", # Reasoning-capable model + "model": DEFAULT_MODEL_ID, # Reasoning-capable model "reasoning_effort": "low", # Enable reasoning "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.0, @@ -2004,7 +1931,7 @@ def test_reasoning_effort_low_has_reasoning(row: EvaluationRow) -> EvaluationRow input_rows=[[TOOLS_WITH_REASONING_ROW]], completion_params=[ { - "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", # Reasoning-capable model + "model": DEFAULT_MODEL_ID, # Reasoning-capable model "reasoning_effort": "low", # Enable reasoning "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.0, @@ -2727,7 +2654,7 @@ def test_non_streaming_multiple_tool_calls(row: EvaluationRow) -> EvaluationRow: input_rows=[[REASONING_DISABLED_NON_STREAM_ROW]], completion_params=[ { - "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "model": DEFAULT_MODEL_ID, "reasoning_effort": "none", "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.0, @@ -2834,7 +2761,7 @@ def test_reasoning_effort_none_no_reasoning_non_stream(row: EvaluationRow) -> Ev input_rows=[[REASONING_ENABLED_NON_STREAM_ROW]], completion_params=[ { - "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "model": DEFAULT_MODEL_ID, "reasoning_effort": "low", "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.0, @@ -2962,7 +2889,7 @@ def test_reasoning_effort_low_has_reasoning_non_stream(row: EvaluationRow) -> Ev input_rows=[[TOOLS_WITH_REASONING_NON_STREAM_ROW]], completion_params=[ { - "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "model": DEFAULT_MODEL_ID, "reasoning_effort": "low", "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.0, @@ -3108,7 +3035,7 @@ def test_non_streaming_tools_with_reasoning(row: EvaluationRow) -> EvaluationRow input_rows=[[STRUCTURED_OUTPUT_WITH_REASONING_ROW]], completion_params=[ { - "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "model": DEFAULT_MODEL_ID, "stream": True, "reasoning_effort": "low", "response_format": STRUCTURED_JSON_SCHEMA, @@ -3211,7 +3138,7 @@ def test_streaming_structured_output_with_reasoning(row: EvaluationRow) -> Evalu input_rows=[[STRUCTURED_OUTPUT_WITH_REASONING_NON_STREAM_ROW]], completion_params=[ { - "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "model": DEFAULT_MODEL_ID, "stream": False, "reasoning_effort": "low", "response_format": STRUCTURED_JSON_SCHEMA, @@ -3334,7 +3261,7 @@ def test_non_streaming_structured_output_with_reasoning(row: EvaluationRow) -> E input_rows=[[MULTIPLE_TOOLS_WITH_REASONING_ROW]], completion_params=[ { - "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "model": DEFAULT_MODEL_ID, "stream": True, "reasoning_effort": "low", "temperature": 0.0, @@ -3461,7 +3388,7 @@ def test_streaming_multiple_tools_with_reasoning(row: EvaluationRow) -> Evaluati input_rows=[[MULTIPLE_TOOLS_WITH_REASONING_NON_STREAM_ROW]], completion_params=[ { - "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", + "model": DEFAULT_MODEL_ID, "stream": False, "reasoning_effort": "low", "temperature": 0.0,