diff --git a/src/frequenz/client/base/streaming.py b/src/frequenz/client/base/streaming.py index 9ec415d..c7f2ea7 100644 --- a/src/frequenz/client/base/streaming.py +++ b/src/frequenz/client/base/streaming.py @@ -8,7 +8,7 @@ from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta -from typing import Generic, Literal, TypeAlias, TypeVar, overload +from typing import AsyncIterable, Generic, Literal, TypeAlias, TypeVar, overload import grpc.aio @@ -19,10 +19,6 @@ _logger = logging.getLogger(__name__) -RequestT = TypeVar("RequestT") -"""The request type of the stream.""" - - InputT = TypeVar("InputT") """The input type of the stream.""" @@ -80,20 +76,31 @@ class GrpcStreamBroadcaster(Generic[InputT, OutputT]): Example: ```python - from typing import Any from frequenz.client.base import ( GrpcStreamBroadcaster, StreamFatalError, StreamRetrying, StreamStarted, ) - from frequenz.channels import Receiver + from frequenz.channels import Receiver # Assuming Receiver is available + + # Dummy async iterable for demonstration + async def async_range(fail_after: int = -1) -> AsyncIterable[int]: + for i in range(10): + if fail_after != -1 and i >= fail_after: + raise grpc.aio.AioRpcError( + code=grpc.StatusCode.UNAVAILABLE, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata(), + details="Simulated error" + ) + yield i + await asyncio.sleep(0.1) async def main(): - stub: Any = ... # The gRPC stub streamer = GrpcStreamBroadcaster( stream_name="example_stream", - stream_method=stub.MyStreamingMethod, + stream_method=lambda: async_range(fail_after=3), transform=lambda msg: msg * 2, # transform messages retry_on_exhausted_stream=False, ) @@ -149,7 +156,7 @@ async def consume_data(): def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments self, stream_name: str, - stream_method: Callable[[], grpc.aio.UnaryStreamCall[RequestT, InputT]], + stream_method: Callable[[], AsyncIterable[InputT]], transform: Callable[[InputT], OutputT], retry_strategy: retry.Strategy | None = None, retry_on_exhausted_stream: bool = False, @@ -275,22 +282,14 @@ async def _run(self) -> None: while True: error: Exception | None = None - first_message_received = False _logger.info("%s: starting to stream", self._stream_name) try: call = self._stream_method() - # We await for the initial metadata before sending a - # StreamStarted event. This is the best indication we have of a - # successful connection without delaying it until the first - # message is received, which might happen a long time after the - # "connection" was established. - await call.initial_metadata() if self._event_sender: await self._event_sender.send(StreamStarted()) async for msg in call: - first_message_received = True try: transformed = self._transform(msg) except Exception: # pylint: disable=broad-exception-caught @@ -306,9 +305,6 @@ async def _run(self) -> None: except grpc.aio.AioRpcError as err: error = err - if first_message_received: - self._retry_strategy.reset() - if error is None and not self._retry_on_exhausted_stream: _logger.info( "%s: connection closed, stream exhausted", self._stream_name diff --git a/tests/streaming/test_grpc_stream_broadcaster.py b/tests/streaming/test_grpc_stream_broadcaster.py index d16f876..ff74b5f 100644 --- a/tests/streaming/test_grpc_stream_broadcaster.py +++ b/tests/streaming/test_grpc_stream_broadcaster.py @@ -5,7 +5,7 @@ import asyncio import logging -from collections.abc import AsyncIterator, Callable +from collections.abc import AsyncIterator from contextlib import AsyncExitStack from datetime import timedelta from unittest import mock @@ -56,18 +56,6 @@ def make_error() -> grpc.aio.AioRpcError: ) -def unary_stream_call_mock( - name: str, side_effect: Callable[[], AsyncIterator[object]] -) -> mock.MagicMock: - """Create a new mocked unary stream call.""" - # Sadly we can't use spec here because grpc.aio.UnaryStreamCall seems to be - # dynamic and mock doesn't find `__aiter__` in it when creating the spec. - call_mock = mock.MagicMock(name=name) - call_mock.__aiter__.side_effect = side_effect - call_mock.initial_metadata = mock.AsyncMock() - return call_mock - - @pytest.fixture async def ok_helper( no_retry: mock.MagicMock, # pylint: disable=redefined-outer-name @@ -83,15 +71,9 @@ async def asynciter() -> AsyncIterator[int]: yield i await asyncio.sleep(0) # Yield control to the event loop - rpc_mock = mock.MagicMock( - name="ok_helper_method", - side_effect=lambda: unary_stream_call_mock( - "ok_helper_unary_stream_call", asynciter - ), - ) helper = streaming.GrpcStreamBroadcaster( stream_name="test_helper", - stream_method=rpc_mock, + stream_method=asynciter, transform=_transformer, retry_strategy=no_retry, retry_on_exhausted_stream=retry_on_exhausted_stream, @@ -140,31 +122,6 @@ async def __anext__(self) -> int: raise self._error return self._current - async def initial_metadata(self) -> None: - """Mock initial metadata method.""" - if self._current >= self._num_successes: - raise self._error - - -def erroring_rpc_mock( - error: Exception, - ready_event: asyncio.Event, - *, - num_successes: int = 0, - should_error_on_initial_metadata_too: bool = False, -) -> mock.MagicMock: - """Fixture for mocked erroring rpc.""" - # In this case we want to keep the state of the erroring call - erroring_iter = _ErroringAsyncIter(error, ready_event, num_successes=num_successes) - call_mock = unary_stream_call_mock( - "erroring_unary_stream_call", lambda: erroring_iter - ) - if should_error_on_initial_metadata_too: - call_mock.initial_metadata.side_effect = erroring_iter.initial_metadata - rpc_mock = mock.MagicMock(name="erroring_rpc", return_value=call_mock) - - return rpc_mock - @pytest.mark.parametrize("retry_on_exhausted_stream", [True]) async def test_streaming_success_retry_on_exhausted( @@ -256,7 +213,7 @@ async def test_streaming_error( # pylint: disable=too-many-arguments helper = streaming.GrpcStreamBroadcaster( stream_name="test_helper", - stream_method=erroring_rpc_mock( + stream_method=lambda: _ErroringAsyncIter( error, receiver_ready_event, num_successes=successes ), transform=_transformer, @@ -316,9 +273,7 @@ async def asynciter() -> AsyncIterator[int]: rpc_mock = mock.MagicMock( name="ok_helper_method", - side_effect=lambda: unary_stream_call_mock( - "ok_helper_unary_stream_call", asynciter - ), + side_effect=asynciter, ) helper = streaming.GrpcStreamBroadcaster( @@ -388,7 +343,7 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments mock_retry.get_progress.return_value = "mock progress" helper = streaming.GrpcStreamBroadcaster( stream_name="test_helper", - stream_method=erroring_rpc_mock(error, receiver_ready_event), + stream_method=lambda: _ErroringAsyncIter(error, receiver_ready_event), transform=_transformer, retry_strategy=mock_retry, ) @@ -422,18 +377,10 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments ] -@pytest.mark.parametrize( - "include_events", [True, False], ids=["with_events", "without_events"] -) -@pytest.mark.parametrize( - "error_in_metadata", - [True, False], - ids=["with_initial_metadata_error", "iterator_error_only"], -) +@pytest.mark.parametrize("include_events", [True, False]) async def test_messages_on_retry( receiver_ready_event: asyncio.Event, # pylint: disable=redefined-outer-name include_events: bool, - error_in_metadata: bool, ) -> None: """Test that messages are sent on retry.""" # We need to use a specific instance for all the test here because 2 errors created @@ -443,11 +390,8 @@ async def test_messages_on_retry( helper = streaming.GrpcStreamBroadcaster( stream_name="test_helper", - stream_method=erroring_rpc_mock( - error, - receiver_ready_event, - num_successes=2, - should_error_on_initial_metadata_too=error_in_metadata, + stream_method=lambda: _ErroringAsyncIter( + error, receiver_ready_event, num_successes=2 ), transform=_transformer, retry_strategy=retry.LinearBackoff(limit=1, interval=0.0, jitter=0.0), @@ -466,57 +410,15 @@ async def test_messages_on_retry( assert items == [ "transformed_0", "transformed_1", + "transformed_0", + "transformed_1", ] if include_events: - extra_events: list[StreamEvent] = [] - if not error_in_metadata: - extra_events.append(StreamStarted()) assert events == [ StreamStarted(), StreamRetrying(timedelta(seconds=0.0), error), - *extra_events, + StreamStarted(), StreamFatalError(error), ] else: assert events == [] - - -@mock.patch( - "frequenz.client.base.streaming.asyncio.sleep", autospec=True, wraps=asyncio.sleep -) -async def test_retry_reset( - mock_sleep: mock.MagicMock, - receiver_ready_event: asyncio.Event, # pylint: disable=redefined-outer-name -) -> None: - """Test that retry strategy resets after a successful start.""" - # Use a mock retry strategy so we can assert reset() was called. - mock_retry = mock.MagicMock(spec=retry.Strategy) - # Simulate one retry interval then exhaustion. - mock_retry.next_interval.side_effect = [0.01, 0.01, None] - mock_retry.copy.return_value = mock_retry - mock_retry.get_progress.return_value = "mock progress" - - # The rpc will yield one message then raise, so the strategy should be reset - # after the successful start (i.e. after first message received). - helper = streaming.GrpcStreamBroadcaster( - stream_name="test_helper", - stream_method=erroring_rpc_mock( - make_error(), receiver_ready_event, num_successes=1 - ), - transform=_transformer, - retry_strategy=mock_retry, - retry_on_exhausted_stream=True, - ) - - async with AsyncExitStack() as stack: - stack.push_async_callback(helper.stop) - - receiver = helper.new_receiver() - receiver_ready_event.set() - _ = await _split_message(receiver) - - # reset() should have been called once after the successful start. - mock_retry.reset.assert_called_once() - - # One sleep for the single retry interval. - mock_sleep.assert_has_calls([mock.call(0.01), mock.call(0.01)])