diff --git a/changelog.md b/changelog.md index db99e148..fc82c498 100644 --- a/changelog.md +++ b/changelog.md @@ -11,6 +11,7 @@ ### Fixes * Moved ZepbenTokenAuth to use python dataclasses instead of `zepben.ewb.dataclassy`, existing code should work as is. +* `TypeError`s occurring in `StepAction`s will no longer silently pass ### Notes * None. diff --git a/src/zepben/ewb/services/network/tracing/traversal/traversal.py b/src/zepben/ewb/services/network/tracing/traversal/traversal.py index e2aa884d..4f6a12fb 100644 --- a/src/zepben/ewb/services/network/tracing/traversal/traversal.py +++ b/src/zepben/ewb/services/network/tracing/traversal/traversal.py @@ -390,10 +390,9 @@ def copy_step_actions(self, other: Traversal[T, D]) -> D: async def apply_step_actions(self, item: T, context: StepContext) -> D: for it in self.step_actions: - try: - await it.apply(item, context) - except TypeError: - pass + _apply = it.apply(item, context) + if inspect.iscoroutine(_apply): + await _apply return self def add_context_value_computer(self, computer: ContextValueComputer[T]) -> D: diff --git a/test/services/network/tracing/traversal/test_step_action.py b/test/services/network/tracing/traversal/test_step_action.py index f5af4909..1fe8246c 100644 --- a/test/services/network/tracing/traversal/test_step_action.py +++ b/test/services/network/tracing/traversal/test_step_action.py @@ -2,6 +2,7 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at https://mozilla.org/MPL/2.0/. +import pytest from pytest import raises from zepben.ewb import StepAction, StepContext @@ -48,3 +49,15 @@ def _apply(self, item: T, context: StepContext): step_action.apply(expected_item, expected_ctx) assert captured == [(expected_item, expected_ctx)] + + @pytest.mark.asyncio + async def test_async_step_action(self): + captured = [] + + class MyStepAction(StepAction): + async def _apply(self, item: T, context: StepContext): + captured.append(item) + + step_action = MyStepAction() + await step_action.apply(1, None) + assert captured == [1] diff --git a/test/services/network/tracing/traversal/test_traversal.py b/test/services/network/tracing/traversal/test_traversal.py index 32b6f120..d5ca7913 100644 --- a/test/services/network/tracing/traversal/test_traversal.py +++ b/test/services/network/tracing/traversal/test_traversal.py @@ -510,3 +510,39 @@ async def test_multiple_start_items_respect_can_stop_on_start(self): await traversal.run(can_stop_on_start_item=False) assert steps == [1, 11, 2, 12] + + @pytest.mark.asyncio + async def test_can_use_async_step_action(self): + steps = [] + + class MyStepAction(StepAction): + async def _apply(self, item: T, context: StepContext): + steps.append(item) + + traversal = ( + _create_traversal(queue=TraversalQueue.breadth_first()) + .add_stop_condition(lambda item, x: True) + .add_step_action(MyStepAction()) + .add_start_item(1) + .add_start_item(11) + ) + await traversal.run(can_stop_on_start_item=False) + + assert steps == [1, 11, 2, 12] + + @pytest.mark.asyncio + async def test_errors_in_step_action_arent_masked(self): + class MyStepAction(StepAction): + async def _apply(self, item: T, context: StepContext): + # noinspection PyTypeChecker + int(1 + "abc") + + traversal = ( + _create_traversal(queue=TraversalQueue.breadth_first()) + .add_stop_condition(lambda item, x: True) + .add_step_action(MyStepAction()) + .add_start_item(1) + .add_start_item(11) + ) + with pytest.raises(TypeError): + await traversal.run(can_stop_on_start_item=False)