Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/guides/explanations/sigproc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ Often, all that is required is the following (e.g., for a custom transformer):

import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.sigproc.base import BaseTransformer, BaseTransformerUnit
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit


class CustomTransformerSettings(ez.Settings):
Expand Down
2 changes: 1 addition & 1 deletion docs/source/guides/tutorials/signalprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ Add the following import statements to the top of the `downsample.py` file:
)
import ezmsg.core as ez

from ezmsg.sigproc.base import (
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
Expand Down
24 changes: 23 additions & 1 deletion src/ezmsg/sigproc/affinetransform.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""Affine transformations via matrix multiplication: y = Ax or y = Ax + B.

For full matrix transformations where channels are mixed (off-diagonal weights),
use :obj:`AffineTransformTransformer` or the `AffineTransform` unit.

For simple per-channel scaling and offset (diagonal weights only), use
:obj:`LinearTransformTransformer` from :mod:`ezmsg.sigproc.linear` instead,
which is more efficient as it avoids matrix multiplication.
"""

import os
from pathlib import Path

Expand All @@ -17,7 +27,6 @@
class AffineTransformSettings(ez.Settings):
"""
Settings for :obj:`AffineTransform`.
See :obj:`affine_transform` for argument details.
"""

weights: np.ndarray | str | Path
Expand All @@ -39,6 +48,19 @@ class AffineTransformState:
class AffineTransformTransformer(
BaseStatefulTransformer[AffineTransformSettings, AxisArray, AxisArray, AffineTransformState]
):
"""Apply affine transformation via matrix multiplication: y = Ax or y = Ax + B.

Use this transformer when you need full matrix transformations that mix
channels (off-diagonal weights), such as spatial filters or projections.

For simple per-channel scaling and offset where each output channel depends
only on its corresponding input channel (diagonal weight matrix), use
:obj:`LinearTransformTransformer` instead, which is more efficient.

The weights matrix can include an offset row (stacked as [A|B]) where the
input is automatically augmented with a column of ones to compute y = Ax + B.
"""

def __call__(self, message: AxisArray) -> AxisArray:
# Override __call__ so we can shortcut if weights are None.
if self.settings.weights is None or (
Expand Down
2 changes: 1 addition & 1 deletion src/ezmsg/sigproc/butterworthzerophase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import ezmsg.core as ez
import numpy as np
import scipy.signal
from ezmsg.baseproc import SettingsType
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from ezmsg.sigproc.base import SettingsType
from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, butter_design_fun
from ezmsg.sigproc.filter import (
BACoeffs,
Expand Down
7 changes: 3 additions & 4 deletions src/ezmsg/sigproc/denormalize.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from ezmsg.sigproc.base import (
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace


class DenormalizeSettings(ez.Settings):
Expand Down
2 changes: 1 addition & 1 deletion src/ezmsg/sigproc/detrend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import scipy.signal as sps
from ezmsg.baseproc import BaseTransformerUnit
from ezmsg.util.messages.axisarray import AxisArray, replace

from ezmsg.sigproc.base import BaseTransformerUnit
from ezmsg.sigproc.ewma import EWMASettings, EWMATransformer


Expand Down
7 changes: 3 additions & 4 deletions src/ezmsg/sigproc/diff.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
from ezmsg.util.messages.util import replace

from ezmsg.sigproc.base import (
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
from ezmsg.util.messages.util import replace


class DiffSettings(ez.Settings):
Expand Down
3 changes: 1 addition & 2 deletions src/ezmsg/sigproc/extract_axis.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import ezmsg.core as ez
import numpy as np
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
from ezmsg.util.messages.axisarray import AxisArray, replace

from ezmsg.sigproc.base import BaseTransformer, BaseTransformerUnit


class ExtractAxisSettings(ez.Settings):
axis: str = "freq"
Expand Down
7 changes: 3 additions & 4 deletions src/ezmsg/sigproc/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@
import numpy as np
import numpy.typing as npt
import scipy.signal
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from ezmsg.sigproc.base import (
from ezmsg.baseproc import (
BaseConsumerUnit,
BaseStatefulTransformer,
BaseTransformerUnit,
SettingsType,
TransformerType,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/ezmsg/sigproc/fir_hilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import ezmsg.core as ez
import numpy as np
import scipy.signal as sps
from ezmsg.baseproc import BaseStatefulTransformer, processor_state
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from ezmsg.sigproc.base import BaseStatefulTransformer, processor_state
from ezmsg.sigproc.filter import (
BACoeffs,
BaseFilterByDesignTransformerUnit,
Expand Down
118 changes: 118 additions & 0 deletions src/ezmsg/sigproc/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Apply a linear transformation: output = scale * input + offset.

Supports per-element scale and offset along a specified axis.
Uses Array API for compatibility with numpy, cupy, torch, etc.

For full matrix transformations, use :obj:`AffineTransformTransformer` instead.
"""

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from array_api_compat import get_namespace
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from .base import BaseStatefulTransformer, BaseTransformerUnit, processor_state


class LinearTransformSettings(ez.Settings):
scale: float | list[float] | npt.ArrayLike = 1.0
"""Scale factor(s). Can be a scalar (applied to all elements) or an array
matching the size of the specified axis for per-element scaling."""

offset: float | list[float] | npt.ArrayLike = 0.0
"""Offset value(s). Can be a scalar (applied to all elements) or an array
matching the size of the specified axis for per-element offset."""

axis: str | None = None
"""Axis along which to apply per-element scale/offset. If None, scalar
scale/offset are broadcast to all elements."""


@processor_state
class LinearTransformState:
scale: npt.NDArray = None
"""Prepared scale array for broadcasting."""

offset: npt.NDArray = None
"""Prepared offset array for broadcasting."""


class LinearTransformTransformer(
BaseStatefulTransformer[LinearTransformSettings, AxisArray, AxisArray, LinearTransformState]
):
"""Apply linear transformation: output = scale * input + offset.

This transformer is optimized for element-wise linear operations with
optional per-channel (or per-axis) coefficients. For full matrix
transformations, use :obj:`AffineTransformTransformer` instead.

Examples:
# Uniform scaling and offset
>>> transformer = LinearTransformTransformer(LinearTransformSettings(scale=2.0, offset=1.0))

# Per-channel scaling (e.g., for 3-channel data along "ch" axis)
>>> transformer = LinearTransformTransformer(LinearTransformSettings(
... scale=[0.5, 1.0, 2.0],
... offset=[0.0, 0.1, 0.2],
... axis="ch"
... ))
"""

def _hash_message(self, message: AxisArray) -> int:
"""Hash based on shape and axis to detect when broadcast shapes need recalculation."""
axis = self.settings.axis
if axis is not None:
axis_idx = message.get_axis_idx(axis)
return hash((message.data.ndim, axis_idx, message.data.shape[axis_idx]))
return hash(message.data.ndim)

def _reset_state(self, message: AxisArray) -> None:
"""Prepare scale/offset arrays with proper broadcast shapes."""
xp = get_namespace(message.data)
ndim = message.data.ndim

scale = self.settings.scale
offset = self.settings.offset

# Convert settings to arrays
if isinstance(scale, (list, np.ndarray)):
scale = xp.asarray(scale, dtype=xp.float64)
else:
# Scalar: create a 0-d array
scale = xp.asarray(float(scale), dtype=xp.float64)

if isinstance(offset, (list, np.ndarray)):
offset = xp.asarray(offset, dtype=xp.float64)
else:
# Scalar: create a 0-d array
offset = xp.asarray(float(offset), dtype=xp.float64)

# If axis is specified and we have 1-d arrays, reshape for proper broadcasting
if self.settings.axis is not None and ndim > 0:
axis_idx = message.get_axis_idx(self.settings.axis)

if scale.ndim == 1:
# Create shape for broadcasting: all 1s except at axis_idx
broadcast_shape = [1] * ndim
broadcast_shape[axis_idx] = scale.shape[0]
scale = xp.reshape(scale, broadcast_shape)

if offset.ndim == 1:
broadcast_shape = [1] * ndim
broadcast_shape[axis_idx] = offset.shape[0]
offset = xp.reshape(offset, broadcast_shape)

self._state.scale = scale
self._state.offset = offset

def _process(self, message: AxisArray) -> AxisArray:
result = message.data * self._state.scale + self._state.offset
return replace(message, data=result)


class LinearTransform(BaseTransformerUnit[LinearTransformSettings, AxisArray, AxisArray, LinearTransformTransformer]):
"""Unit wrapper for LinearTransformTransformer."""

SETTINGS = LinearTransformSettings
8 changes: 4 additions & 4 deletions src/ezmsg/sigproc/rollingscaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from ezmsg.sigproc.base import (
from ezmsg.baseproc import (
BaseAdaptiveTransformer,
BaseAdaptiveTransformerUnit,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from ezmsg.sigproc.sampler import SampleMessage


Expand Down
11 changes: 8 additions & 3 deletions tests/unit/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from ezmsg.util.messages.axisarray import AxisArray
from frozendict import frozendict

from ezmsg.sigproc.activation import ACTIVATIONS, ActivationFunction, activation
from ezmsg.sigproc.activation import (
ACTIVATIONS,
ActivationFunction,
ActivationSettings,
ActivationTransformer,
)


@pytest.mark.parametrize("function", [_ for _ in ActivationFunction] + ActivationFunction.options())
Expand All @@ -25,8 +30,8 @@ def msg_generator():
)
yield msg

proc = activation(function=function)
out_msgs = [proc.send(_) for _ in msg_generator()]
proc = ActivationTransformer(ActivationSettings(function=function))
out_msgs = [proc(_) for _ in msg_generator()]
out_dat = AxisArray.concatenate(*out_msgs, dim="time").data

if function in ACTIVATIONS:
Expand Down
Loading