diff --git a/agentlightning/execution/client_server.py b/agentlightning/execution/client_server.py index 950610b40..2b3335685 100644 --- a/agentlightning/execution/client_server.py +++ b/agentlightning/execution/client_server.py @@ -124,13 +124,21 @@ def __init__( ) self.allowed_exit_codes = tuple(allowed_exit_codes) + # This flag is set to True after server launches and False before server stops + # Clients check this flag when requests fail - if server is not online, silently ignore errors + # Be mindful of performance: all processes need to synchronously read this flag + ctx = multiprocessing.get_context() + self._server_online = ctx.Value("b", False) # 'b' = signed char, False = 0 + async def _execute_algorithm( self, algorithm: AlgorithmBundle, store: LightningStore, stop_evt: ExecutionEvent ) -> None: wrapper_store: LightningStore | None = None if self.managed_store: logger.info("Starting LightningStore server on %s:%s", self.server_host, self.server_port) - wrapper_store = LightningStoreServer(store, host=self.server_host, port=self.server_port) + wrapper_store = LightningStoreServer( + store, host=self.server_host, port=self.server_port, server_online_flag=self._server_online + ) server_started = False else: wrapper_store = store @@ -173,7 +181,9 @@ async def _execute_runner( ) -> None: if self.managed_store: # If managed, we actually do not use the provided store - client_store = LightningStoreClient(f"http://{self.server_host}:{self.server_port}") + client_store = LightningStoreClient( + f"http://{self.server_host}:{self.server_port}", server_online_flag=self._server_online + ) else: client_store = store try: diff --git a/agentlightning/store/client_server.py b/agentlightning/store/client_server.py index b496135ee..03d33b08a 100644 --- a/agentlightning/store/client_server.py +++ b/agentlightning/store/client_server.py @@ -78,6 +78,15 @@ T_model = TypeVar("T_model", bound=BaseModel) +class ServerShutdownError(Exception): + """Raised when the server is shutting down and requests cannot be completed. + + This exception is raised instead of ServerDisconnectedError when we detect + that the server is permanently unavailable (e.g., during graceful shutdown). + Callers should handle this gracefully without dumping full tracebacks. + """ + + class RolloutRequest(BaseModel): input: TaskInput mode: Optional[Literal["train", "val", "test"]] = None @@ -252,6 +261,8 @@ def __init__( launch_mode: LaunchMode = "thread", launcher_args: PythonServerLauncherArgs | None = None, n_workers: int = 1, + prometheus: bool = False, + server_online_flag: Any = None, tracker: MetricsBackend | None = None, ): super().__init__() @@ -301,12 +312,15 @@ def __init__( # LightningStoreServer holds a plain Python object (self.store) in one process # (the process that runs uvicorn/FastAPI). # When you multiprocessing.Process(...) and call methods on a different LightningStore instance - # (or on a copy inherited via fork), you’re mutating another process’s memory, not the server’s memory. + # (or on a copy inherited via fork), you're mutating another process's memory, not the server's memory. # So we need to track the owner process (whoever creates the server), # and only mutate the store in that process. self._owner_pid = os.getpid() self._client: Optional[LightningStoreClient] = None + # Set to True after server launches, False before server stops + self._server_online_flag = server_online_flag + @property def capabilities(self) -> LightningStoreCapabilities: """Return the capabilities of the store.""" @@ -416,6 +430,11 @@ async def start(self): end_time = time.time() server_logger.info(f"Lightning store server started in {end_time - start_time:.2f} seconds") + # Set server online flag to True after server has launched + if self._server_online_flag is not None: + with self._server_online_flag.get_lock(): + self._server_online_flag.value = True + async def run_forever(self): """Runs the FastAPI server indefinitely.""" server_logger.info( @@ -428,6 +447,11 @@ async def stop(self): You need to call this method in the same process as the server was created in. """ + # Set server online flag to False before server stops + if self._server_online_flag is not None: + with self._server_online_flag.get_lock(): + self._server_online_flag.value = False + server_logger.info("Stopping the lightning store server...") await self.server_launcher.stop() server_logger.info("Lightning store server stopped.") @@ -1355,6 +1379,7 @@ def __init__( health_retry_delays: Sequence[float] = (0.1, 0.2, 0.5), request_timeout: float = 30.0, connection_timeout: float = 5.0, + server_online_flag: Any = None, ): self.server_address_root = server_address.rstrip("/") self.server_address = self.server_address_root + API_V1_AGL_PREFIX @@ -1371,7 +1396,9 @@ def __init__( # Store whether the dequeue was successful in history self._dequeue_was_successful: bool = False - self._dequeue_first_unsuccessful: bool = True + + # When requests fail, check this flag - if server is not online, silently ignore errors + self._server_online_flag = server_online_flag @property def capabilities(self) -> LightningStoreCapabilities: @@ -1421,7 +1448,6 @@ def __setstate__(self, state: Dict[str, Any]): self._request_timeout = state["_request_timeout"] self._connection_timeout = state["_connection_timeout"] self._dequeue_was_successful = False - self._dequeue_first_unsuccessful = True async def _get_session(self) -> aiohttp.ClientSession: # In the proxy process, FastAPI middleware calls @@ -1459,6 +1485,7 @@ async def _wait_until_healthy(self, session: aiohttp.ClientSession) -> bool: """ Probe the server's /health until it responds 200 or retries are exhausted. Returns True if healthy, False otherwise. + When this returns False, it indicates the server is shutting down or permanently unavailable. """ if not self._health_retry_delays: client_logger.info("No health retry delays configured; skipping health checks.") @@ -1477,8 +1504,9 @@ async def _wait_until_healthy(self, session: aiohttp.ClientSession) -> bool: client_logger.warning(f"Server is not healthy yet. Retrying in {delay} seconds.") if delay > 0.0: await asyncio.sleep(delay) - client_logger.error( - f"Server is not healthy at {self.server_address}/health after {len(self._health_retry_delays)} retry attempts" + client_logger.warning( + f"Server is not healthy at {self.server_address}/health after {len(self._health_retry_delays)} retry attempts. " + "Server appears to be shutting down." ) return False @@ -1540,10 +1568,54 @@ async def _request_json( last_exc = net_exc client_logger.info(f"Network/session issue: {net_exc} - will retry the request {method}: {path}") if not await self._wait_until_healthy(session): + # Check shared flag - if server is not online, silently ignore error + if self._server_online_flag is not None: + with self._server_online_flag.get_lock(): + is_online = bool(self._server_online_flag.value) + if not is_online: + client_logger.debug( + f"Server is not online (shared flag). Silently ignoring {type(net_exc).__name__} for {method}: {path}" + ) + # Silently ignore - return None to indicate failure was expected + return None break # server is not healthy, do not retry + except asyncio.CancelledError as cancel_exc: + # Cancellation can occur during async operations, especially during shutdown + client_logger.debug(f"Request cancelled: {method}: {path}", exc_info=True) + # Check shared flag - if server is not online, silently ignore error + if self._server_online_flag is not None: + with self._server_online_flag.get_lock(): + is_online = bool(self._server_online_flag.value) + if not is_online: + client_logger.debug( + f"Server is not online (shared flag). Silently ignoring CancelledError for {method}: {path}" + ) + # Silently ignore - return None to indicate failure was expected + return None + # If flag not available or server is online, re-raise cancellation + raise cancel_exc # exhausted retries assert last_exc is not None + # Before raising, check shared flag - if server is not online, silently ignore error + if isinstance( + last_exc, + ( + aiohttp.ServerDisconnectedError, + aiohttp.ClientConnectorError, + aiohttp.ClientOSError, + asyncio.TimeoutError, + ), + ): + if self._server_online_flag is not None: + with self._server_online_flag.get_lock(): + is_online = bool(self._server_online_flag.value) + if not is_online: + client_logger.debug( + f"Server is not online (shared flag). Silently ignoring {type(last_exc).__name__} for {method}: {path}" + ) + # Silently ignore - return None to indicate failure was expected + return None raise last_exc async def close(self): @@ -1649,10 +1721,18 @@ async def _dequeue_batch( self._dequeue_was_successful = True return [AttemptedRollout.model_validate(item) for item in data] except Exception as e: + # Check shared flag - if server is not online, silently ignore error + if self._server_online_flag is not None: + with self._server_online_flag.get_lock(): + is_online = bool(self._server_online_flag.value) + if not is_online: + client_logger.debug( + f"Server is not online (shared flag). Silently ignoring dequeue_rollout failure: {e}" + ) + return None + # Log warning if server was online and dequeue was successful before (transition from online to offline) if self._dequeue_was_successful: - if self._dequeue_first_unsuccessful: - client_logger.warning(f"dequeue_rollout failed with exception: {e}") - self._dequeue_first_unsuccessful = False + client_logger.warning(f"dequeue_rollout failed with exception: {e}") client_logger.debug("dequeue_rollout failed with exception. Details:", exc_info=True) # Else ignore the exception because the server is not ready yet return [] @@ -1916,16 +1996,23 @@ async def add_otel_span( readable_span: ReadableSpan, sequence_id: int | None = None, ) -> Optional[Span]: - # unchanged logic, now benefits from retries inside add_span/get_next_span_sequence_id - if sequence_id is None: - sequence_id = await self.get_next_span_sequence_id(rollout_id, attempt_id) - span = Span.from_opentelemetry( - readable_span, - rollout_id=rollout_id, - attempt_id=attempt_id, - sequence_id=sequence_id, - ) - return await self.add_span(span) + try: + # unchanged logic, now benefits from retries inside add_span/get_next_span_sequence_id + if sequence_id is None: + sequence_id = await self.get_next_span_sequence_id(rollout_id, attempt_id) + span = Span.from_opentelemetry( + readable_span, + rollout_id=rollout_id, + attempt_id=attempt_id, + sequence_id=sequence_id, + ) + return await self.add_span(span) + except (ServerShutdownError, asyncio.CancelledError): + # Server is shutting down or request was cancelled - handle gracefully without traceback + client_logger.debug( + f"Server is shutting down or request cancelled. Skipping add_otel_span for rollout {rollout_id}, attempt {attempt_id}." + ) + return None async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[Rollout]: """Wait for rollouts to complete. @@ -2030,22 +2117,29 @@ async def update_attempt( last_heartbeat_time: float | Unset = UNSET, metadata: Optional[Dict[str, Any]] | Unset = UNSET, ) -> Attempt: - payload: Dict[str, Any] = {} - if not isinstance(status, Unset): - payload["status"] = status - if not isinstance(worker_id, Unset): - payload["worker_id"] = worker_id - if not isinstance(last_heartbeat_time, Unset): - payload["last_heartbeat_time"] = last_heartbeat_time - if not isinstance(metadata, Unset): - payload["metadata"] = metadata - - data = await self._request_json( - "post", - f"/rollouts/{rollout_id}/attempts/{attempt_id}", - json=payload, - ) - return Attempt.model_validate(data) + try: + payload: Dict[str, Any] = {} + if not isinstance(status, Unset): + payload["status"] = status + if not isinstance(worker_id, Unset): + payload["worker_id"] = worker_id + if not isinstance(last_heartbeat_time, Unset): + payload["last_heartbeat_time"] = last_heartbeat_time + if not isinstance(metadata, Unset): + payload["metadata"] = metadata + + data = await self._request_json( + "post", + f"/rollouts/{rollout_id}/attempts/{attempt_id}", + json=payload, + ) + return Attempt.model_validate(data) + except (ServerShutdownError, asyncio.CancelledError): + # Server is shutting down or request was cancelled - handle gracefully without traceback + client_logger.debug( + f"Server is shutting down or request cancelled. Skipping update_attempt for rollout {rollout_id}, attempt {attempt_id}." + ) + raise async def query_workers( self,