Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions eval_protocol/adapters/fireworks_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(
self.project_id = project_id
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self._session = requests.Session()

def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -> List[Dict[str, Any]]:
"""Fetch logs from Fireworks tracing gateway /logs endpoint.
Expand All @@ -287,14 +288,14 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -
last_error: Optional[str] = None
for url in urls_to_try:
try:
response = requests.get(url, params=params, timeout=self.timeout, headers=headers)
if response.status_code == 404:
# Try next variant
last_error = f"404 for {url}"
continue
response.raise_for_status()
data = response.json() or {}
break
with self._session.get(url, params=params, timeout=self.timeout, headers=headers) as response:
if response.status_code == 404:
# Try next variant (must close response to release connection)
last_error = f"404 for {url}"
continue
response.raise_for_status()
data = response.json() or {}
break
except requests.exceptions.RequestException as e:
last_error = str(e)
continue
Expand Down Expand Up @@ -412,22 +413,20 @@ def get_evaluation_rows(

result = None
try:
response = requests.get(url, params=params, timeout=self.timeout, headers=headers)
response.raise_for_status()
result = response.json()
except requests.exceptions.HTTPError as e:
error_msg = str(e)

# Try to extract detail message from response
if e.response is not None:
try:
error_detail = e.response.json().get("detail", {})
error_msg = error_detail or e.response.text
except Exception: # In case e.response.json() fails
error_msg = f"Proxy error: {e.response.text}"

logger.error("Failed to fetch traces from proxy (HTTP %s): %s", e.response.status_code, error_msg)
return eval_rows
with self._session.get(url, params=params, timeout=self.timeout, headers=headers) as response:
if response.status_code >= 400:
error_msg: str = response.text
try:
payload = response.json()
if isinstance(payload, dict) and "detail" in payload:
detail = payload.get("detail")
if detail:
error_msg = str(detail)
except Exception:
pass
logger.error("Failed to fetch traces from proxy (HTTP %s): %s", response.status_code, error_msg)
return eval_rows
result = response.json()
except requests.exceptions.RequestException as e:
# Non-HTTP errors (network issues, timeouts, etc.)
logger.error("Failed to fetch traces from proxy: %s", str(e))
Expand All @@ -451,3 +450,10 @@ def get_evaluation_rows(

logger.info("Successfully converted %d traces to evaluation rows", len(eval_rows))
return eval_rows

def close(self) -> None:
"""Close underlying HTTP resources."""
try:
self._session.close()
except Exception:
pass
19 changes: 14 additions & 5 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
self._timeout_seconds = timeout_seconds
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)
self._session = requests.Session()

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
tasks: List[asyncio.Task[EvaluationRow]] = []
Expand Down Expand Up @@ -94,8 +95,8 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
def _post_init() -> None:
url = f"{remote_base_url}/init"
try:
r = requests.post(url, json=init_payload.model_dump(), timeout=300)
r.raise_for_status()
with self._session.post(url, json=init_payload.model_dump(), timeout=300) as r:
r.raise_for_status()
except requests.exceptions.Timeout:
raise TimeoutError(
f"The /init endpoint tried {url} with {init_payload.model_dump()} but timed out after 300 seconds."
Expand All @@ -108,9 +109,9 @@ def _post_init() -> None:

def _get_status() -> Dict[str, Any]:
url = f"{remote_base_url}/status"
r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15)
r.raise_for_status()
return r.json()
with self._session.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15) as r:
r.raise_for_status()
return r.json()

continue_polling_status = True
while time.time() < deadline:
Expand Down Expand Up @@ -204,4 +205,12 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
return tasks

def cleanup(self) -> None:
try:
self._tracing_adapter.close()
except Exception:
pass
try:
self._session.close()
except Exception:
pass
return None
Loading