Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
FIREWORKS_API_KEY="your_fireworks_api_key_here"
FIREWORKS_ACCOUNT_ID="your_fireworks_account_id_here" # e.g., "fireworks" or your specific account

# OpenAI Credentials (for using OpenAI models as judge)
OPENAI_API_KEY="your_openai_api_key_here"

# Optional: If targeting a non-production Fireworks API endpoint
# FIREWORKS_API_BASE="https://dev.api.fireworks.ai"

Expand Down
73 changes: 73 additions & 0 deletions eval_protocol/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,79 @@ def get_fireworks_api_base() -> str:
return api_base


def get_extra_headers() -> Dict[str, str]:
"""
Retrieves extra headers from the FIREWORKS_EXTRA_HEADERS environment variable.

The value should be a JSON object mapping header names to values.
Example: FIREWORKS_EXTRA_HEADERS='{"x-custom-header": "value", "x-another": "value2"}'

Returns:
Dictionary of extra headers, or empty dict if not set or invalid.
"""
import json

extra_headers_str = os.environ.get("FIREWORKS_EXTRA_HEADERS")
if not extra_headers_str:
return {}

try:
extra_headers = json.loads(extra_headers_str)
if isinstance(extra_headers, dict):
# Ensure all values are strings
return {str(k): str(v) for k, v in extra_headers.items()}
else:
logger.warning("FIREWORKS_EXTRA_HEADERS must be a JSON object, got %s", type(extra_headers).__name__)
return {}
except json.JSONDecodeError as e:
logger.warning("Failed to parse FIREWORKS_EXTRA_HEADERS as JSON: %s", e)
return {}


def get_platform_headers(
api_key: Optional[str] = None,
content_type: Optional[str] = "application/json",
include_extra_headers: bool = True,
) -> Dict[str, str]:
"""
Builds standard headers for Fireworks platform API requests.

This centralizes header construction including:
- Authorization bearer token
- Content-Type
- User-Agent
- Extra headers from FIREWORKS_EXTRA_HEADERS env var (JSON format)

Args:
api_key: The API key for authorization. If None, resolves via get_fireworks_api_key().
content_type: The Content-Type header value. Set to None to omit.
include_extra_headers: Whether to include extra headers from FIREWORKS_EXTRA_HEADERS env var.

Returns:
Dictionary of headers for platform API requests.
"""
from .common_utils import get_user_agent

resolved_api_key = api_key or get_fireworks_api_key()

headers: Dict[str, str] = {
"User-Agent": get_user_agent(),
}

if resolved_api_key:
headers["Authorization"] = f"Bearer {resolved_api_key}"

if content_type:
headers["Content-Type"] = content_type

# Include extra headers if set in environment
if include_extra_headers:
extra = get_extra_headers()
headers.update(extra)

return headers


