diff --git a/docs/source/guides/explanations/sigproc.rst b/docs/source/guides/explanations/sigproc.rst index e4dd5f9e..f052e874 100644 --- a/docs/source/guides/explanations/sigproc.rst +++ b/docs/source/guides/explanations/sigproc.rst @@ -329,7 +329,7 @@ Often, all that is required is the following (e.g., for a custom transformer): import ezmsg.core as ez from ezmsg.util.messages.axisarray import AxisArray - from ezmsg.sigproc.base import BaseTransformer, BaseTransformerUnit + from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit class CustomTransformerSettings(ez.Settings): diff --git a/docs/source/guides/tutorials/signalprocessing.rst b/docs/source/guides/tutorials/signalprocessing.rst index 0156b995..2d410be3 100644 --- a/docs/source/guides/tutorials/signalprocessing.rst +++ b/docs/source/guides/tutorials/signalprocessing.rst @@ -132,7 +132,7 @@ Add the following import statements to the top of the `downsample.py` file: ) import ezmsg.core as ez - from ezmsg.sigproc.base import ( + from ezmsg.baseproc import ( BaseStatefulTransformer, BaseTransformerUnit, processor_state, diff --git a/src/ezmsg/sigproc/affinetransform.py b/src/ezmsg/sigproc/affinetransform.py index 7035a96a..624bfd94 100644 --- a/src/ezmsg/sigproc/affinetransform.py +++ b/src/ezmsg/sigproc/affinetransform.py @@ -1,3 +1,13 @@ +"""Affine transformations via matrix multiplication: y = Ax or y = Ax + B. + +For full matrix transformations where channels are mixed (off-diagonal weights), +use :obj:`AffineTransformTransformer` or the `AffineTransform` unit. + +For simple per-channel scaling and offset (diagonal weights only), use +:obj:`LinearTransformTransformer` from :mod:`ezmsg.sigproc.linear` instead, +which is more efficient as it avoids matrix multiplication. +""" + import os from pathlib import Path @@ -17,7 +27,6 @@ class AffineTransformSettings(ez.Settings): """ Settings for :obj:`AffineTransform`. - See :obj:`affine_transform` for argument details. """ weights: np.ndarray | str | Path @@ -39,6 +48,19 @@ class AffineTransformState: class AffineTransformTransformer( BaseStatefulTransformer[AffineTransformSettings, AxisArray, AxisArray, AffineTransformState] ): + """Apply affine transformation via matrix multiplication: y = Ax or y = Ax + B. + + Use this transformer when you need full matrix transformations that mix + channels (off-diagonal weights), such as spatial filters or projections. + + For simple per-channel scaling and offset where each output channel depends + only on its corresponding input channel (diagonal weight matrix), use + :obj:`LinearTransformTransformer` instead, which is more efficient. + + The weights matrix can include an offset row (stacked as [A|B]) where the + input is automatically augmented with a column of ones to compute y = Ax + B. + """ + def __call__(self, message: AxisArray) -> AxisArray: # Override __call__ so we can shortcut if weights are None. if self.settings.weights is None or ( diff --git a/src/ezmsg/sigproc/butterworthzerophase.py b/src/ezmsg/sigproc/butterworthzerophase.py index 22ebe681..1bd2fbff 100644 --- a/src/ezmsg/sigproc/butterworthzerophase.py +++ b/src/ezmsg/sigproc/butterworthzerophase.py @@ -4,10 +4,10 @@ import ezmsg.core as ez import numpy as np import scipy.signal +from ezmsg.baseproc import SettingsType from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messages.util import replace -from ezmsg.sigproc.base import SettingsType from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, butter_design_fun from ezmsg.sigproc.filter import ( BACoeffs, diff --git a/src/ezmsg/sigproc/denormalize.py b/src/ezmsg/sigproc/denormalize.py index 0e7f94ea..9e7050c5 100644 --- a/src/ezmsg/sigproc/denormalize.py +++ b/src/ezmsg/sigproc/denormalize.py @@ -1,14 +1,13 @@ import ezmsg.core as ez import numpy as np import numpy.typing as npt -from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.util.messages.util import replace - -from ezmsg.sigproc.base import ( +from ezmsg.baseproc import ( BaseStatefulTransformer, BaseTransformerUnit, processor_state, ) +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace class DenormalizeSettings(ez.Settings): diff --git a/src/ezmsg/sigproc/detrend.py b/src/ezmsg/sigproc/detrend.py index 0b405b53..c6091273 100644 --- a/src/ezmsg/sigproc/detrend.py +++ b/src/ezmsg/sigproc/detrend.py @@ -1,7 +1,7 @@ import scipy.signal as sps +from ezmsg.baseproc import BaseTransformerUnit from ezmsg.util.messages.axisarray import AxisArray, replace -from ezmsg.sigproc.base import BaseTransformerUnit from ezmsg.sigproc.ewma import EWMASettings, EWMATransformer diff --git a/src/ezmsg/sigproc/diff.py b/src/ezmsg/sigproc/diff.py index 90136d0f..15aaffdc 100644 --- a/src/ezmsg/sigproc/diff.py +++ b/src/ezmsg/sigproc/diff.py @@ -1,14 +1,13 @@ import ezmsg.core as ez import numpy as np import numpy.typing as npt -from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis -from ezmsg.util.messages.util import replace - -from ezmsg.sigproc.base import ( +from ezmsg.baseproc import ( BaseStatefulTransformer, BaseTransformerUnit, processor_state, ) +from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis +from ezmsg.util.messages.util import replace class DiffSettings(ez.Settings): diff --git a/src/ezmsg/sigproc/extract_axis.py b/src/ezmsg/sigproc/extract_axis.py index 220e7467..8af74165 100644 --- a/src/ezmsg/sigproc/extract_axis.py +++ b/src/ezmsg/sigproc/extract_axis.py @@ -1,9 +1,8 @@ import ezmsg.core as ez import numpy as np +from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit from ezmsg.util.messages.axisarray import AxisArray, replace -from ezmsg.sigproc.base import BaseTransformer, BaseTransformerUnit - class ExtractAxisSettings(ez.Settings): axis: str = "freq" diff --git a/src/ezmsg/sigproc/filter.py b/src/ezmsg/sigproc/filter.py index 989d37cd..2ef2340e 100644 --- a/src/ezmsg/sigproc/filter.py +++ b/src/ezmsg/sigproc/filter.py @@ -6,10 +6,7 @@ import numpy as np import numpy.typing as npt import scipy.signal -from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.util.messages.util import replace - -from ezmsg.sigproc.base import ( +from ezmsg.baseproc import ( BaseConsumerUnit, BaseStatefulTransformer, BaseTransformerUnit, @@ -17,6 +14,8 @@ TransformerType, processor_state, ) +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace @dataclass diff --git a/src/ezmsg/sigproc/fir_hilbert.py b/src/ezmsg/sigproc/fir_hilbert.py index 41f37a6b..8cdd3c5d 100644 --- a/src/ezmsg/sigproc/fir_hilbert.py +++ b/src/ezmsg/sigproc/fir_hilbert.py @@ -4,10 +4,10 @@ import ezmsg.core as ez import numpy as np import scipy.signal as sps +from ezmsg.baseproc import BaseStatefulTransformer, processor_state from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messages.util import replace -from ezmsg.sigproc.base import BaseStatefulTransformer, processor_state from ezmsg.sigproc.filter import ( BACoeffs, BaseFilterByDesignTransformerUnit, diff --git a/src/ezmsg/sigproc/linear.py b/src/ezmsg/sigproc/linear.py new file mode 100644 index 00000000..4612aa72 --- /dev/null +++ b/src/ezmsg/sigproc/linear.py @@ -0,0 +1,118 @@ +"""Apply a linear transformation: output = scale * input + offset. + +Supports per-element scale and offset along a specified axis. +Uses Array API for compatibility with numpy, cupy, torch, etc. + +For full matrix transformations, use :obj:`AffineTransformTransformer` instead. +""" + +import ezmsg.core as ez +import numpy as np +import numpy.typing as npt +from array_api_compat import get_namespace +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace + +from .base import BaseStatefulTransformer, BaseTransformerUnit, processor_state + + +class LinearTransformSettings(ez.Settings): + scale: float | list[float] | npt.ArrayLike = 1.0 + """Scale factor(s). Can be a scalar (applied to all elements) or an array + matching the size of the specified axis for per-element scaling.""" + + offset: float | list[float] | npt.ArrayLike = 0.0 + """Offset value(s). Can be a scalar (applied to all elements) or an array + matching the size of the specified axis for per-element offset.""" + + axis: str | None = None + """Axis along which to apply per-element scale/offset. If None, scalar + scale/offset are broadcast to all elements.""" + + +@processor_state +class LinearTransformState: + scale: npt.NDArray = None + """Prepared scale array for broadcasting.""" + + offset: npt.NDArray = None + """Prepared offset array for broadcasting.""" + + +class LinearTransformTransformer( + BaseStatefulTransformer[LinearTransformSettings, AxisArray, AxisArray, LinearTransformState] +): + """Apply linear transformation: output = scale * input + offset. + + This transformer is optimized for element-wise linear operations with + optional per-channel (or per-axis) coefficients. For full matrix + transformations, use :obj:`AffineTransformTransformer` instead. + + Examples: + # Uniform scaling and offset + >>> transformer = LinearTransformTransformer(LinearTransformSettings(scale=2.0, offset=1.0)) + + # Per-channel scaling (e.g., for 3-channel data along "ch" axis) + >>> transformer = LinearTransformTransformer(LinearTransformSettings( + ... scale=[0.5, 1.0, 2.0], + ... offset=[0.0, 0.1, 0.2], + ... axis="ch" + ... )) + """ + + def _hash_message(self, message: AxisArray) -> int: + """Hash based on shape and axis to detect when broadcast shapes need recalculation.""" + axis = self.settings.axis + if axis is not None: + axis_idx = message.get_axis_idx(axis) + return hash((message.data.ndim, axis_idx, message.data.shape[axis_idx])) + return hash(message.data.ndim) + + def _reset_state(self, message: AxisArray) -> None: + """Prepare scale/offset arrays with proper broadcast shapes.""" + xp = get_namespace(message.data) + ndim = message.data.ndim + + scale = self.settings.scale + offset = self.settings.offset + + # Convert settings to arrays + if isinstance(scale, (list, np.ndarray)): + scale = xp.asarray(scale, dtype=xp.float64) + else: + # Scalar: create a 0-d array + scale = xp.asarray(float(scale), dtype=xp.float64) + + if isinstance(offset, (list, np.ndarray)): + offset = xp.asarray(offset, dtype=xp.float64) + else: + # Scalar: create a 0-d array + offset = xp.asarray(float(offset), dtype=xp.float64) + + # If axis is specified and we have 1-d arrays, reshape for proper broadcasting + if self.settings.axis is not None and ndim > 0: + axis_idx = message.get_axis_idx(self.settings.axis) + + if scale.ndim == 1: + # Create shape for broadcasting: all 1s except at axis_idx + broadcast_shape = [1] * ndim + broadcast_shape[axis_idx] = scale.shape[0] + scale = xp.reshape(scale, broadcast_shape) + + if offset.ndim == 1: + broadcast_shape = [1] * ndim + broadcast_shape[axis_idx] = offset.shape[0] + offset = xp.reshape(offset, broadcast_shape) + + self._state.scale = scale + self._state.offset = offset + + def _process(self, message: AxisArray) -> AxisArray: + result = message.data * self._state.scale + self._state.offset + return replace(message, data=result) + + +class LinearTransform(BaseTransformerUnit[LinearTransformSettings, AxisArray, AxisArray, LinearTransformTransformer]): + """Unit wrapper for LinearTransformTransformer.""" + + SETTINGS = LinearTransformSettings diff --git a/src/ezmsg/sigproc/rollingscaler.py b/src/ezmsg/sigproc/rollingscaler.py index 477649e9..a4390c2c 100644 --- a/src/ezmsg/sigproc/rollingscaler.py +++ b/src/ezmsg/sigproc/rollingscaler.py @@ -3,14 +3,14 @@ import ezmsg.core as ez import numpy as np import numpy.typing as npt -from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.util.messages.util import replace - -from ezmsg.sigproc.base import ( +from ezmsg.baseproc import ( BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit, processor_state, ) +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace + from ezmsg.sigproc.sampler import SampleMessage diff --git a/tests/unit/test_activation.py b/tests/unit/test_activation.py index 0d22a3c8..06cd77d6 100644 --- a/tests/unit/test_activation.py +++ b/tests/unit/test_activation.py @@ -4,7 +4,12 @@ from ezmsg.util.messages.axisarray import AxisArray from frozendict import frozendict -from ezmsg.sigproc.activation import ACTIVATIONS, ActivationFunction, activation +from ezmsg.sigproc.activation import ( + ACTIVATIONS, + ActivationFunction, + ActivationSettings, + ActivationTransformer, +) @pytest.mark.parametrize("function", [_ for _ in ActivationFunction] + ActivationFunction.options()) @@ -25,8 +30,8 @@ def msg_generator(): ) yield msg - proc = activation(function=function) - out_msgs = [proc.send(_) for _ in msg_generator()] + proc = ActivationTransformer(ActivationSettings(function=function)) + out_msgs = [proc(_) for _ in msg_generator()] out_dat = AxisArray.concatenate(*out_msgs, dim="time").data if function in ACTIVATIONS: diff --git a/tests/unit/test_affine_transform.py b/tests/unit/test_affine_transform.py index 68a5cff3..20355dc2 100644 --- a/tests/unit/test_affine_transform.py +++ b/tests/unit/test_affine_transform.py @@ -4,11 +4,16 @@ import numpy as np from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.sigproc.affinetransform import affine_transform, common_rereference +from ezmsg.sigproc.affinetransform import ( + AffineTransformSettings, + AffineTransformTransformer, + CommonRereferenceSettings, + CommonRereferenceTransformer, +) from tests.helpers.util import assert_messages_equal -def test_affine_generator(): +def test_affine_transform(): n_times = 13 n_chans = 64 in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans) @@ -20,16 +25,16 @@ def test_affine_generator(): backup = [copy.deepcopy(msg_in)] - gen = affine_transform(weights=np.eye(n_chans), axis="ch") - msg_out = gen.send(msg_in) + xformer = AffineTransformTransformer(AffineTransformSettings(weights=np.eye(n_chans), axis="ch")) + msg_out = xformer(msg_in) assert msg_out.data.shape == in_dat.shape assert np.allclose(msg_out.data, in_dat) assert not np.may_share_memory(msg_out.data, in_dat) assert_messages_equal([msg_in], backup) - # Send again just to make sure the generator doesn't crash - _ = gen.send(msg_in) + # Call again just to make sure the transformer doesn't crash + _ = xformer(msg_in) # Test with weights from a CSV file. csv_path = Path(__file__).parents[1] / "resources" / "xform.csv" @@ -37,27 +42,29 @@ def test_affine_generator(): expected_out = in_dat @ weights.T # Same result: expected_out = np.vstack([(step[None, :] * weights).sum(axis=1) for step in in_dat]) - gen = affine_transform(weights=csv_path, axis="ch", right_multiply=False) - msg_out = gen.send(msg_in) + xformer = AffineTransformTransformer(AffineTransformSettings(weights=csv_path, axis="ch", right_multiply=False)) + msg_out = xformer(msg_in) assert np.allclose(msg_out.data, expected_out) assert len(msg_out.axes["ch"].data) == weights.shape[0] assert (msg_out.axes["ch"].data[:-1] == msg_in.axes["ch"].data).all() # Try again as str, not Path - gen = affine_transform(weights=str(csv_path), axis="ch", right_multiply=False) - msg_out = gen.send(msg_in) + xformer = AffineTransformTransformer( + AffineTransformSettings(weights=str(csv_path), axis="ch", right_multiply=False) + ) + msg_out = xformer(msg_in) assert np.allclose(msg_out.data, expected_out) assert len(msg_out.axes["ch"].data) == weights.shape[0] # Try again as direct ndarray - gen = affine_transform(weights=weights, axis="ch", right_multiply=False) - msg_out = gen.send(msg_in) + xformer = AffineTransformTransformer(AffineTransformSettings(weights=weights, axis="ch", right_multiply=False)) + msg_out = xformer(msg_in) assert np.allclose(msg_out.data, expected_out) assert len(msg_out.axes["ch"].data) == weights.shape[0] # One more time, but we pre-transpose the weights and do not override right_multiply - gen = affine_transform(weights=weights.T, axis="ch", right_multiply=True) - msg_out = gen.send(msg_in) + xformer = AffineTransformTransformer(AffineTransformSettings(weights=weights.T, axis="ch", right_multiply=True)) + msg_out = xformer(msg_in) assert np.allclose(msg_out.data, expected_out) assert len(msg_out.axes["ch"].data) == weights.shape[0] @@ -70,9 +77,9 @@ def test_affine_passthrough(): backup = [copy.deepcopy(msg_in)] - gen = affine_transform(weights="passthrough", axis="does not matter") - msg_out = gen.send(msg_in) - # We wouldn't want out_data is in_dat ezmsg pipeline but it's fine for the generator + xformer = AffineTransformTransformer(AffineTransformSettings(weights="passthrough", axis="does not matter")) + msg_out = xformer(msg_in) + # We wouldn't want out_data is in_dat ezmsg pipeline but it's fine for the transformer assert msg_out.data is in_dat assert_messages_equal([msg_out], backup) @@ -85,8 +92,8 @@ def test_common_rereference(): backup = [copy.deepcopy(msg_in)] - gen = common_rereference(mode="mean", axis="ch", include_current=True) - msg_out = gen.send(msg_in) + xformer = CommonRereferenceTransformer(CommonRereferenceSettings(mode="mean", axis="ch", include_current=True)) + msg_out = xformer(msg_in) assert np.array_equal( msg_out.data, msg_in.data - np.mean(msg_in.data, axis=1, keepdims=True), @@ -103,17 +110,17 @@ def test_common_rereference(): expected_out.append(msg_in.data[..., ch_ix] - np.mean(msg_in.data[..., idx], axis=1)) expected_out = np.stack(expected_out).T - gen = common_rereference(mode="mean", axis="ch", include_current=False) - msg_out = gen.send(msg_in) # 41 us + xformer = CommonRereferenceTransformer(CommonRereferenceSettings(mode="mean", axis="ch", include_current=False)) + msg_out = xformer(msg_in) # 41 us assert np.allclose(msg_out.data, expected_out) - # Instead of CAR, we could use affine_transform with weights that reproduce CAR. + # Instead of CAR, we could use AffineTransformTransformer with weights that reproduce CAR. # However, this method is 30x slower than above. (Actual difference varies depending on data shape). if False: weights = -np.ones((n_chans, n_chans)) / (n_chans - 1) np.fill_diagonal(weights, 1) - gen = affine_transform(weights=weights, axis="ch") - msg_out = gen.send(msg_in) + xformer = AffineTransformTransformer(AffineTransformSettings(weights=weights, axis="ch")) + msg_out = xformer(msg_in) assert np.allclose(msg_out.data, expected_out) @@ -123,7 +130,7 @@ def test_car_passthrough(): in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans) msg_in = AxisArray(in_dat, dims=["time", "ch"]) - gen = common_rereference(mode="passthrough") - msg_out = gen.send(msg_in) + xformer = CommonRereferenceTransformer(CommonRereferenceSettings(mode="passthrough")) + msg_out = xformer(msg_in) assert np.array_equal(msg_out.data, in_dat) assert np.may_share_memory(msg_out.data, in_dat) diff --git a/tests/unit/test_aggregate.py b/tests/unit/test_aggregate.py index 9b59e77e..3647b9b3 100644 --- a/tests/unit/test_aggregate.py +++ b/tests/unit/test_aggregate.py @@ -10,7 +10,8 @@ AggregateSettings, AggregateTransformer, AggregationFunction, - ranged_aggregate, + RangedAggregateSettings, + RangedAggregateTransformer, ) from tests.helpers.util import assert_messages_equal @@ -61,8 +62,8 @@ def test_aggregate(agg_func: AggregationFunction): backup = [copy.deepcopy(_) for _ in in_msgs] - gen = ranged_aggregate(axis=targ_ax, bands=bands, operation=agg_func) - out_msgs = [gen.send(_) for _ in in_msgs] + xformer = RangedAggregateTransformer(RangedAggregateSettings(axis=targ_ax, bands=bands, operation=agg_func)) + out_msgs = [xformer(_) for _ in in_msgs] assert_messages_equal(in_msgs, backup) @@ -96,8 +97,8 @@ def test_aggregate(agg_func: AggregationFunction): def test_arg_aggregate(agg_func: AggregationFunction): bands = [(5.0, 20.0), (30.0, 50.0)] in_msgs = [_ for _ in get_msg_gen()] - gen = ranged_aggregate(axis="freq", bands=bands, operation=agg_func) - out_msgs = [gen.send(_) for _ in in_msgs] + xformer = RangedAggregateTransformer(RangedAggregateSettings(axis="freq", bands=bands, operation=agg_func)) + out_msgs = [xformer(_) for _ in in_msgs] if agg_func == AggregationFunction.ARGMIN: expected_vals = np.array([np.min(_) for _ in bands]) @@ -111,8 +112,10 @@ def test_arg_aggregate(agg_func: AggregationFunction): def test_trapezoid(): bands = [(5.0, 20.0), (30.0, 50.0)] in_msgs = [_ for _ in get_msg_gen()] - gen = ranged_aggregate(axis="freq", bands=bands, operation=AggregationFunction.TRAPEZOID) - out_msgs = [gen.send(_) for _ in in_msgs] + xformer = RangedAggregateTransformer( + RangedAggregateSettings(axis="freq", bands=bands, operation=AggregationFunction.TRAPEZOID) + ) + out_msgs = [xformer(_) for _ in in_msgs] out_dat = AxisArray.concatenate(*out_msgs, dim="time").data @@ -145,15 +148,17 @@ def test_aggregate_handle_change(change_ax: str): ) ] - gen = ranged_aggregate( - axis="freq", - bands=[(5.0, 20.0), (30.0, 50.0)], - operation=AggregationFunction.MEAN, + xformer = RangedAggregateTransformer( + RangedAggregateSettings( + axis="freq", + bands=[(5.0, 20.0), (30.0, 50.0)], + operation=AggregationFunction.MEAN, + ) ) - out_msgs1 = [gen.send(_) for _ in in_msgs1] + out_msgs1 = [xformer(_) for _ in in_msgs1] print(len(out_msgs1)) - out_msgs2 = [gen.send(_) for _ in in_msgs2] + out_msgs2 = [xformer(_) for _ in in_msgs2] print(len(out_msgs2)) diff --git a/tests/unit/test_bandpower.py b/tests/unit/test_bandpower.py index 4dc2a7a6..227d71ad 100644 --- a/tests/unit/test_bandpower.py +++ b/tests/unit/test_bandpower.py @@ -3,7 +3,12 @@ import numpy as np from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.sigproc.bandpower import AggregationFunction, SpectrogramSettings, bandpower +from ezmsg.sigproc.bandpower import ( + AggregationFunction, + BandPowerSettings, + BandPowerTransformer, + SpectrogramSettings, +) from tests.helpers.util import ( assert_messages_equal, create_messages_with_periodic_signal, @@ -39,15 +44,17 @@ def test_bandpower(): # while being processed. backup = [copy.deepcopy(_) for _ in messages] - gen = bandpower( - spectrogram_settings=SpectrogramSettings( - window_dur=win_dur, - window_shift=0.1, - ), - bands=bands, - aggregation=AggregationFunction.MEAN, + xformer = BandPowerTransformer( + BandPowerSettings( + spectrogram_settings=SpectrogramSettings( + window_dur=win_dur, + window_shift=0.1, + ), + bands=bands, + aggregation=AggregationFunction.MEAN, + ) ) - results = [gen.send(_) for _ in messages] + results = [xformer(_) for _ in messages] assert_messages_equal(messages, backup) diff --git a/tests/unit/test_base.py b/tests/unit/test_base.py index ce8b9965..b8562889 100644 --- a/tests/unit/test_base.py +++ b/tests/unit/test_base.py @@ -5,9 +5,7 @@ from unittest.mock import MagicMock import pytest -from ezmsg.util.generator import consumer - -from ezmsg.sigproc.base import ( +from ezmsg.baseproc import ( BaseAdaptiveTransformer, BaseAsyncTransformer, BaseConsumer, @@ -28,6 +26,8 @@ _get_processor_message_type, processor_state, ) +from ezmsg.util.generator import consumer + from ezmsg.sigproc.cheby import ChebyshevFilterTransformer from ezmsg.sigproc.filter import FilterByDesignState diff --git a/tests/unit/test_butter.py b/tests/unit/test_butter.py index 49fa8aa8..cbc6d3c9 100644 --- a/tests/unit/test_butter.py +++ b/tests/unit/test_butter.py @@ -4,10 +4,7 @@ from ezmsg.util.messages.axisarray import AxisArray from frozendict import frozendict -from ezmsg.sigproc.butterworthfilter import ( - ButterworthFilterSettings as LegacyButterSettings, -) -from ezmsg.sigproc.butterworthfilter import butter +from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, ButterworthFilterTransformer @pytest.mark.parametrize( @@ -31,7 +28,7 @@ def test_butterworth_legacy_filter_settings(cutoff: float, cuton: float, order: If cuton is larger than cutoff we assume bandstop. order (int): The order of the filter. """ - btype, Wn = LegacyButterSettings(order=order, cuton=cuton, cutoff=cutoff).filter_specs() + btype, Wn = ButterworthFilterSettings(order=order, cuton=cuton, cutoff=cutoff).filter_specs() if cuton is None: assert btype == "lowpass" assert Wn == cutoff @@ -92,7 +89,7 @@ def test_butterworth( in_dat = np.arange(np.prod(dat_shape), dtype=float).reshape(*dat_shape) # Calculate Expected Result - btype, Wn = LegacyButterSettings(order=order, cuton=cuton, cutoff=cutoff).filter_specs() + btype, Wn = ButterworthFilterSettings(order=order, cuton=cuton, cutoff=cutoff).filter_specs() coefs = scipy.signal.butter(order, Wn, btype=btype, output=coef_type, fs=fs) tmp_dat = np.moveaxis(in_dat, time_ax, -1) if coef_type == "ba": @@ -127,25 +124,29 @@ def test_butterworth( # Test axis_name `None` when target axis idx is 0. axis_name = "time" if time_ax != 0 else None - gen = butter( - axis=axis_name, - order=order, - cuton=cuton, - cutoff=cutoff, - coef_type=coef_type, + xformer = ButterworthFilterTransformer( + ButterworthFilterSettings( + axis=axis_name, + order=order, + cuton=cuton, + cutoff=cutoff, + coef_type=coef_type, + ) ) - result = np.concatenate([gen.send(_).data for _ in messages], axis=time_ax) + result = np.concatenate([xformer(_).data for _ in messages], axis=time_ax) assert np.allclose(result, expected) def test_butterworth_empty_msg(): - proc = butter( - axis="time", - order=2, - cuton=0.1, - cutoff=1.0, - coef_type="sos", + proc = ButterworthFilterTransformer( + ButterworthFilterSettings( + axis="time", + order=2, + cuton=0.1, + cutoff=1.0, + coef_type="sos", + ) ) msg_in = AxisArray( data=np.zeros((0, 2)), @@ -156,7 +157,7 @@ def test_butterworth_empty_msg(): }, key="test_butterworth_empty_msg", ) - res = proc.send(msg_in) + res = proc(msg_in) assert res.data.size == 0 @@ -196,12 +197,14 @@ def _calc_power(msg, targ_freqs): assert np.argmax(power0[:, 1]) == 1 # 40Hz should be the strongest frequency in ch1 # Initialize filter - lowpass at 30Hz should pass 10Hz but attenuate 40Hz - proc = butter( - axis="time", - order=4, - cutoff=30.0, # Lowpass at 30Hz - cuton=None, - coef_type="sos", + proc = ButterworthFilterTransformer( + ButterworthFilterSettings( + axis="time", + order=4, + cutoff=30.0, # Lowpass at 30Hz + cuton=None, + coef_type="sos", + ) ) # Process first message @@ -232,8 +235,6 @@ def _calc_power(msg, targ_freqs): assert power3[0][0] / power2[0][0] < 0.1 # 10Hz should be _further_ attenuated # Test update_settings with complete new settings object, includes coef_type change. - from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings - new_settings = ButterworthFilterSettings( axis="time", order=2, diff --git a/tests/unit/test_butterworthzerophase.py b/tests/unit/test_butterworthzerophase.py index 00f63ef6..5d0d4714 100644 --- a/tests/unit/test_butterworthzerophase.py +++ b/tests/unit/test_butterworthzerophase.py @@ -103,7 +103,7 @@ def test_butterworth_zero_phase_matches_scipy( padlen=padlen, ) - out = zp.send(msg).data + out = zp(msg).data assert np.allclose(out, expected, atol=1e-10, rtol=1e-7) @@ -118,7 +118,7 @@ def test_butterworth_zero_phase_empty_msg(): }, key="test_butterworth_zero_phase_empty", ) - res = zp.send(msg) + res = zp(msg) assert res.data.size == 0 diff --git a/tests/unit/test_coordinatespaces.py b/tests/unit/test_coordinatespaces.py index 43967781..86b4a75d 100644 --- a/tests/unit/test_coordinatespaces.py +++ b/tests/unit/test_coordinatespaces.py @@ -127,7 +127,7 @@ def test_cart2pol_basic(self): backup = [copy.deepcopy(msg_in)] transformer = CoordinateSpacesTransformer(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) - msg_out = transformer.send(msg_in) + msg_out = transformer(msg_in) # Check shape preserved assert msg_out.data.shape == data.shape @@ -148,7 +148,7 @@ def test_pol2cart_basic(self): msg_in = AxisArray(data, dims=["time", "ch"]) transformer = CoordinateSpacesTransformer(CoordinateSpacesSettings(mode=CoordinateMode.POL2CART, axis="ch")) - msg_out = transformer.send(msg_in) + msg_out = transformer(msg_in) # Check values expected_x = np.array([1.0, 0.0, 1.0]) @@ -163,11 +163,11 @@ def test_roundtrip_transformer(self): # cart -> pol c2p = CoordinateSpacesTransformer(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) - msg_polar = c2p.send(msg_in) + msg_polar = c2p(msg_in) # pol -> cart p2c = CoordinateSpacesTransformer(CoordinateSpacesSettings(mode=CoordinateMode.POL2CART, axis="ch")) - msg_back = p2c.send(msg_polar) + msg_back = p2c(msg_polar) assert np.allclose(msg_back.data, data) @@ -178,7 +178,7 @@ def test_default_axis(self): # No axis specified - should use last dim ("xy") transformer = CoordinateSpacesTransformer(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL)) - msg_out = transformer.send(msg_in) + msg_out = transformer(msg_in) assert msg_out.data.shape == data.shape assert np.isclose(msg_out.data[0, 0], 1.0) # r=1 for (1,0) @@ -191,7 +191,7 @@ def test_axis_not_last(self): msg_in = AxisArray(data, dims=["ch", "time"]) transformer = CoordinateSpacesTransformer(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) - msg_out = transformer.send(msg_in) + msg_out = transformer(msg_in) assert msg_out.data.shape == data.shape # First column: (1, 0) -> r=1 @@ -206,10 +206,10 @@ def test_3d_array(self): msg_in = AxisArray(data, dims=["batch", "time", "ch"]) c2p = CoordinateSpacesTransformer(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) - msg_polar = c2p.send(msg_in) + msg_polar = c2p(msg_in) p2c = CoordinateSpacesTransformer(CoordinateSpacesSettings(mode=CoordinateMode.POL2CART, axis="ch")) - msg_back = p2c.send(msg_polar) + msg_back = p2c(msg_polar) assert msg_polar.data.shape == data.shape assert np.allclose(msg_back.data, data) @@ -222,7 +222,7 @@ def test_wrong_axis_size_raises(self): transformer = CoordinateSpacesTransformer(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) with pytest.raises(ValueError, match="exactly 2 elements"): - transformer.send(msg_in) + transformer(msg_in) def test_axis_labels_updated_cart2pol(self): """Test that axis labels are updated for cart2pol.""" @@ -235,7 +235,7 @@ def test_axis_labels_updated_cart2pol(self): ) transformer = CoordinateSpacesTransformer(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) - msg_out = transformer.send(msg_in) + msg_out = transformer(msg_in) assert "ch" in msg_out.axes assert list(msg_out.axes["ch"].data) == ["r", "theta"] @@ -251,7 +251,7 @@ def test_axis_labels_updated_pol2cart(self): ) transformer = CoordinateSpacesTransformer(CoordinateSpacesSettings(mode=CoordinateMode.POL2CART, axis="ch")) - msg_out = transformer.send(msg_in) + msg_out = transformer(msg_in) assert "ch" in msg_out.axes assert list(msg_out.axes["ch"].data) == ["x", "y"] @@ -268,5 +268,5 @@ def test_multiple_sends(self): for _ in range(5): data = np.random.randn(10, 2) msg_in = AxisArray(data, dims=["time", "ch"]) - msg_out = transformer.send(msg_in) + msg_out = transformer(msg_in) assert msg_out.data.shape == data.shape diff --git a/tests/unit/test_filterbank.py b/tests/unit/test_filterbank.py index 47459394..09b0de7c 100644 --- a/tests/unit/test_filterbank.py +++ b/tests/unit/test_filterbank.py @@ -3,7 +3,7 @@ import scipy.signal as sps from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.sigproc.filterbank import FilterbankMode, filterbank +from ezmsg.sigproc.filterbank import FilterbankMode, FilterbankSettings, FilterbankTransformer from tests.helpers.util import gaussian, make_chirp @@ -76,10 +76,10 @@ def test_filterbank(mode: str, kernel_type: str): kernels.append(k) # Prep filterbank - gen = filterbank(kernels=kernels, mode=mode, axis="time") + proc = FilterbankTransformer(settings=FilterbankSettings(kernels=kernels, mode=mode, axis="time")) # Pass the messages - out_messages = [gen.send(msg_in) for msg_in in in_messages] + out_messages = [proc(msg_in) for msg_in in in_messages] result = AxisArray.concatenate(*out_messages, dim="time") assert result.key == "test_filterbank" diff --git a/tests/unit/test_linear.py b/tests/unit/test_linear.py new file mode 100644 index 00000000..2e391c4d --- /dev/null +++ b/tests/unit/test_linear.py @@ -0,0 +1,307 @@ +import copy + +import numpy as np +import pytest +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.sigproc.linear import ( + LinearTransformSettings, + LinearTransformTransformer, +) +from tests.helpers.util import assert_messages_equal + + +class TestLinearTransformScalar: + """Tests for scalar scale and offset.""" + + def test_scale_only(self): + data = np.array([[1.0, 2.0], [3.0, 4.0]]) + msg_in = AxisArray(data, dims=["time", "ch"]) + backup = [copy.deepcopy(msg_in)] + + xformer = LinearTransformTransformer(LinearTransformSettings(scale=2.0)) + msg_out = xformer(msg_in) + + assert_messages_equal([msg_in], backup) + assert np.allclose(msg_out.data, data * 2.0) + + def test_offset_only(self): + data = np.array([[1.0, 2.0], [3.0, 4.0]]) + msg_in = AxisArray(data, dims=["time", "ch"]) + backup = [copy.deepcopy(msg_in)] + + xformer = LinearTransformTransformer(LinearTransformSettings(offset=10.0)) + msg_out = xformer(msg_in) + + assert_messages_equal([msg_in], backup) + assert np.allclose(msg_out.data, data + 10.0) + + def test_scale_and_offset(self): + data = np.array([[1.0, 2.0], [3.0, 4.0]]) + msg_in = AxisArray(data, dims=["time", "ch"]) + backup = [copy.deepcopy(msg_in)] + + xformer = LinearTransformTransformer(LinearTransformSettings(scale=2.0, offset=1.0)) + msg_out = xformer(msg_in) + + assert_messages_equal([msg_in], backup) + assert np.allclose(msg_out.data, data * 2.0 + 1.0) + + def test_identity(self): + """Default settings should be identity transform.""" + data = np.array([[1.0, 2.0], [3.0, 4.0]]) + msg_in = AxisArray(data, dims=["time", "ch"]) + + xformer = LinearTransformTransformer(LinearTransformSettings()) + msg_out = xformer(msg_in) + + assert np.allclose(msg_out.data, data) + + def test_negative_scale(self): + data = np.array([[1.0, 2.0], [3.0, 4.0]]) + msg_in = AxisArray(data, dims=["time", "ch"]) + + xformer = LinearTransformTransformer(LinearTransformSettings(scale=-1.0, offset=0.0)) + msg_out = xformer(msg_in) + + assert np.allclose(msg_out.data, -data) + + +class TestLinearTransformPerChannel: + """Tests for per-channel scale and offset.""" + + def test_per_channel_scale(self): + data = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + msg_in = AxisArray(data, dims=["time", "ch"]) + backup = [copy.deepcopy(msg_in)] + + xformer = LinearTransformTransformer(LinearTransformSettings(scale=[1.0, 2.0, 3.0], axis="ch")) + msg_out = xformer(msg_in) + + assert_messages_equal([msg_in], backup) + expected = data * np.array([1.0, 2.0, 3.0]) + assert np.allclose(msg_out.data, expected) + + def test_per_channel_offset(self): + data = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + msg_in = AxisArray(data, dims=["time", "ch"]) + backup = [copy.deepcopy(msg_in)] + + xformer = LinearTransformTransformer(LinearTransformSettings(offset=[10.0, 20.0, 30.0], axis="ch")) + msg_out = xformer(msg_in) + + assert_messages_equal([msg_in], backup) + expected = data + np.array([10.0, 20.0, 30.0]) + assert np.allclose(msg_out.data, expected) + + def test_per_channel_scale_and_offset(self): + data = np.array([[1.0, 2.0], [3.0, 4.0]]) + msg_in = AxisArray(data, dims=["time", "ch"]) + backup = [copy.deepcopy(msg_in)] + + xformer = LinearTransformTransformer( + LinearTransformSettings( + scale=[2.0, 0.5], + offset=[1.0, -1.0], + axis="ch", + ) + ) + msg_out = xformer(msg_in) + + assert_messages_equal([msg_in], backup) + expected = np.array( + [ + [1.0 * 2.0 + 1.0, 2.0 * 0.5 - 1.0], + [3.0 * 2.0 + 1.0, 4.0 * 0.5 - 1.0], + ] + ) + assert np.allclose(msg_out.data, expected) + + def test_velocity_to_beta_use_case(self): + """Test the velocity magnitude/angle to beta scaling use case.""" + # Simulate CART2POL output: magnitude ~314, angle 0-2π + data = np.array( + [ + [100.0, 0.0], + [200.0, np.pi], + [314.0, 2 * np.pi], + ] + ) + msg_in = AxisArray(data, dims=["time", "ch"]) + + # Scale to beta range 0.5-2.0 + xformer = LinearTransformTransformer( + LinearTransformSettings( + scale=[1.5 / 314, 1.5 / (2 * np.pi)], + offset=[0.5, 0.5], + axis="ch", + ) + ) + msg_out = xformer(msg_in) + + # Check output is in expected beta range + assert msg_out.data.min() >= 0.5 - 1e-10 + assert msg_out.data.max() <= 2.0 + 1e-10 + + # Check specific values + expected = np.array( + [ + [100.0 * 1.5 / 314 + 0.5, 0.0 * 1.5 / (2 * np.pi) + 0.5], + [200.0 * 1.5 / 314 + 0.5, np.pi * 1.5 / (2 * np.pi) + 0.5], + [314.0 * 1.5 / 314 + 0.5, 2 * np.pi * 1.5 / (2 * np.pi) + 0.5], + ] + ) + assert np.allclose(msg_out.data, expected) + + +class TestLinearTransformDifferentAxes: + """Tests for operating on different axes.""" + + def test_time_axis(self): + """Test per-sample scaling along time axis.""" + data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + msg_in = AxisArray(data, dims=["time", "ch"]) + + xformer = LinearTransformTransformer(LinearTransformSettings(scale=[1.0, 2.0, 3.0], axis="time")) + msg_out = xformer(msg_in) + + expected = np.array( + [ + [1.0 * 1.0, 2.0 * 1.0], + [3.0 * 2.0, 4.0 * 2.0], + [5.0 * 3.0, 6.0 * 3.0], + ] + ) + assert np.allclose(msg_out.data, expected) + + def test_3d_data_middle_axis(self): + """Test with 3D data operating on middle axis.""" + data = np.ones((2, 3, 4)) # time, ch, freq + msg_in = AxisArray(data, dims=["time", "ch", "freq"]) + + xformer = LinearTransformTransformer(LinearTransformSettings(scale=[1.0, 2.0, 3.0], axis="ch")) + msg_out = xformer(msg_in) + + assert msg_out.data.shape == (2, 3, 4) + assert np.allclose(msg_out.data[:, 0, :], 1.0) + assert np.allclose(msg_out.data[:, 1, :], 2.0) + assert np.allclose(msg_out.data[:, 2, :], 3.0) + + +class TestLinearTransformState: + """Tests for state persistence and hash-based reset.""" + + def test_state_persistence(self): + """State should persist across multiple calls with same shape.""" + data = np.array([[1.0, 2.0], [3.0, 4.0]]) + msg_in = AxisArray(data, dims=["time", "ch"]) + + xformer = LinearTransformTransformer(LinearTransformSettings(scale=[2.0, 0.5], offset=[1.0, -1.0], axis="ch")) + + # First call + msg_out1 = xformer(msg_in) + + # Second call - should use cached state + msg_out2 = xformer(msg_in) + + assert np.allclose(msg_out1.data, msg_out2.data) + + def test_state_reset_on_shape_change(self): + """State should reset when input shape changes.""" + xformer = LinearTransformTransformer(LinearTransformSettings(scale=[1.0, 2.0], axis="ch")) + + # First call with 2 channels + data1 = np.array([[1.0, 2.0]]) + msg1 = AxisArray(data1, dims=["time", "ch"]) + out1 = xformer(msg1) + assert np.allclose(out1.data, np.array([[1.0, 4.0]])) + + # Second call with 3 channels - should reset state + # Note: This will fail because scale array doesn't match new shape + # The transformer should handle this gracefully or raise an error + data2 = np.array([[1.0, 2.0, 3.0]]) + msg2 = AxisArray(data2, dims=["time", "ch"]) + with pytest.raises((ValueError, IndexError)): + xformer(msg2) + + def test_state_reset_on_ndim_change(self): + """State should reset when input ndim changes.""" + xformer = LinearTransformTransformer(LinearTransformSettings(scale=2.0, offset=1.0)) + + # First call with 2D data + data1 = np.array([[1.0, 2.0]]) + msg1 = AxisArray(data1, dims=["time", "ch"]) + out1 = xformer(msg1) + assert np.allclose(out1.data, data1 * 2.0 + 1.0) + + # Second call with 3D data + data2 = np.ones((2, 3, 4)) + msg2 = AxisArray(data2, dims=["time", "ch", "freq"]) + out2 = xformer(msg2) + assert np.allclose(out2.data, data2 * 2.0 + 1.0) + + +class TestLinearTransformEdgeCases: + """Tests for edge cases.""" + + def test_single_element(self): + data = np.array([[1.0]]) + msg_in = AxisArray(data, dims=["time", "ch"]) + + xformer = LinearTransformTransformer(LinearTransformSettings(scale=2.0, offset=1.0)) + msg_out = xformer(msg_in) + + assert np.allclose(msg_out.data, np.array([[3.0]])) + + def test_empty_data(self): + data = np.array([]).reshape(0, 2) + msg_in = AxisArray(data, dims=["time", "ch"]) + + xformer = LinearTransformTransformer(LinearTransformSettings(scale=2.0, offset=1.0)) + msg_out = xformer(msg_in) + + assert msg_out.data.shape == (0, 2) + + def test_large_values(self): + data = np.array([[1e10, 1e-10]]) + msg_in = AxisArray(data, dims=["time", "ch"]) + + xformer = LinearTransformTransformer(LinearTransformSettings(scale=2.0, offset=1.0)) + msg_out = xformer(msg_in) + + expected = data * 2.0 + 1.0 + assert np.allclose(msg_out.data, expected) + + def test_preserves_dtype(self): + """Output dtype should match computation dtype (float64).""" + data = np.array([[1, 2], [3, 4]], dtype=np.int32) + msg_in = AxisArray(data, dims=["time", "ch"]) + + xformer = LinearTransformTransformer(LinearTransformSettings(scale=2.0, offset=0.5)) + msg_out = xformer(msg_in) + + # Result should be float due to float scale/offset + assert msg_out.data.dtype in [np.float64, np.float32] + assert np.allclose(msg_out.data, np.array([[2.5, 4.5], [6.5, 8.5]])) + + def test_numpy_array_settings(self): + """Test with numpy arrays in settings instead of lists.""" + data = np.array([[1.0, 2.0], [3.0, 4.0]]) + msg_in = AxisArray(data, dims=["time", "ch"]) + + xformer = LinearTransformTransformer( + LinearTransformSettings( + scale=np.array([2.0, 0.5]), + offset=np.array([1.0, -1.0]), + axis="ch", + ) + ) + msg_out = xformer(msg_in) + + expected = np.array( + [ + [1.0 * 2.0 + 1.0, 2.0 * 0.5 - 1.0], + [3.0 * 2.0 + 1.0, 4.0 * 0.5 - 1.0], + ] + ) + assert np.allclose(msg_out.data, expected) diff --git a/tests/unit/test_math.py b/tests/unit/test_math.py index 8940fbdb..2133c012 100644 --- a/tests/unit/test_math.py +++ b/tests/unit/test_math.py @@ -2,12 +2,12 @@ import pytest from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.sigproc.math.abs import abs -from ezmsg.sigproc.math.clip import clip -from ezmsg.sigproc.math.difference import const_difference -from ezmsg.sigproc.math.invert import invert -from ezmsg.sigproc.math.log import log -from ezmsg.sigproc.math.scale import scale +from ezmsg.sigproc.math.abs import AbsTransformer +from ezmsg.sigproc.math.clip import ClipSettings, ClipTransformer +from ezmsg.sigproc.math.difference import ConstDifferenceSettings, ConstDifferenceTransformer +from ezmsg.sigproc.math.invert import InvertTransformer +from ezmsg.sigproc.math.log import LogSettings, LogTransformer +from ezmsg.sigproc.math.scale import ScaleSettings, ScaleTransformer def test_abs(): @@ -15,8 +15,8 @@ def test_abs(): n_chans = 255 in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans) msg_in = AxisArray(in_dat, dims=["time", "ch"]) - proc = abs() - msg_out = proc.send(msg_in) + xformer = AbsTransformer() + msg_out = xformer(msg_in) assert np.array_equal(msg_out.data, np.abs(in_dat)) @@ -28,8 +28,8 @@ def test_clip(a_min: float, a_max: float): in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans) msg_in = AxisArray(in_dat, dims=["time", "ch"]) - proc = clip(a_min, a_max) - msg_out = proc.send(msg_in) + xformer = ClipTransformer(ClipSettings(a_min=a_min, a_max=a_max)) + msg_out = xformer(msg_in) assert all(msg_out.data[np.where(in_dat < a_min)] == a_min) assert all(msg_out.data[np.where(in_dat > a_max)] == a_max) @@ -43,8 +43,8 @@ def test_const_difference(value: float, subtrahend: bool): in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans) msg_in = AxisArray(in_dat, dims=["time", "ch"]) - proc = const_difference(value, subtrahend) - msg_out = proc.send(msg_in) + xformer = ConstDifferenceTransformer(ConstDifferenceSettings(value=value, subtrahend=subtrahend)) + msg_out = xformer(msg_in) assert np.array_equal(msg_out.data, (in_dat - value) if subtrahend else (value - in_dat)) @@ -53,22 +53,22 @@ def test_invert(): n_chans = 255 in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans) msg_in = AxisArray(in_dat, dims=["time", "ch"]) - proc = invert() - msg_out = proc.send(msg_in) + xformer = InvertTransformer() + msg_out = xformer(msg_in) assert np.array_equal(msg_out.data, 1 / in_dat) @pytest.mark.parametrize("base", [np.e, 2, 10]) @pytest.mark.parametrize("dtype", [int, float]) -@pytest.mark.parametrize("clip", [False, True]) -def test_log(base: float, dtype, clip: bool): +@pytest.mark.parametrize("clip_zero", [False, True]) +def test_log(base: float, dtype, clip_zero: bool): n_times = 130 n_chans = 255 in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans).astype(dtype) msg_in = AxisArray(in_dat, dims=["time", "ch"]) - proc = log(base, clip_zero=clip) - msg_out = proc.send(msg_in) - if clip and dtype is float: + xformer = LogTransformer(LogSettings(base=base, clip_zero=clip_zero)) + msg_out = xformer(msg_in) + if clip_zero and dtype is float: in_dat = np.clip(in_dat, a_min=np.finfo(msg_in.data.dtype).tiny, a_max=None) assert np.array_equal(msg_out.data, np.log(in_dat) / np.log(base)) @@ -80,8 +80,8 @@ def test_scale(scale_factor: float): in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans) msg_in = AxisArray(in_dat, dims=["time", "ch"]) - proc = scale(scale_factor) - msg_out = proc.send(msg_in) + xformer = ScaleTransformer(ScaleSettings(scale=scale_factor)) + msg_out = xformer(msg_in) assert msg_out.data.shape == (n_times, n_chans) assert np.array_equal(msg_out.data, in_dat * scale_factor) diff --git a/tests/unit/test_scaler.py b/tests/unit/test_scaler.py index bd916787..95efdd3d 100644 --- a/tests/unit/test_scaler.py +++ b/tests/unit/test_scaler.py @@ -11,7 +11,6 @@ AdaptiveStandardScalerSettings, AdaptiveStandardScalerTransformer, scaler, - scaler_np, ) from tests.helpers.util import assert_messages_equal @@ -53,26 +52,13 @@ def test_scaler(fixture_arrays): backup = copy.deepcopy(test_input) tau = 0.010913566679372915 - """ - Test legacy interface. Should be deprecated. - """ - gen = scaler_np(time_constant=tau, axis="time") - outputs = [] - for chunk in test_input: - outputs.append(gen.send(chunk)) - output = AxisArray.concatenate(*outputs, dim="time") - assert np.allclose(output.data, expected_result, atol=1e-3) - assert_messages_equal(test_input, backup) - - """ - Test new interface - """ xformer = AdaptiveStandardScalerTransformer(time_constant=tau, axis="time") outputs = [] for chunk in test_input: outputs.append(xformer(chunk)) output = AxisArray.concatenate(*outputs, dim="time") assert np.allclose(output.data, expected_result, atol=1e-3) + assert_messages_equal(test_input, backup) def _make_scaler_test_msg(data: np.ndarray, fs: float = 1000.0) -> AxisArray: diff --git a/tests/unit/test_slicer.py b/tests/unit/test_slicer.py index cb4d31e6..5e2c2915 100644 --- a/tests/unit/test_slicer.py +++ b/tests/unit/test_slicer.py @@ -4,7 +4,7 @@ import pytest from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.sigproc.slicer import parse_slice, slicer +from ezmsg.sigproc.slicer import SlicerSettings, SlicerTransformer, parse_slice from tests.helpers.util import assert_messages_equal @@ -22,7 +22,7 @@ def test_parse_slice(): assert parse_slice("4:64, 68:100") == (slice(4, 64), slice(68, 100)) -def test_slicer_generator(): +def test_slicer_transformer(): n_times = 13 n_chans = 255 in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans) @@ -33,40 +33,40 @@ def test_slicer_generator(): "time": AxisArray.TimeAxis(fs=100.0, offset=0.1), "ch": AxisArray.CoordinateAxis(data=np.array([f"Ch{_}" for _ in range(n_chans)]), dims=["ch"]), }, - key="test_slicer_generator", + key="test_slicer_transformer", ) backup = [copy.deepcopy(msg_in)] - gen = slicer(selection=":2", axis="ch") - msg_out = gen.send(msg_in) + xformer = SlicerTransformer(SlicerSettings(selection=":2", axis="ch")) + msg_out = xformer(msg_in) assert_messages_equal([msg_in], backup) assert msg_out.data.shape == (n_times, 2) assert np.array_equal(msg_out.data, in_dat[:, :2]) assert np.may_share_memory(msg_out.data, in_dat) - gen = slicer(selection="::3", axis="ch") - msg_out = gen.send(msg_in) + xformer = SlicerTransformer(SlicerSettings(selection="::3", axis="ch")) + msg_out = xformer(msg_in) assert_messages_equal([msg_in], backup) assert msg_out.data.shape == (n_times, n_chans // 3) assert np.array_equal(msg_out.data, in_dat[:, ::3]) assert np.may_share_memory(msg_out.data, in_dat) - gen = slicer(selection="4:64", axis="ch") - msg_out = gen.send(msg_in) + xformer = SlicerTransformer(SlicerSettings(selection="4:64", axis="ch")) + msg_out = xformer(msg_in) assert_messages_equal([msg_in], backup) assert msg_out.data.shape == (n_times, 60) assert np.array_equal(msg_out.data, in_dat[:, 4:64]) assert np.may_share_memory(msg_out.data, in_dat) # Discontiguous slices leads to a copy - gen = slicer(selection="1, 3:5", axis="ch") - msg_out = gen.send(msg_in) + xformer = SlicerTransformer(SlicerSettings(selection="1, 3:5", axis="ch")) + msg_out = xformer(msg_in) assert_messages_equal([msg_in], backup) assert np.array_equal(msg_out.data, msg_in.data[:, [1, 3, 4]]) assert not np.may_share_memory(msg_out.data, in_dat) -def test_slicer_gen_drop_dim(): +def test_slicer_drop_dim(): n_times = 50 n_chans = 10 in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans) @@ -77,12 +77,12 @@ def test_slicer_gen_drop_dim(): "time": AxisArray.TimeAxis(fs=100.0, offset=0.1), "ch": AxisArray.CoordinateAxis(data=np.array([f"Ch{_}" for _ in range(n_chans)]), dims=["ch"]), }, - key="test_slicer_gen_drop_dim", + key="test_slicer_drop_dim", ) backup = [copy.deepcopy(msg_in)] - gen = slicer(selection="5", axis="ch") - msg_out = gen.send(msg_in) + xformer = SlicerTransformer(SlicerSettings(selection="5", axis="ch")) + msg_out = xformer(msg_in) assert_messages_equal([msg_in], backup) assert msg_out.data.shape == (n_times,) assert np.array_equal(msg_out.data, msg_in.data[:, 5]) @@ -108,9 +108,8 @@ def test_slicer_label(selection: str): ) backup = [copy.deepcopy(msg_in)] - gen = slicer(selection=selection, axis="ch") - # gen = slicer(selection=":3", axis="ch") - msg_out = gen.send(msg_in) + xformer = SlicerTransformer(SlicerSettings(selection=selection, axis="ch")) + msg_out = xformer(msg_in) assert_messages_equal([msg_in], backup) assert msg_out.data.shape == (n_times, 3) assert np.array_equal(msg_out.data, msg_in.data[:, :3]) diff --git a/tests/unit/test_spectrogram.py b/tests/unit/test_spectrogram.py index 14ba4192..26cd8e60 100644 --- a/tests/unit/test_spectrogram.py +++ b/tests/unit/test_spectrogram.py @@ -68,7 +68,7 @@ def test_spectrogram(): output=SpectralOutput.POSITIVE, ) - results = [proc.send(msg) for msg in messages] + results = [proc(msg) for msg in messages] results = [_ for _ in results if _.data.size] # Drop empty messages assert_messages_equal(messages, backup) diff --git a/tests/unit/test_spectrum.py b/tests/unit/test_spectrum.py index 97302905..73e12031 100644 --- a/tests/unit/test_spectrum.py +++ b/tests/unit/test_spectrum.py @@ -9,8 +9,9 @@ from ezmsg.sigproc.spectrum import ( SpectralOutput, SpectralTransform, + SpectrumSettings, + SpectrumTransformer, WindowFunction, - spectrum, ) from tests.helpers.util import ( assert_messages_equal, @@ -71,8 +72,8 @@ def test_spectrum_gen_multiwin(window: WindowFunction, transform: SpectralTransf input_multiwin = AxisArray.concatenate(*messages, dim="win") input_multiwin.axes["win"] = AxisArray.TimeAxis(offset=0, fs=1 / win_step_dur) - gen = spectrum(axis="time", window=window, transform=transform, output=output) - result = gen.send(input_multiwin) + proc = SpectrumTransformer(SpectrumSettings(axis="time", window=window, transform=transform, output=output)) + result = proc(input_multiwin) # _debug_plot_welch(input_multiwin, result, welch_db=True) assert isinstance(result, AxisArray) assert "time" not in result.dims @@ -115,8 +116,8 @@ def test_spectrum_gen(window: WindowFunction, transform: SpectralTransform, outp ) backup = [copy.deepcopy(_) for _ in messages] - gen = spectrum(axis="time", window=window, transform=transform, output=output) - results = [gen.send(msg) for msg in messages] + proc = SpectrumTransformer(SpectrumSettings(axis="time", window=window, transform=transform, output=output)) + results = [proc(msg) for msg in messages] assert_messages_equal(messages, backup) @@ -142,16 +143,18 @@ def test_spectrum_vs_sps_fft(complex: bool): ) nfft = 1 << (messages[0].data.shape[0] - 1).bit_length() # nextpow2 - gen = spectrum( - axis="time", - window=WindowFunction.NONE, - transform=SpectralTransform.RAW_COMPLEX if complex else SpectralTransform.REAL, - output=SpectralOutput.FULL if complex else SpectralOutput.POSITIVE, - norm="backward", - do_fftshift=False, - nfft=nfft, + proc = SpectrumTransformer( + SpectrumSettings( + axis="time", + window=WindowFunction.NONE, + transform=SpectralTransform.RAW_COMPLEX if complex else SpectralTransform.REAL, + output=SpectralOutput.FULL if complex else SpectralOutput.POSITIVE, + norm="backward", + do_fftshift=False, + nfft=nfft, + ) ) - results = [gen.send(msg) for msg in messages] + results = [proc(msg) for msg in messages] test_spec = results[0].data if complex: sp_res = sp_fft.fft(messages[0].data, n=nfft, axis=0) diff --git a/tests/unit/test_wavelets.py b/tests/unit/test_wavelets.py index e0e469fa..7f040e22 100644 --- a/tests/unit/test_wavelets.py +++ b/tests/unit/test_wavelets.py @@ -2,7 +2,7 @@ import pywt from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.sigproc.wavelets import MinPhaseMode, cwt +from ezmsg.sigproc.wavelets import CWTSettings, CWTTransformer, MinPhaseMode from tests.helpers.util import gaussian, make_chirp @@ -91,16 +91,18 @@ def test_cwt(): expected = np.swapaxes(expected, 0, 1) # Prep filterbank - gen = cwt( - frequencies=frequencies, - wavelet=wavelet, - min_phase=MinPhaseMode.HOMOMORPHIC, - axis="time", + proc = CWTTransformer( + CWTSettings( + frequencies=frequencies, + wavelet=wavelet, + min_phase=MinPhaseMode.HOMOMORPHIC, + axis="time", + ) ) # Pass the messages - out_messages = [gen.send(in_messages[0])] - out_messages += [gen.send(msg_in) for msg_in in in_messages[1:]] + out_messages = [proc(in_messages[0])] + out_messages += [proc(msg_in) for msg_in in in_messages[1:]] result = AxisArray.concatenate(*out_messages, dim="time") assert result.key == "test_cwt"