diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 218f9d1d..7be2bb45 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -264,6 +264,7 @@ def __init__( self.project_id = project_id self.base_url = base_url.rstrip("/") self.timeout = timeout + self._session = requests.Session() def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -> List[Dict[str, Any]]: """Fetch logs from Fireworks tracing gateway /logs endpoint. @@ -287,14 +288,14 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) - last_error: Optional[str] = None for url in urls_to_try: try: - response = requests.get(url, params=params, timeout=self.timeout, headers=headers) - if response.status_code == 404: - # Try next variant - last_error = f"404 for {url}" - continue - response.raise_for_status() - data = response.json() or {} - break + with self._session.get(url, params=params, timeout=self.timeout, headers=headers) as response: + if response.status_code == 404: + # Try next variant (must close response to release connection) + last_error = f"404 for {url}" + continue + response.raise_for_status() + data = response.json() or {} + break except requests.exceptions.RequestException as e: last_error = str(e) continue @@ -412,22 +413,20 @@ def get_evaluation_rows( result = None try: - response = requests.get(url, params=params, timeout=self.timeout, headers=headers) - response.raise_for_status() - result = response.json() - except requests.exceptions.HTTPError as e: - error_msg = str(e) - - # Try to extract detail message from response - if e.response is not None: - try: - error_detail = e.response.json().get("detail", {}) - error_msg = error_detail or e.response.text - except Exception: # In case e.response.json() fails - error_msg = f"Proxy error: {e.response.text}" - - logger.error("Failed to fetch traces from proxy (HTTP %s): %s", e.response.status_code, error_msg) - return eval_rows + with self._session.get(url, params=params, timeout=self.timeout, headers=headers) as response: + if response.status_code >= 400: + error_msg: str = response.text + try: + payload = response.json() + if isinstance(payload, dict) and "detail" in payload: + detail = payload.get("detail") + if detail: + error_msg = str(detail) + except Exception: + pass + logger.error("Failed to fetch traces from proxy (HTTP %s): %s", response.status_code, error_msg) + return eval_rows + result = response.json() except requests.exceptions.RequestException as e: # Non-HTTP errors (network issues, timeouts, etc.) logger.error("Failed to fetch traces from proxy: %s", str(e)) @@ -451,3 +450,10 @@ def get_evaluation_rows( logger.info("Successfully converted %d traces to evaluation rows", len(eval_rows)) return eval_rows + + def close(self) -> None: + """Close underlying HTTP resources.""" + try: + self._session.close() + except Exception: + pass diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index ab42bdcd..818008ae 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -54,6 +54,7 @@ def __init__( self._timeout_seconds = timeout_seconds self._output_data_loader = output_data_loader or default_fireworks_output_data_loader self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url) + self._session = requests.Session() def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: tasks: List[asyncio.Task[EvaluationRow]] = [] @@ -94,8 +95,8 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: def _post_init() -> None: url = f"{remote_base_url}/init" try: - r = requests.post(url, json=init_payload.model_dump(), timeout=300) - r.raise_for_status() + with self._session.post(url, json=init_payload.model_dump(), timeout=300) as r: + r.raise_for_status() except requests.exceptions.Timeout: raise TimeoutError( f"The /init endpoint tried {url} with {init_payload.model_dump()} but timed out after 300 seconds." @@ -108,9 +109,9 @@ def _post_init() -> None: def _get_status() -> Dict[str, Any]: url = f"{remote_base_url}/status" - r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15) - r.raise_for_status() - return r.json() + with self._session.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15) as r: + r.raise_for_status() + return r.json() continue_polling_status = True while time.time() < deadline: @@ -204,4 +205,12 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: return tasks def cleanup(self) -> None: + try: + self._tracing_adapter.close() + except Exception: + pass + try: + self._session.close() + except Exception: + pass return None