def verify_api_key_and_get_account_id(
api_key: Optional[str] = None,
api_base: Optional[str] = None,
Expand Down
16 changes: 4 additions & 12 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import requests
from pydantic import ValidationError

from ..auth import get_fireworks_api_base, get_fireworks_api_key
from ..auth import get_fireworks_api_base, get_fireworks_api_key, get_platform_headers
from ..common_utils import get_user_agent
from ..fireworks_rft import (
build_default_output_model,
Expand Down Expand Up @@ -175,11 +175,7 @@ def _poll_evaluator_status(
Returns:
True if evaluator becomes ACTIVE, False if timeout or BUILD_FAILED
"""
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"User-Agent": get_user_agent(),
}
headers = get_platform_headers(api_key=api_key, content_type="application/json")

check_url = f"{api_base}/v1/{evaluator_resource_name}"
timeout_seconds = timeout_minutes * 60
Expand Down Expand Up @@ -517,11 +513,7 @@ def _upload_and_ensure_evaluator(
# Optional short-circuit: if evaluator already exists and not forcing, skip upload path
if not force:
try:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"User-Agent": get_user_agent(),
}
headers = get_platform_headers(api_key=api_key, content_type="application/json")
resp = requests.get(f"{api_base}/v1/{evaluator_resource_name}", headers=headers, timeout=10)
if resp.ok:
state = resp.json().get("state", "STATE_UNSPECIFIED")
Expand Down Expand Up @@ -702,7 +694,7 @@ def _create_rft_job(
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")
if getattr(args, "evaluation_dataset", None):
body["evaluationDataset"] = args.evaluation_dataset

output_model_arg = getattr(args, "output_model", None)
if output_model_arg:
if len(output_model_arg) > 63:
Expand Down
13 changes: 3 additions & 10 deletions eval_protocol/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from eval_protocol.auth import (
get_fireworks_account_id,
get_fireworks_api_key,
get_platform_headers,
verify_api_key_and_get_account_id,
)
from eval_protocol.common_utils import get_user_agent
Expand Down Expand Up @@ -403,11 +404,7 @@ def preview(self, sample_file, max_samples=5):
account_id = "pyroworks-dev"

url = f"{api_base}/v1/accounts/{account_id}/evaluators:previewEvaluator"
headers = {
"Authorization": f"Bearer {auth_token}",
"Content-Type": "application/json",
"User-Agent": get_user_agent(),
}
headers = get_platform_headers(api_key=auth_token, content_type="application/json")
logger.info(f"Previewing evaluator using API endpoint: {url} with account: {account_id}")
logger.debug(f"Preview API Request URL: {url}")
logger.debug(f"Preview API Request Headers: {json.dumps(headers, indent=2)}")
Expand Down Expand Up @@ -749,11 +746,7 @@ def create(self, evaluator_id, display_name=None, description=None, force=False)
account_id = "pyroworks-dev"

base_url = f"{self.api_base}/v1/{parent}/evaluatorsV2"
headers = {
"Authorization": f"Bearer {auth_token}",
"Content-Type": "application/json",
"User-Agent": get_user_agent(),
}
headers = get_platform_headers(api_key=auth_token, content_type="application/json")

self._ensure_requirements_present(os.getcwd())

Expand Down
41 changes: 20 additions & 21 deletions eval_protocol/fireworks_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

import requests

from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key
from .common_utils import get_user_agent
from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key, get_platform_headers


def _map_api_host_to_app_host(api_base: str) -> str:
Expand Down Expand Up @@ -142,11 +141,17 @@ def create_dataset_from_jsonl(
display_name: Optional[str],
jsonl_path: str,
) -> Tuple[str, Dict[str, Any]]:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"User-Agent": get_user_agent(),
}
import os

# DEBUG: Check environment variable
extra_headers_env = os.environ.get("FIREWORKS_EXTRA_HEADERS", "<NOT SET>")
print(f"[DEBUG] FIREWORKS_EXTRA_HEADERS env: {extra_headers_env}")
Comment on lines +146 to +148

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Stop printing extra header secrets during dataset creation

create_dataset_from_jsonl now unconditionally prints the value of FIREWORKS_EXTRA_HEADERS (and the constructed headers below) to stdout. If callers set that env var with API tokens or other sensitive headers, every dataset creation leaks those secrets into logs/terminals because there is no debug guard or masking.

Useful? React with 👍 / 👎.


headers = get_platform_headers(api_key=api_key, content_type="application/json")

# DEBUG: Print headers (mask auth token)
debug_headers = {k: (v[:20] + "..." if k == "Authorization" else v) for k, v in headers.items()}
print(f"[DEBUG] Headers being sent: {debug_headers}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Debug print statements accidentally committed to production code

Debug print statements were left in the create_dataset_from_jsonl function. These lines print [DEBUG] prefixed messages to stdout, including the FIREWORKS_EXTRA_HEADERS environment variable value and partial authorization tokens (first 20 characters). This debugging code pollutes output for production users and potentially leaks sensitive information.

Fix in Cursor Fix in Web

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Debug print statements accidentally committed to production code

Debug print statements were left in the create_dataset_from_jsonl function. These lines print [DEBUG] prefixed messages to stdout, including the FIREWORKS_EXTRA_HEADERS environment variable value and partial authorization tokens (first 20 characters). This debugging code pollutes output for production users and potentially leaks sensitive information.

Fix in Cursor Fix in Web

# Count examples quickly
example_count = 0
with open(jsonl_path, "r", encoding="utf-8") as f:
Expand All @@ -171,10 +176,8 @@ def create_dataset_from_jsonl(
upload_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets/{dataset_id}:upload"
with open(jsonl_path, "rb") as f:
files = {"file": f}
up_headers = {
"Authorization": f"Bearer {api_key}",
"User-Agent": get_user_agent(),
}
# For file uploads, omit Content-Type (let requests set multipart boundary)
up_headers = get_platform_headers(api_key=api_key, content_type=None)
up_resp = requests.post(upload_url, files=files, headers=up_headers, timeout=600)
if up_resp.status_code not in (200, 201):
raise RuntimeError(f"Dataset upload failed: {up_resp.status_code} {up_resp.text}")
Expand All @@ -196,12 +199,8 @@ def create_reinforcement_fine_tuning_job(
# Remove from body and append as query param
body.pop("jobId", None)
url = f"{url}?{urlencode({'reinforcementFineTuningJobId': job_id})}"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
"User-Agent": get_user_agent(),
}
headers = get_platform_headers(api_key=api_key, content_type="application/json")
headers["Accept"] = "application/json"
resp = requests.post(url, json=body, headers=headers, timeout=60)
if resp.status_code not in (200, 201):
raise RuntimeError(f"RFT job creation failed: {resp.status_code} {resp.text}")
Expand All @@ -217,22 +216,22 @@ def build_default_dataset_id(evaluator_id: str) -> str:
def build_default_output_model(evaluator_id: str) -> str:
base = evaluator_id.lower().replace("_", "-")
uuid_suffix = str(uuid.uuid4())[:4]

# suffix is "-rft-{4chars}" -> 9 chars
suffix_len = 9
max_len = 63

# Check if we need to truncate
if len(base) + suffix_len > max_len:
# Calculate hash of the full base to preserve uniqueness
hash_digest = hashlib.sha256(base.encode("utf-8")).hexdigest()[:6]
# New structure: {truncated_base}-{hash}-{uuid_suffix}
# Space needed for "-{hash}" is 1 + 6 = 7
hash_part_len = 7

allowed_base_len = max_len - suffix_len - hash_part_len
truncated_base = base[:allowed_base_len].strip("-")

return f"{truncated_base}-{hash_digest}-rft-{uuid_suffix}"

return f"{base}-rft-{uuid_suffix}"
Expand Down
18 changes: 4 additions & 14 deletions eval_protocol/platform_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
get_fireworks_account_id,
get_fireworks_api_base,
get_fireworks_api_key,
get_platform_headers,
)
from eval_protocol.common_utils import get_user_agent

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -93,11 +93,7 @@ def create_or_update_fireworks_secret(
logger.error("Missing Fireworks API key, base URL, or account ID for creating/updating secret.")
return False

headers = {
"Authorization": f"Bearer {resolved_api_key}",
"Content-Type": "application/json",
"User-Agent": get_user_agent(),
}
headers = get_platform_headers(api_key=resolved_api_key, content_type="application/json")

# The secret_id for GET/PATCH/DELETE operations is the key_name.
# The 'name' field in the gatewaySecret model for POST/PATCH is a bit ambiguous.
Expand Down Expand Up @@ -219,10 +215,7 @@ def get_fireworks_secret(
logger.error("Missing Fireworks API key, base URL, or account ID for getting secret.")
return None

headers = {
"Authorization": f"Bearer {resolved_api_key}",
"User-Agent": get_user_agent(),
}
headers = get_platform_headers(api_key=resolved_api_key, content_type=None)
resource_id = _normalize_secret_resource_id(key_name)

try:
Expand Down Expand Up @@ -259,10 +252,7 @@ def delete_fireworks_secret(
logger.error("Missing Fireworks API key, base URL, or account ID for deleting secret.")
return False

headers = {
"Authorization": f"Bearer {resolved_api_key}",
"User-Agent": get_user_agent(),
}
headers = get_platform_headers(api_key=resolved_api_key, content_type=None)
resource_id = _normalize_secret_resource_id(key_name)

try:
Expand Down
14 changes: 4 additions & 10 deletions eval_protocol/pytest/handle_persist_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import re
from typing import Any

from eval_protocol.common_utils import get_user_agent
from eval_protocol.directory_utils import find_eval_protocol_dir
from eval_protocol.models import EvaluationRow
from eval_protocol.pytest.store_experiment_link import store_experiment_link
Expand All @@ -16,6 +15,7 @@
get_fireworks_account_id,
verify_api_key_and_get_account_id,
get_fireworks_api_base,
get_platform_headers,
)

import requests
Expand Down Expand Up @@ -130,11 +130,7 @@ def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name:
continue

api_base = get_fireworks_api_base()
headers = {
"Authorization": f"Bearer {fireworks_api_key}",
"Content-Type": "application/json",
"User-Agent": get_user_agent(),
}
headers = get_platform_headers(api_key=fireworks_api_key, content_type="application/json")

# Make dataset first

Expand Down Expand Up @@ -167,10 +163,8 @@ def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name:
upload_url = f"{api_base}/v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload"
with open(exp_file, "rb") as f:
files = {"file": f}
upload_headers = {
"Authorization": f"Bearer {fireworks_api_key}",
"User-Agent": get_user_agent(),
}
# For file uploads, omit Content-Type (let requests set multipart boundary)
upload_headers = get_platform_headers(api_key=fireworks_api_key, content_type=None)
upload_response = requests.post(upload_url, files=files, headers=upload_headers)

# Skip if upload failed
Expand Down
Loading