Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

### Fixes
* Moved ZepbenTokenAuth to use python dataclasses instead of `zepben.ewb.dataclassy`, existing code should work as is.
* `TypeError`s occurring in `StepAction`s will no longer silently pass

### Notes
* None.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,9 @@ def copy_step_actions(self, other: Traversal[T, D]) -> D:

async def apply_step_actions(self, item: T, context: StepContext) -> D:
for it in self.step_actions:
try:
await it.apply(item, context)
except TypeError:
pass
_apply = it.apply(item, context)
if inspect.iscoroutine(_apply):
await _apply
return self

def add_context_value_computer(self, computer: ContextValueComputer[T]) -> D:
Expand Down
13 changes: 13 additions & 0 deletions test/services/network/tracing/traversal/test_step_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
import pytest
from pytest import raises

from zepben.ewb import StepAction, StepContext
Expand Down Expand Up @@ -48,3 +49,15 @@ def _apply(self, item: T, context: StepContext):
step_action.apply(expected_item, expected_ctx)

assert captured == [(expected_item, expected_ctx)]

@pytest.mark.asyncio
async def test_async_step_action(self):
captured = []

class MyStepAction(StepAction):
async def _apply(self, item: T, context: StepContext):
captured.append(item)

step_action = MyStepAction()
await step_action.apply(1, None)
assert captured == [1]
36 changes: 36 additions & 0 deletions test/services/network/tracing/traversal/test_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,3 +510,39 @@ async def test_multiple_start_items_respect_can_stop_on_start(self):
await traversal.run(can_stop_on_start_item=False)

assert steps == [1, 11, 2, 12]

@pytest.mark.asyncio
async def test_can_use_async_step_action(self):
steps = []

class MyStepAction(StepAction):
async def _apply(self, item: T, context: StepContext):
steps.append(item)

traversal = (
_create_traversal(queue=TraversalQueue.breadth_first())
.add_stop_condition(lambda item, x: True)
.add_step_action(MyStepAction())
.add_start_item(1)
.add_start_item(11)
)
await traversal.run(can_stop_on_start_item=False)

assert steps == [1, 11, 2, 12]

@pytest.mark.asyncio
async def test_errors_in_step_action_arent_masked(self):
class MyStepAction(StepAction):
async def _apply(self, item: T, context: StepContext):
# noinspection PyTypeChecker
int(1 + "abc")

traversal = (
_create_traversal(queue=TraversalQueue.breadth_first())
.add_stop_condition(lambda item, x: True)
.add_step_action(MyStepAction())
.add_start_item(1)
.add_start_item(11)
)
with pytest.raises(TypeError):
await traversal.run(can_stop_on_start_item=False)