From 88c4a5cbdef47837112702f80d1ecc462d298857 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Tue, 20 Jan 2026 22:34:42 +0000 Subject: [PATCH 1/7] feat: added mri module --- pylops/medical/mri.py | 256 +++++++++++++++++++++++++ pytests/test_mri.py | 432 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 688 insertions(+) create mode 100644 pylops/medical/mri.py create mode 100644 pytests/test_mri.py diff --git a/pylops/medical/mri.py b/pylops/medical/mri.py new file mode 100644 index 00000000..6b3ec0c2 --- /dev/null +++ b/pylops/medical/mri.py @@ -0,0 +1,256 @@ +__all__ = [ + "MRI2D", +] + +from typing import Literal, Optional, Union + +import numpy as np + +from pylops import LinearOperator +from pylops.basicoperators import Diagonal, Restriction +from pylops.signalprocessing import FFT2D, Bilinear +from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray + + +class MRI2D(LinearOperator): + r"""2D Magnetic Resonance Imaging + + Apply 2D Magnetic Resonance Imaging operator to obtain a k-space data (i.e., + undersampled Fourier representation of the model). + + Parameters + ---------- + dims : :obj:`list` or :obj:`int` + Number of samples for each dimension. Must be 2-dimensional and of size + :math:`n_y \times n_x` + mask : :obj:`str` or :obj:`numpy.ndarray` + Mask to be applied in the Fourier domain: + + - :obj:`numpy.ndarray`: a 2-dimensional array of size :math:`n_y \times n_x` + with 1 in the selected locations; + - ``vertical-reg``: mask with vertical lines (regularly sampled around the + second dimension); + - ``vertical-uni``: mask with vertical lines (irregularly sampled around the + second dimension, with lines drawn from a uniform distribution); + - ``radial-reg``: mask with radial lines (regularly sampled around the + :math:`-\pi/\pi` angles); + - ``radial-uni``: mask with radial lines (irregularly sampled around the + :math:`-\pi/\pi` angles, with angles drawn from a uniform distribution); + nlines: :obj:`str` + Number of lines in the k-space. + perc_center: :obj:`float` + Percentage of total lines to retain in the center. + engine : :obj:`str`, optional + Engine used for fft computation (``numpy`` or ``fftw``) + dtype : :obj:`str`, optional + Type of elements in input array. + name : :obj:`str`, optional + Name of operator (to be used by :func:`pylops.utils.describe.describe`) + **kwargs_fft + Arbitrary keyword arguments to be passed to the selected fft method + + Attributes + ---------- + mask : :obj:`numpy.ndarray` + Mask applied in the Fourier domain. + ROp : :obj:`pylops.Restriction` or :obj:`pylops.Diagonal` or :obj:`pylops.signalprocessing.Bilinear` + Operator that applies the mask in the Fourier domain. + shape : :obj:`tuple` + Operator shape + explicit : :obj:`bool` + Operator contains a matrix that can be solved + explicitly (``True``) or not (``False``) + + Notes + ----- + The MRI2D operator applies 2-dimensional Fourier transform to the model, + followed by a subsampling with a given ``mask``: + + .. math:: + \mathbf{d} = \mathbf{R} \mathbf{F}_{k} \mathbf{m} + + where :math:`\mathbf{F}_{k}` is the 2-dimensional Fourier transform and + :math:`\mathbf{R}` is the mask. + + """ + + def __init__( + self, + dims: InputDimsLike, + mask: Union[ + Literal["vertical-reg", "vertical-uni", "radial-reg", "radial-uni"], NDArray + ], + nlines: Optional[int] = None, + perc_center: float = 0.1, + engine: str = "numpy", + dtype: DTypeLike = "complex128", + name: str = "M", + **kwargs_fft, + ) -> None: + self.dims = dims + self._mask_type = mask if isinstance(mask, str) else "mask" + self.engine = engine + + # Validate inputs + if isinstance(mask, str) and mask not in ( + "vertical-reg", + "vertical-uni", + "radial-reg", + "radial-uni", + ): + raise ValueError( + "mask must be a numpy array, 'vertical-reg', 'vertical-uni', 'radial-reg', or 'radial-uni'" + ) + if isinstance(mask, str) and mask == "vertical-reg" and perc_center > 0.0: + raise ValueError("perc_center must be 0.0 when using 'vertical-reg' mask") + + if self.engine not in ["numpy", "scipy", "mkl_fft"]: + raise ValueError("engine must be 'numpy', 'scipy', or 'mkl_fft'") + + if self._mask_type == "mask": + self.mask = mask + elif "vertical" in self._mask_type: + self.mask = self._vertical_mask( + dims, + nlines, + perc_center, + uniform=True if "reg" in self._mask_type else False, + ) + elif "radial" in self._mask_type: + self.mask = self._radial_mask( + dims, nlines, uniform=True if "reg" in self._mask_type else False + ) + self.ROp, Op = self._calc_op( + dims=dims, + mask_type=mask if isinstance(mask, str) else "mask", + mask=self.mask, + fft_engine=engine, + dtype=dtype, + **kwargs_fft, + ) + super().__init__(Op=Op, name=name) + + @staticmethod + def _vertical_mask( + dims: InputDimsLike, nlines: int, perc_center: float, uniform: bool = True + ) -> NDArray: + """Create vertical mask""" + nlines_center = int(perc_center * dims[1]) + if (nlines + nlines_center) > dims[1]: + raise ValueError( + "nlines and perc_center produce a number of lines " + "greater than the total number of lines of the k-space" + f"({nlines + nlines_center}>{dims[1]})" + ) + + if nlines_center == 0: + # No lines from the center + if uniform: + step = dims[1] // nlines + mask = np.arange(0, dims[1], step)[:nlines] + else: + rng = np.random.default_rng() + mask = rng.choice(np.arange(dims[1]), nlines, replace=False) + else: + # Lines taken from the center + istart_center = dims[1] // 2 - nlines_center // 2 + iend_center = dims[1] // 2 + nlines_center // 2 + (nlines_center % 2) + ilines_center = np.arange(istart_center, iend_center) + + # Other lines + if uniform: + nlines_left = nlines // 2 + nlines % 2 + step_left = istart_center // nlines_left + ilines_left = np.arange(0, istart_center, step_left)[:nlines_left] + nlines_right = nlines // 2 + step_right = (dims[1] - iend_center) // nlines_left + ilines_right = np.arange(iend_center, dims[1], step_right)[ + :nlines_right + ] + mask = np.sort(np.hstack((ilines_left, ilines_center, ilines_right))) + else: + rng = np.random.default_rng() + ilines_other = np.hstack( + (np.arange(0, istart_center), np.arange(iend_center, dims[1])) + ) + ilines_other = rng.choice(ilines_other, nlines, replace=False) + mask = np.sort(np.hstack((ilines_center, ilines_other))) + return mask + + @staticmethod + def _radial_mask(dims: InputDimsLike, nlines: int, uniform: bool = True) -> NDArray: + """Create radial mask""" + npoints_per_line = dims[1] - 1 + + # Define angles + if uniform: + thetas = np.linspace(0, np.pi, nlines, endpoint=False) + else: + rng = np.random.default_rng() + thetas = rng.uniform(-np.pi, np.pi, nlines) + + # Create lines + lines = [] + for theta in thetas: + if theta == np.pi / 2: + # Create vertical line + xline = np.zeros(npoints_per_line) + yline = np.linspace( + -dims[1] // 2 + 1, dims[1] // 2 - 1, npoints_per_line, endpoint=True + ) + elif np.tan(theta) >= 0: + # Create lines for positive angles + xmax = min(dims[1] // 2, (dims[0] // 2) / np.tan(theta)) + xline = np.linspace( + -xmax, + min(xmax, dims[0] // 2 - 1 - (dims[0] + 1) % 2), + npoints_per_line, + endpoint=True, + ) + yline = np.tan(theta) * xline + else: + # Create lines for negative angles + xmin = max(-dims[1] // 2 + 1, (dims[0] // 2) / np.tan(theta)) + xline = np.linspace( + xmin, min(-xmin, dims[0] // 2 - 1), npoints_per_line, endpoint=True + ) + yline = np.tan(theta) * xline + xline, yline = xline + dims[0] // 2, yline + dims[1] // 2 + lines.append(np.vstack((xline, yline))) + mask = np.concatenate(lines, axis=1) + mask = mask[:, mask[0] < dims[0] - 1] + mask = mask[:, mask[1] < dims[1] - 1] + mask = np.unique(mask, axis=1) + return mask + + def _matvec(self, x: NDArray) -> NDArray: + return super()._matvec(x) + + def _rmatvec(self, x: NDArray) -> NDArray: + return super()._rmatvec(x) + + @staticmethod + def _calc_op( + dims: InputDimsLike, + mask_type: "str", + mask: NDArray, + fft_engine: float, + dtype: DTypeLike, + **kwargs_fft, + ): + """Calculate MRI operator""" + fop = FFT2D( + dims, + nffts=dims, + fftshift_after=True, + engine=fft_engine, + dtype=dtype, + **kwargs_fft, + ) + if mask_type == "mask": + rop = Diagonal(mask, dtype=dtype) + elif "vertical" in mask_type: + rop = Restriction(dims, mask, axis=-1, forceflat=True, dtype=dtype) + elif "radial" in mask_type: + rop = Bilinear(mask, dims, dtype=dtype) + return rop, rop @ fop diff --git a/pytests/test_mri.py b/pytests/test_mri.py new file mode 100644 index 00000000..4e605ccf --- /dev/null +++ b/pytests/test_mri.py @@ -0,0 +1,432 @@ +import os + +if int(os.environ.get("TEST_CUPY_PYLOPS", 0)): + import cupy as np + from cupy.testing import assert_array_almost_equal, assert_array_equal + + backend = "cupy" +else: + import numpy as np + from numpy.testing import assert_array_almost_equal, assert_array_equal + + backend = "numpy" +import pytest + +from pylops.medical import MRI2D +from pylops.utils import dottest, mkl_fft_enabled + +par1 = { + "ny": 32, + "nx": 64, + "imag": 0, + "dtype": "complex128", + "engine": "numpy", +} # real input, complex dtype, numpy engine +par2 = { + "ny": 32, + "nx": 64, + "imag": 1j, + "dtype": "complex128", + "engine": "numpy", +} # complex input, complex dtype, numpy engine +par3 = { + "ny": 32, + "nx": 64, + "imag": 0, + "dtype": "complex64", + "engine": "numpy", +} # real input, complex64 dtype, numpy engine +par4 = { + "ny": 32, + "nx": 64, + "imag": 1j, + "dtype": "complex64", + "engine": "scipy", +} # complex input, complex64 dtype, scipy engine + +np.random.seed(10) + + +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)]) +def test_MRI2D_mask_array(par): + """Dot-test and forward/adjoint for MRI2D operator with numpy array mask""" + np.random.seed(10) + + # Create a random mask + mask = np.zeros((par["ny"], par["nx"]), dtype=bool) + nselected = int(par["ny"] * par["nx"] * 0.3) + indices = np.random.choice(par["ny"] * par["nx"], nselected, replace=False) + mask.flat[indices] = True + mask = mask.astype(par["dtype"]) + + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask=mask, + engine=par["engine"], + dtype=par["dtype"], + ) + + assert dottest( + Mop, + nselected, + par["ny"] * par["nx"], + complexflag=0 if par["imag"] == 0 else 3, + backend=backend, + ) + + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( + 0, 1, (par["ny"], par["nx"]) + ) + y = Mop * x.ravel() + xadj = Mop.H * y + + assert y.shape[0] == nselected + assert xadj.shape[0] == par["ny"] * par["nx"] + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_MRI2D_vertical_reg(par): + """Dot-test and forward/adjoint for MRI2D operator with vertical-reg mask""" + np.random.seed(10) + + nlines = 16 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="vertical-reg", + nlines=nlines, + perc_center=0.0, + engine=par["engine"], + dtype=par["dtype"], + ) + + assert dottest( + Mop, + par["ny"] * nlines, + par["ny"] * par["nx"], + complexflag=0 if par["imag"] == 0 else 3, + backend=backend, + ) + + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( + 0, 1, (par["ny"], par["nx"]) + ) + y = Mop * x.ravel() + xadj = Mop.H * y + + assert y.shape[0] == par["ny"] * nlines + assert xadj.shape[0] == par["ny"] * par["nx"] + assert len(Mop.mask) == nlines + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_MRI2D_vertical_uni(par): + """Dot-test and forward/adjoint for MRI2D operator with vertical-uni mask""" + np.random.seed(10) + + nlines = 16 + perc_center = 0.1 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="vertical-uni", + nlines=nlines, + perc_center=perc_center, + engine=par["engine"], + dtype=par["dtype"], + ) + + assert dottest( + Mop, + par["ny"] * (nlines + int(perc_center * par["nx"])), + par["ny"] * par["nx"], + complexflag=0 if par["imag"] == 0 else 3, + backend=backend, + ) + + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( + 0, 1, (par["ny"], par["nx"]) + ) + y = Mop * x.ravel() + xadj = Mop.H * y + + nlines_total = nlines + int(perc_center * par["nx"]) + assert y.shape[0] == par["ny"] * nlines_total + assert xadj.shape[0] == par["ny"] * par["nx"] + assert len(Mop.mask) == nlines_total + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_MRI2D_radial_reg(par): + """Dot-test and forward/adjoint for MRI2D operator with radial-reg mask""" + np.random.seed(10) + + nlines = 8 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="radial-reg", + nlines=nlines, + engine=par["engine"], + dtype=par["dtype"], + ) + + # For radial masks, output size depends on the number of points in the mask + npoints = Mop.mask.shape[1] + + assert dottest( + Mop, + npoints, + par["ny"] * par["nx"], + complexflag=0 if par["imag"] == 0 else 3, + backend=backend, + ) + + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( + 0, 1, (par["ny"], par["nx"]) + ) + y = Mop * x.ravel() + xadj = Mop.H * y + + assert y.shape[0] == npoints + assert xadj.shape[0] == par["ny"] * par["nx"] + assert Mop.mask.shape[0] == 2 # x and y coordinates + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_MRI2D_radial_uni(par): + """Dot-test and forward/adjoint for MRI2D operator with radial-uni mask""" + np.random.seed(10) + + nlines = 8 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="radial-uni", + nlines=nlines, + engine=par["engine"], + dtype=par["dtype"], + ) + + # For radial masks, output size depends on the number of points in the mask + npoints = Mop.mask.shape[1] + + assert dottest( + Mop, + npoints, + par["ny"] * par["nx"], + complexflag=0 if par["imag"] == 0 else 3, + backend=backend, + ) + + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( + 0, 1, (par["ny"], par["nx"]) + ) + y = Mop * x.ravel() + xadj = Mop.H * y + + assert y.shape[0] == npoints + assert xadj.shape[0] == par["ny"] * par["nx"] + assert Mop.mask.shape[0] == 2 # x and y coordinates + + +def test_MRI2D_invalid_mask(): + """Test MRI2D operator with invalid mask string""" + with pytest.raises(ValueError, match="mask must be"): + MRI2D( + dims=(32, 64), + mask="invalid-mask", + nlines=16, + dtype="complex128", + ) + + +def test_MRI2D_invalid_engine(): + """Test MRI2D operator with invalid engine""" + mask = np.zeros((32, 64), dtype="complex128") + mask[::2, ::2] = 1.0 + + with pytest.raises(ValueError, match="engine must be"): + MRI2D( + dims=(32, 64), + mask=mask, + engine="invalid-engine", + dtype="complex128", + ) + + +def test_MRI2D_vertical_reg_invalid_perc_center(): + """Test MRI2D operator with vertical-reg mask and non-zero perc_center""" + with pytest.raises(ValueError, match="perc_center must be 0.0"): + MRI2D( + dims=(32, 64), + mask="vertical-reg", + nlines=16, + perc_center=0.1, + dtype="complex128", + ) + + +def test_MRI2D_vertical_mask_invalid_nlines(): + """Test MRI2D operator with vertical mask and invalid nlines""" + with pytest.raises(ValueError, match="nlines and perc_center"): + MRI2D( + dims=(32, 64), + mask="vertical-uni", + nlines=60, + perc_center=0.5, + dtype="complex128", + ) + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_MRI2D_vertical_mask_no_center(par): + """Test MRI2D operator with vertical mask and no center lines""" + np.random.seed(10) + + nlines = 16 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="vertical-uni", + nlines=nlines, + perc_center=0.0, + engine=par["engine"], + dtype=par["dtype"], + ) + + assert len(Mop.mask) == nlines + assert dottest( + Mop, + par["ny"] * nlines, + par["ny"] * par["nx"], + complexflag=0 if par["imag"] == 0 else 3, + backend=backend, + ) + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_MRI2D_attributes(par): + """Test MRI2D operator attributes""" + np.random.seed(10) + + mask = np.zeros((par["ny"], par["nx"]), dtype=bool) + mask[::2, ::2] = True + mask = mask.astype(par["dtype"]) + + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask=mask, + engine=par["engine"], + dtype=par["dtype"], + name="TestMRI", + ) + + assert Mop.dims == (par["ny"], par["nx"]) + assert Mop.engine == par["engine"] + assert Mop.name == "TestMRI" + assert hasattr(Mop, "mask") + assert hasattr(Mop, "ROp") + assert Mop.shape[1] == par["ny"] * par["nx"] + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_MRI2D_non_contiguous_input(par): + """Test MRI2D operator with non-contiguous input""" + np.random.seed(10) + + mask = np.zeros((par["ny"], par["nx"]), dtype=bool) + mask[::2, ::2] = True + mask = mask.astype(par["dtype"]) + + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask=mask, + engine=par["engine"], + dtype=par["dtype"], + ) + + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( + 0, 1, (par["ny"], par["nx"]) + ) + x_noncontig = x[:, ::-1] # non-contiguous view + + y = Mop * x_noncontig.ravel() + assert not np.allclose(y, 0.0) + + +@pytest.mark.skipif(not mkl_fft_enabled(), reason="MKL FFT not available") +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_MRI2D_mkl_engine(par): + """Test MRI2D operator with MKL FFT engine""" + np.random.seed(10) + + mask = np.zeros((par["ny"], par["nx"]), dtype=bool) + mask[::2, ::2] = True + mask = mask.astype(par["dtype"]) + + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask=mask, + engine="mkl_fft", + dtype=par["dtype"], + ) + + assert dottest( + Mop, + np.sum(mask), + par["ny"] * par["nx"], + complexflag=0 if par["imag"] == 0 else 3, + backend=backend, + ) + + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( + 0, 1, (par["ny"], par["nx"]) + ) + y = Mop * x.ravel() + xadj = Mop.H * y + + assert y.shape[0] == np.sum(mask) + assert xadj.shape[0] == par["ny"] * par["nx"] + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_MRI2D_vertical_mask_regularity(par): + """Test that vertical-reg mask produces regularly spaced lines""" + np.random.seed(10) + + nlines = 8 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="vertical-reg", + nlines=nlines, + perc_center=0.0, + engine=par["engine"], + dtype=par["dtype"], + ) + + mask_indices = Mop.mask + # Check that indices are regularly spaced + if len(mask_indices) > 1: + steps = np.diff(np.sort(mask_indices)) + # All steps should be approximately equal (within rounding) + assert np.allclose(steps, steps[0], atol=1) + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_MRI2D_radial_mask_shape(par): + """Test that radial mask has correct shape""" + np.random.seed(10) + + nlines = 8 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="radial-reg", + nlines=nlines, + engine=par["engine"], + dtype=par["dtype"], + ) + + # Radial mask should be 2 x npoints array + assert Mop.mask.shape[0] == 2 + assert Mop.mask.shape[1] > 0 + # All points should be within bounds + assert np.all(Mop.mask[0] >= 0) + assert np.all(Mop.mask[0] < par["ny"]) + assert np.all(Mop.mask[1] >= 0) + assert np.all(Mop.mask[1] < par["nx"]) From 68e58880d1204e108b6959039785da4350d81428 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Tue, 20 Jan 2026 23:07:30 +0000 Subject: [PATCH 2/7] minor: update test_mri --- pytests/test_mri.py | 336 ++++++++++++++++---------------------------- 1 file changed, 121 insertions(+), 215 deletions(-) diff --git a/pytests/test_mri.py b/pytests/test_mri.py index 4e605ccf..d60c7bb5 100644 --- a/pytests/test_mri.py +++ b/pytests/test_mri.py @@ -18,36 +18,75 @@ par1 = { "ny": 32, "nx": 64, - "imag": 0, "dtype": "complex128", "engine": "numpy", } # real input, complex dtype, numpy engine par2 = { "ny": 32, "nx": 64, - "imag": 1j, - "dtype": "complex128", - "engine": "numpy", -} # complex input, complex dtype, numpy engine -par3 = { - "ny": 32, - "nx": 64, - "imag": 0, "dtype": "complex64", "engine": "numpy", } # real input, complex64 dtype, numpy engine -par4 = { +par3 = { "ny": 32, "nx": 64, - "imag": 1j, "dtype": "complex64", "engine": "scipy", -} # complex input, complex64 dtype, scipy engine +} # real input, complex64 dtype, scipy engine np.random.seed(10) -@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)]) +def test_MRI2D_invalid_mask(): + """Test MRI2D operator with invalid mask string""" + with pytest.raises(ValueError, match="mask must be"): + MRI2D( + dims=(32, 64), + mask="invalid-mask", + nlines=16, + dtype="complex128", + ) + + +def test_MRI2D_invalid_engine(): + """Test MRI2D operator with invalid engine""" + mask = np.zeros((32, 64), dtype="complex128") + mask[::2, ::2] = 1.0 + + with pytest.raises(ValueError, match="engine must be"): + MRI2D( + dims=(32, 64), + mask=mask, + engine="invalid-engine", + dtype="complex128", + ) + + +def test_MRI2D_vertical_reg_invalid_perc_center(): + """Test MRI2D operator with vertical-reg mask and non-zero perc_center""" + with pytest.raises(ValueError, match="perc_center must be 0.0"): + MRI2D( + dims=(32, 64), + mask="vertical-reg", + nlines=16, + perc_center=0.1, + dtype="complex128", + ) + + +def test_MRI2D_vertical_mask_invalid_nlines(): + """Test MRI2D operator with vertical mask and invalid nlines""" + with pytest.raises(ValueError, match="nlines and perc_center"): + MRI2D( + dims=(32, 64), + mask="vertical-uni", + nlines=60, + perc_center=0.5, + dtype="complex128", + ) + + +@pytest.mark.parametrize("par", [(par1), (par2), (par3)]) def test_MRI2D_mask_array(par): """Dot-test and forward/adjoint for MRI2D operator with numpy array mask""" np.random.seed(10) @@ -66,21 +105,20 @@ def test_MRI2D_mask_array(par): dtype=par["dtype"], ) + # For Diagonal mask, output size is same as input size assert dottest( Mop, - nselected, par["ny"] * par["nx"], - complexflag=0 if par["imag"] == 0 else 3, + par["ny"] * par["nx"], + complexflag=2, backend=backend, ) - x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( - 0, 1, (par["ny"], par["nx"]) - ) + x = np.random.normal(0, 1, (par["ny"], par["nx"])) y = Mop * x.ravel() xadj = Mop.H * y - assert y.shape[0] == nselected + assert y.shape[0] == par["ny"] * par["nx"] assert xadj.shape[0] == par["ny"] * par["nx"] @@ -103,13 +141,11 @@ def test_MRI2D_vertical_reg(par): Mop, par["ny"] * nlines, par["ny"] * par["nx"], - complexflag=0 if par["imag"] == 0 else 3, + complexflag=2, backend=backend, ) - x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( - 0, 1, (par["ny"], par["nx"]) - ) + x = np.random.normal(0, 1, (par["ny"], par["nx"])) y = Mop * x.ravel() xadj = Mop.H * y @@ -118,6 +154,29 @@ def test_MRI2D_vertical_reg(par): assert len(Mop.mask) == nlines +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_MRI2D_vertical_mask_regularity(par): + """Test that vertical-reg mask produces regularly spaced lines""" + np.random.seed(10) + + nlines = 8 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="vertical-reg", + nlines=nlines, + perc_center=0.0, + engine=par["engine"], + dtype=par["dtype"], + ) + + mask_indices = Mop.mask + # Check that indices are regularly spaced + if len(mask_indices) > 1: + steps = np.diff(np.sort(mask_indices)) + # All steps should be approximately equal (within rounding) + assert np.allclose(steps, steps[0], atol=1) + + @pytest.mark.parametrize("par", [(par1), (par2)]) def test_MRI2D_vertical_uni(par): """Dot-test and forward/adjoint for MRI2D operator with vertical-uni mask""" @@ -138,13 +197,11 @@ def test_MRI2D_vertical_uni(par): Mop, par["ny"] * (nlines + int(perc_center * par["nx"])), par["ny"] * par["nx"], - complexflag=0 if par["imag"] == 0 else 3, + complexflag=2, backend=backend, ) - x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( - 0, 1, (par["ny"], par["nx"]) - ) + x = np.random.normal(0, 1, (par["ny"], par["nx"])) y = Mop * x.ravel() xadj = Mop.H * y @@ -155,50 +212,39 @@ def test_MRI2D_vertical_uni(par): @pytest.mark.parametrize("par", [(par1), (par2)]) -def test_MRI2D_radial_reg(par): - """Dot-test and forward/adjoint for MRI2D operator with radial-reg mask""" +def test_MRI2D_vertical_uni_no_center(par): + """Test MRI2D operator with vertical mask and no center lines""" np.random.seed(10) - nlines = 8 + nlines = 16 Mop = MRI2D( dims=(par["ny"], par["nx"]), - mask="radial-reg", + mask="vertical-uni", nlines=nlines, + perc_center=0.0, engine=par["engine"], dtype=par["dtype"], ) - # For radial masks, output size depends on the number of points in the mask - npoints = Mop.mask.shape[1] - + assert len(Mop.mask) == nlines assert dottest( Mop, - npoints, + par["ny"] * nlines, par["ny"] * par["nx"], - complexflag=0 if par["imag"] == 0 else 3, + complexflag=2, backend=backend, ) - x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( - 0, 1, (par["ny"], par["nx"]) - ) - y = Mop * x.ravel() - xadj = Mop.H * y - - assert y.shape[0] == npoints - assert xadj.shape[0] == par["ny"] * par["nx"] - assert Mop.mask.shape[0] == 2 # x and y coordinates - @pytest.mark.parametrize("par", [(par1), (par2)]) -def test_MRI2D_radial_uni(par): - """Dot-test and forward/adjoint for MRI2D operator with radial-uni mask""" +def test_MRI2D_radial_reg(par): + """Dot-test and forward/adjoint for MRI2D operator with radial-reg mask""" np.random.seed(10) nlines = 8 Mop = MRI2D( dims=(par["ny"], par["nx"]), - mask="radial-uni", + mask="radial-reg", nlines=nlines, engine=par["engine"], dtype=par["dtype"], @@ -211,146 +257,54 @@ def test_MRI2D_radial_uni(par): Mop, npoints, par["ny"] * par["nx"], - complexflag=0 if par["imag"] == 0 else 3, + complexflag=2, backend=backend, ) - x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( - 0, 1, (par["ny"], par["nx"]) - ) - y = Mop * x.ravel() + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + y = Mop * x xadj = Mop.H * y - assert y.shape[0] == npoints - assert xadj.shape[0] == par["ny"] * par["nx"] + assert y.size == npoints + assert xadj.shape == (par["ny"], par["nx"]) assert Mop.mask.shape[0] == 2 # x and y coordinates -def test_MRI2D_invalid_mask(): - """Test MRI2D operator with invalid mask string""" - with pytest.raises(ValueError, match="mask must be"): - MRI2D( - dims=(32, 64), - mask="invalid-mask", - nlines=16, - dtype="complex128", - ) - - -def test_MRI2D_invalid_engine(): - """Test MRI2D operator with invalid engine""" - mask = np.zeros((32, 64), dtype="complex128") - mask[::2, ::2] = 1.0 - - with pytest.raises(ValueError, match="engine must be"): - MRI2D( - dims=(32, 64), - mask=mask, - engine="invalid-engine", - dtype="complex128", - ) - - -def test_MRI2D_vertical_reg_invalid_perc_center(): - """Test MRI2D operator with vertical-reg mask and non-zero perc_center""" - with pytest.raises(ValueError, match="perc_center must be 0.0"): - MRI2D( - dims=(32, 64), - mask="vertical-reg", - nlines=16, - perc_center=0.1, - dtype="complex128", - ) - - -def test_MRI2D_vertical_mask_invalid_nlines(): - """Test MRI2D operator with vertical mask and invalid nlines""" - with pytest.raises(ValueError, match="nlines and perc_center"): - MRI2D( - dims=(32, 64), - mask="vertical-uni", - nlines=60, - perc_center=0.5, - dtype="complex128", - ) - - @pytest.mark.parametrize("par", [(par1), (par2)]) -def test_MRI2D_vertical_mask_no_center(par): - """Test MRI2D operator with vertical mask and no center lines""" +def test_MRI2D_radial_uni(par): + """Dot-test and forward/adjoint for MRI2D operator with radial-uni mask""" np.random.seed(10) - nlines = 16 + nlines = 8 Mop = MRI2D( dims=(par["ny"], par["nx"]), - mask="vertical-uni", + mask="radial-uni", nlines=nlines, - perc_center=0.0, engine=par["engine"], dtype=par["dtype"], ) - assert len(Mop.mask) == nlines + # For radial masks, output size depends on the number of points in the mask + npoints = Mop.mask.shape[1] + assert dottest( Mop, - par["ny"] * nlines, + npoints, par["ny"] * par["nx"], - complexflag=0 if par["imag"] == 0 else 3, + complexflag=2, backend=backend, ) + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + y = Mop * x + xadj = Mop.H * y -@pytest.mark.parametrize("par", [(par1), (par2)]) -def test_MRI2D_attributes(par): - """Test MRI2D operator attributes""" - np.random.seed(10) - - mask = np.zeros((par["ny"], par["nx"]), dtype=bool) - mask[::2, ::2] = True - mask = mask.astype(par["dtype"]) - - Mop = MRI2D( - dims=(par["ny"], par["nx"]), - mask=mask, - engine=par["engine"], - dtype=par["dtype"], - name="TestMRI", - ) - - assert Mop.dims == (par["ny"], par["nx"]) - assert Mop.engine == par["engine"] - assert Mop.name == "TestMRI" - assert hasattr(Mop, "mask") - assert hasattr(Mop, "ROp") - assert Mop.shape[1] == par["ny"] * par["nx"] - - -@pytest.mark.parametrize("par", [(par1), (par2)]) -def test_MRI2D_non_contiguous_input(par): - """Test MRI2D operator with non-contiguous input""" - np.random.seed(10) - - mask = np.zeros((par["ny"], par["nx"]), dtype=bool) - mask[::2, ::2] = True - mask = mask.astype(par["dtype"]) - - Mop = MRI2D( - dims=(par["ny"], par["nx"]), - mask=mask, - engine=par["engine"], - dtype=par["dtype"], - ) - - x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( - 0, 1, (par["ny"], par["nx"]) - ) - x_noncontig = x[:, ::-1] # non-contiguous view - - y = Mop * x_noncontig.ravel() - assert not np.allclose(y, 0.0) + assert y.size == npoints + assert xadj.shape == (par["ny"], par["nx"]) + assert Mop.mask.shape[0] == 2 # x and y coordinates -@pytest.mark.skipif(not mkl_fft_enabled(), reason="MKL FFT not available") +@pytest.mark.skipif(not mkl_fft_enabled, reason="MKL FFT not available") @pytest.mark.parametrize("par", [(par1), (par2)]) def test_MRI2D_mkl_engine(par): """Test MRI2D operator with MKL FFT engine""" @@ -367,66 +321,18 @@ def test_MRI2D_mkl_engine(par): dtype=par["dtype"], ) + # For Diagonal mask, output size is same as input size assert dottest( Mop, - np.sum(mask), par["ny"] * par["nx"], - complexflag=0 if par["imag"] == 0 else 3, + par["ny"] * par["nx"], + complexflag=2, backend=backend, ) - x = np.random.normal(0, 1, (par["ny"], par["nx"])) + par["imag"] * np.random.normal( - 0, 1, (par["ny"], par["nx"]) - ) + x = np.random.normal(0, 1, (par["ny"], par["nx"])) y = Mop * x.ravel() xadj = Mop.H * y - assert y.shape[0] == np.sum(mask) + assert y.shape[0] == par["ny"] * par["nx"] assert xadj.shape[0] == par["ny"] * par["nx"] - - -@pytest.mark.parametrize("par", [(par1), (par2)]) -def test_MRI2D_vertical_mask_regularity(par): - """Test that vertical-reg mask produces regularly spaced lines""" - np.random.seed(10) - - nlines = 8 - Mop = MRI2D( - dims=(par["ny"], par["nx"]), - mask="vertical-reg", - nlines=nlines, - perc_center=0.0, - engine=par["engine"], - dtype=par["dtype"], - ) - - mask_indices = Mop.mask - # Check that indices are regularly spaced - if len(mask_indices) > 1: - steps = np.diff(np.sort(mask_indices)) - # All steps should be approximately equal (within rounding) - assert np.allclose(steps, steps[0], atol=1) - - -@pytest.mark.parametrize("par", [(par1), (par2)]) -def test_MRI2D_radial_mask_shape(par): - """Test that radial mask has correct shape""" - np.random.seed(10) - - nlines = 8 - Mop = MRI2D( - dims=(par["ny"], par["nx"]), - mask="radial-reg", - nlines=nlines, - engine=par["engine"], - dtype=par["dtype"], - ) - - # Radial mask should be 2 x npoints array - assert Mop.mask.shape[0] == 2 - assert Mop.mask.shape[1] > 0 - # All points should be within bounds - assert np.all(Mop.mask[0] >= 0) - assert np.all(Mop.mask[0] < par["ny"]) - assert np.all(Mop.mask[1] >= 0) - assert np.all(Mop.mask[1] < par["nx"]) From 32e8b60b9c0253a4f9cebf1d1c9cb86d2f550dba Mon Sep 17 00:00:00 2001 From: mrava87 Date: Wed, 21 Jan 2026 08:28:15 +0000 Subject: [PATCH 3/7] minor: add mri to init file --- pylops/medical/__init__.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pylops/medical/__init__.py b/pylops/medical/__init__.py index b11b398f..2771aac3 100644 --- a/pylops/medical/__init__.py +++ b/pylops/medical/__init__.py @@ -8,11 +8,12 @@ A list of operators present in pylops.medical: CT2D 2D Computerized Tomography. + MRI2D 2D Magnetic Resonance Imaging. """ from .ct import * +from .mri import * -__all__ = [ - "CT2D", -] + +__all__ = ["CT2D", "MRI2D"] From 313531e8d3509a7a8ef68b434e7c8cf186af986c Mon Sep 17 00:00:00 2001 From: mrava87 Date: Wed, 21 Jan 2026 08:48:25 +0000 Subject: [PATCH 4/7] minor: fix mri tests --- pytests/test_mri.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/pytests/test_mri.py b/pytests/test_mri.py index d60c7bb5..b84a5def 100644 --- a/pytests/test_mri.py +++ b/pytests/test_mri.py @@ -20,21 +20,13 @@ "nx": 64, "dtype": "complex128", "engine": "numpy", -} # real input, complex dtype, numpy engine +} # even input, complex dtype, numpy engine par2 = { - "ny": 32, - "nx": 64, - "dtype": "complex64", - "engine": "numpy", -} # real input, complex64 dtype, numpy engine -par3 = { - "ny": 32, - "nx": 64, - "dtype": "complex64", + "ny": 33, + "nx": 65, + "dtype": "complex128", "engine": "scipy", -} # real input, complex64 dtype, scipy engine - -np.random.seed(10) +} # odd input, complex64 dtype, scipy engine def test_MRI2D_invalid_mask(): @@ -86,7 +78,7 @@ def test_MRI2D_vertical_mask_invalid_nlines(): ) -@pytest.mark.parametrize("par", [(par1), (par2), (par3)]) +@pytest.mark.parametrize("par", [(par1), (par2)]) def test_MRI2D_mask_array(par): """Dot-test and forward/adjoint for MRI2D operator with numpy array mask""" np.random.seed(10) From 29a547cb4bd04e74708b1227a0633ad737431c8e Mon Sep 17 00:00:00 2001 From: mrava87 Date: Wed, 21 Jan 2026 21:37:58 +0000 Subject: [PATCH 5/7] feat: fix MRI2D to work with different backends --- pylops/medical/mri.py | 24 +++++++-- pytests/test_mri.py | 114 +++++++++++++++++++++++------------------- 2 files changed, 83 insertions(+), 55 deletions(-) diff --git a/pylops/medical/mri.py b/pylops/medical/mri.py index 6b3ec0c2..9e301c93 100644 --- a/pylops/medical/mri.py +++ b/pylops/medical/mri.py @@ -2,6 +2,7 @@ "MRI2D", ] +import warnings from typing import Literal, Optional, Union import numpy as np @@ -9,6 +10,7 @@ from pylops import LinearOperator from pylops.basicoperators import Diagonal, Restriction from pylops.signalprocessing import FFT2D, Bilinear +from pylops.utils.backend import get_module from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray @@ -41,7 +43,10 @@ class MRI2D(LinearOperator): perc_center: :obj:`float` Percentage of total lines to retain in the center. engine : :obj:`str`, optional - Engine used for fft computation (``numpy`` or ``fftw``) + Engine used for computation (``numpy`` or ``cupy`` or ``jax``). + fft_engine : :obj:`str`, optional + Engine used for fft computation (``numpy`` or ``scipy`` or ``mkl_fft``). + If ``engine='cupy'``, fft_engine is forced to ``'numpy'``. dtype : :obj:`str`, optional Type of elements in input array. name : :obj:`str`, optional @@ -83,6 +88,7 @@ def __init__( nlines: Optional[int] = None, perc_center: float = 0.1, engine: str = "numpy", + fft_engine: str = "numpy", dtype: DTypeLike = "complex128", name: str = "M", **kwargs_fft, @@ -90,8 +96,12 @@ def __init__( self.dims = dims self._mask_type = mask if isinstance(mask, str) else "mask" self.engine = engine + self.fft_engine = fft_engine # Validate inputs + if engine != "numpy" and fft_engine != "numpy": + warnings.warn(f"When engine='{engine}', fft_engine is forced to 'numpy'") + self.fft_engine = "numpy" if isinstance(mask, str) and mask not in ( "vertical-reg", "vertical-uni", @@ -104,8 +114,8 @@ def __init__( if isinstance(mask, str) and mask == "vertical-reg" and perc_center > 0.0: raise ValueError("perc_center must be 0.0 when using 'vertical-reg' mask") - if self.engine not in ["numpy", "scipy", "mkl_fft"]: - raise ValueError("engine must be 'numpy', 'scipy', or 'mkl_fft'") + if self.fft_engine not in ["numpy", "scipy", "mkl_fft"]: + raise ValueError("fft_engine must be 'numpy', 'scipy', or 'mkl_fft'") if self._mask_type == "mask": self.mask = mask @@ -120,11 +130,17 @@ def __init__( self.mask = self._radial_mask( dims, nlines, uniform=True if "reg" in self._mask_type else False ) + + # Convert mask to appropriate backend + ncp = get_module(self.engine) + self.mask = ncp.asarray(self.mask) + + # Create operator self.ROp, Op = self._calc_op( dims=dims, mask_type=mask if isinstance(mask, str) else "mask", mask=self.mask, - fft_engine=engine, + fft_engine=self.fft_engine, dtype=dtype, **kwargs_fft, ) diff --git a/pytests/test_mri.py b/pytests/test_mri.py index b84a5def..acdd59b7 100644 --- a/pytests/test_mri.py +++ b/pytests/test_mri.py @@ -19,13 +19,37 @@ "ny": 32, "nx": 64, "dtype": "complex128", - "engine": "numpy", + "fft_engine": "numpy", } # even input, complex dtype, numpy engine par2 = { + "ny": 32, + "nx": 64, + "dtype": "complex128", + "fft_engine": "scipy", +} # even input, complex64 dtype, scipy engine +par3 = { + "ny": 32, + "nx": 64, + "dtype": "complex128", + "fft_engine": "mkl_fft", +} # even input, complex dtype, mkl_fft engine +par4 = { + "ny": 33, + "nx": 65, + "dtype": "complex128", + "fft_engine": "numpy", +} # odd input, complex64 dtype, numpy engine +par5 = { + "ny": 33, + "nx": 65, + "dtype": "complex128", + "fft_engine": "scipy", +} # even input, complex dtype, scipy engine +par6 = { "ny": 33, "nx": 65, "dtype": "complex128", - "engine": "scipy", + "fft_engine": "mkl_fft", } # odd input, complex64 dtype, scipy engine @@ -49,7 +73,8 @@ def test_MRI2D_invalid_engine(): MRI2D( dims=(32, 64), mask=mask, - engine="invalid-engine", + engine=backend, + fft_engine="invalid-engine", dtype="complex128", ) @@ -78,9 +103,11 @@ def test_MRI2D_vertical_mask_invalid_nlines(): ) -@pytest.mark.parametrize("par", [(par1), (par2)]) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_mask_array(par): """Dot-test and forward/adjoint for MRI2D operator with numpy array mask""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") np.random.seed(10) # Create a random mask @@ -93,7 +120,8 @@ def test_MRI2D_mask_array(par): Mop = MRI2D( dims=(par["ny"], par["nx"]), mask=mask, - engine=par["engine"], + engine=backend, + fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -114,9 +142,11 @@ def test_MRI2D_mask_array(par): assert xadj.shape[0] == par["ny"] * par["nx"] -@pytest.mark.parametrize("par", [(par1), (par2)]) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_vertical_reg(par): """Dot-test and forward/adjoint for MRI2D operator with vertical-reg mask""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") np.random.seed(10) nlines = 16 @@ -125,7 +155,8 @@ def test_MRI2D_vertical_reg(par): mask="vertical-reg", nlines=nlines, perc_center=0.0, - engine=par["engine"], + engine=backend, + fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -146,9 +177,11 @@ def test_MRI2D_vertical_reg(par): assert len(Mop.mask) == nlines -@pytest.mark.parametrize("par", [(par1), (par2)]) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_vertical_mask_regularity(par): """Test that vertical-reg mask produces regularly spaced lines""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") np.random.seed(10) nlines = 8 @@ -157,7 +190,8 @@ def test_MRI2D_vertical_mask_regularity(par): mask="vertical-reg", nlines=nlines, perc_center=0.0, - engine=par["engine"], + engine=backend, + fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -169,9 +203,11 @@ def test_MRI2D_vertical_mask_regularity(par): assert np.allclose(steps, steps[0], atol=1) -@pytest.mark.parametrize("par", [(par1), (par2)]) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_vertical_uni(par): """Dot-test and forward/adjoint for MRI2D operator with vertical-uni mask""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") np.random.seed(10) nlines = 16 @@ -181,7 +217,8 @@ def test_MRI2D_vertical_uni(par): mask="vertical-uni", nlines=nlines, perc_center=perc_center, - engine=par["engine"], + engine=backend, + fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -203,9 +240,11 @@ def test_MRI2D_vertical_uni(par): assert len(Mop.mask) == nlines_total -@pytest.mark.parametrize("par", [(par1), (par2)]) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_vertical_uni_no_center(par): """Test MRI2D operator with vertical mask and no center lines""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") np.random.seed(10) nlines = 16 @@ -214,7 +253,8 @@ def test_MRI2D_vertical_uni_no_center(par): mask="vertical-uni", nlines=nlines, perc_center=0.0, - engine=par["engine"], + engine=backend, + fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -228,9 +268,11 @@ def test_MRI2D_vertical_uni_no_center(par): ) -@pytest.mark.parametrize("par", [(par1), (par2)]) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_radial_reg(par): """Dot-test and forward/adjoint for MRI2D operator with radial-reg mask""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") np.random.seed(10) nlines = 8 @@ -238,7 +280,8 @@ def test_MRI2D_radial_reg(par): dims=(par["ny"], par["nx"]), mask="radial-reg", nlines=nlines, - engine=par["engine"], + engine=backend, + fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -262,9 +305,11 @@ def test_MRI2D_radial_reg(par): assert Mop.mask.shape[0] == 2 # x and y coordinates -@pytest.mark.parametrize("par", [(par1), (par2)]) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_radial_uni(par): """Dot-test and forward/adjoint for MRI2D operator with radial-uni mask""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") np.random.seed(10) nlines = 8 @@ -272,7 +317,8 @@ def test_MRI2D_radial_uni(par): dims=(par["ny"], par["nx"]), mask="radial-uni", nlines=nlines, - engine=par["engine"], + engine=backend, + fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -294,37 +340,3 @@ def test_MRI2D_radial_uni(par): assert y.size == npoints assert xadj.shape == (par["ny"], par["nx"]) assert Mop.mask.shape[0] == 2 # x and y coordinates - - -@pytest.mark.skipif(not mkl_fft_enabled, reason="MKL FFT not available") -@pytest.mark.parametrize("par", [(par1), (par2)]) -def test_MRI2D_mkl_engine(par): - """Test MRI2D operator with MKL FFT engine""" - np.random.seed(10) - - mask = np.zeros((par["ny"], par["nx"]), dtype=bool) - mask[::2, ::2] = True - mask = mask.astype(par["dtype"]) - - Mop = MRI2D( - dims=(par["ny"], par["nx"]), - mask=mask, - engine="mkl_fft", - dtype=par["dtype"], - ) - - # For Diagonal mask, output size is same as input size - assert dottest( - Mop, - par["ny"] * par["nx"], - par["ny"] * par["nx"], - complexflag=2, - backend=backend, - ) - - x = np.random.normal(0, 1, (par["ny"], par["nx"])) - y = Mop * x.ravel() - xadj = Mop.H * y - - assert y.shape[0] == par["ny"] * par["nx"] - assert xadj.shape[0] == par["ny"] * par["nx"] From 9cd37d7f378195bbd52dbbd00144028b48c9fdc0 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Wed, 21 Jan 2026 21:56:49 +0000 Subject: [PATCH 6/7] minor: skip tests for cupy --- pylops/medical/mri.py | 25 ++++++++++++------- pytests/test_mri.py | 58 +++++++++++++++++++++++++------------------ 2 files changed, 50 insertions(+), 33 deletions(-) diff --git a/pylops/medical/mri.py b/pylops/medical/mri.py index 9e301c93..0fba67d5 100644 --- a/pylops/medical/mri.py +++ b/pylops/medical/mri.py @@ -43,10 +43,9 @@ class MRI2D(LinearOperator): perc_center: :obj:`float` Percentage of total lines to retain in the center. engine : :obj:`str`, optional - Engine used for computation (``numpy`` or ``cupy`` or ``jax``). + Engine used for computation (``numpy`` or ``jax``). fft_engine : :obj:`str`, optional Engine used for fft computation (``numpy`` or ``scipy`` or ``mkl_fft``). - If ``engine='cupy'``, fft_engine is forced to ``'numpy'``. dtype : :obj:`str`, optional Type of elements in input array. name : :obj:`str`, optional @@ -66,6 +65,15 @@ class MRI2D(LinearOperator): Operator contains a matrix that can be solved explicitly (``True``) or not (``False``) + Raises + ------ + ValueError + If ``mask`` is not one of the accepted strings or a numpy array. + ValueError + If ``fft_engine`` is neither ``numpy``, ``fftw``, nor ``scipy``. + ValueError + If ``perc_center`` is greater than 0 when using ``vertical-reg`` mask. + Notes ----- The MRI2D operator applies 2-dimensional Fourier transform to the model, @@ -87,8 +95,8 @@ def __init__( ], nlines: Optional[int] = None, perc_center: float = 0.1, - engine: str = "numpy", - fft_engine: str = "numpy", + engine: Literal["numpy", "jax"] = "numpy", + fft_engine: Literal["numpy", "scipy", "mkl_fft"] = "numpy", dtype: DTypeLike = "complex128", name: str = "M", **kwargs_fft, @@ -99,8 +107,8 @@ def __init__( self.fft_engine = fft_engine # Validate inputs - if engine != "numpy" and fft_engine != "numpy": - warnings.warn(f"When engine='{engine}', fft_engine is forced to 'numpy'") + if engine == "jax" and fft_engine != "numpy": + warnings.warn(f"When engine='jax', fft_engine is forced to 'numpy'") self.fft_engine = "numpy" if isinstance(mask, str) and mask not in ( "vertical-reg", @@ -111,11 +119,10 @@ def __init__( raise ValueError( "mask must be a numpy array, 'vertical-reg', 'vertical-uni', 'radial-reg', or 'radial-uni'" ) - if isinstance(mask, str) and mask == "vertical-reg" and perc_center > 0.0: - raise ValueError("perc_center must be 0.0 when using 'vertical-reg' mask") - if self.fft_engine not in ["numpy", "scipy", "mkl_fft"]: raise ValueError("fft_engine must be 'numpy', 'scipy', or 'mkl_fft'") + if isinstance(mask, str) and mask == "vertical-reg" and perc_center > 0.0: + raise ValueError("perc_center must be 0.0 when using 'vertical-reg' mask") if self._mask_type == "mask": self.mask = mask diff --git a/pytests/test_mri.py b/pytests/test_mri.py index acdd59b7..739bb188 100644 --- a/pytests/test_mri.py +++ b/pytests/test_mri.py @@ -1,15 +1,6 @@ import os -if int(os.environ.get("TEST_CUPY_PYLOPS", 0)): - import cupy as np - from cupy.testing import assert_array_almost_equal, assert_array_equal - - backend = "cupy" -else: - import numpy as np - from numpy.testing import assert_array_almost_equal, assert_array_equal - - backend = "numpy" +import numpy as np import pytest from pylops.medical import MRI2D @@ -53,6 +44,9 @@ } # odd input, complex64 dtype, scipy engine +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) def test_MRI2D_invalid_mask(): """Test MRI2D operator with invalid mask string""" with pytest.raises(ValueError, match="mask must be"): @@ -64,6 +58,9 @@ def test_MRI2D_invalid_mask(): ) +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) def test_MRI2D_invalid_engine(): """Test MRI2D operator with invalid engine""" mask = np.zeros((32, 64), dtype="complex128") @@ -73,12 +70,14 @@ def test_MRI2D_invalid_engine(): MRI2D( dims=(32, 64), mask=mask, - engine=backend, fft_engine="invalid-engine", dtype="complex128", ) +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) def test_MRI2D_vertical_reg_invalid_perc_center(): """Test MRI2D operator with vertical-reg mask and non-zero perc_center""" with pytest.raises(ValueError, match="perc_center must be 0.0"): @@ -91,6 +90,9 @@ def test_MRI2D_vertical_reg_invalid_perc_center(): ) +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) def test_MRI2D_vertical_mask_invalid_nlines(): """Test MRI2D operator with vertical mask and invalid nlines""" with pytest.raises(ValueError, match="nlines and perc_center"): @@ -103,6 +105,9 @@ def test_MRI2D_vertical_mask_invalid_nlines(): ) +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) @pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_mask_array(par): """Dot-test and forward/adjoint for MRI2D operator with numpy array mask""" @@ -120,7 +125,6 @@ def test_MRI2D_mask_array(par): Mop = MRI2D( dims=(par["ny"], par["nx"]), mask=mask, - engine=backend, fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -131,7 +135,6 @@ def test_MRI2D_mask_array(par): par["ny"] * par["nx"], par["ny"] * par["nx"], complexflag=2, - backend=backend, ) x = np.random.normal(0, 1, (par["ny"], par["nx"])) @@ -142,6 +145,9 @@ def test_MRI2D_mask_array(par): assert xadj.shape[0] == par["ny"] * par["nx"] +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) @pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_vertical_reg(par): """Dot-test and forward/adjoint for MRI2D operator with vertical-reg mask""" @@ -155,7 +161,6 @@ def test_MRI2D_vertical_reg(par): mask="vertical-reg", nlines=nlines, perc_center=0.0, - engine=backend, fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -165,7 +170,6 @@ def test_MRI2D_vertical_reg(par): par["ny"] * nlines, par["ny"] * par["nx"], complexflag=2, - backend=backend, ) x = np.random.normal(0, 1, (par["ny"], par["nx"])) @@ -177,6 +181,9 @@ def test_MRI2D_vertical_reg(par): assert len(Mop.mask) == nlines +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) @pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_vertical_mask_regularity(par): """Test that vertical-reg mask produces regularly spaced lines""" @@ -190,7 +197,6 @@ def test_MRI2D_vertical_mask_regularity(par): mask="vertical-reg", nlines=nlines, perc_center=0.0, - engine=backend, fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -203,6 +209,9 @@ def test_MRI2D_vertical_mask_regularity(par): assert np.allclose(steps, steps[0], atol=1) +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) @pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_vertical_uni(par): """Dot-test and forward/adjoint for MRI2D operator with vertical-uni mask""" @@ -217,7 +226,6 @@ def test_MRI2D_vertical_uni(par): mask="vertical-uni", nlines=nlines, perc_center=perc_center, - engine=backend, fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -227,7 +235,6 @@ def test_MRI2D_vertical_uni(par): par["ny"] * (nlines + int(perc_center * par["nx"])), par["ny"] * par["nx"], complexflag=2, - backend=backend, ) x = np.random.normal(0, 1, (par["ny"], par["nx"])) @@ -240,6 +247,9 @@ def test_MRI2D_vertical_uni(par): assert len(Mop.mask) == nlines_total +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) @pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_vertical_uni_no_center(par): """Test MRI2D operator with vertical mask and no center lines""" @@ -253,7 +263,6 @@ def test_MRI2D_vertical_uni_no_center(par): mask="vertical-uni", nlines=nlines, perc_center=0.0, - engine=backend, fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -264,10 +273,12 @@ def test_MRI2D_vertical_uni_no_center(par): par["ny"] * nlines, par["ny"] * par["nx"], complexflag=2, - backend=backend, ) +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) @pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_radial_reg(par): """Dot-test and forward/adjoint for MRI2D operator with radial-reg mask""" @@ -280,7 +291,6 @@ def test_MRI2D_radial_reg(par): dims=(par["ny"], par["nx"]), mask="radial-reg", nlines=nlines, - engine=backend, fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -293,7 +303,6 @@ def test_MRI2D_radial_reg(par): npoints, par["ny"] * par["nx"], complexflag=2, - backend=backend, ) x = np.random.normal(0, 1, (par["ny"], par["nx"])) @@ -305,6 +314,9 @@ def test_MRI2D_radial_reg(par): assert Mop.mask.shape[0] == 2 # x and y coordinates +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) @pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) def test_MRI2D_radial_uni(par): """Dot-test and forward/adjoint for MRI2D operator with radial-uni mask""" @@ -317,7 +329,6 @@ def test_MRI2D_radial_uni(par): dims=(par["ny"], par["nx"]), mask="radial-uni", nlines=nlines, - engine=backend, fft_engine=par["fft_engine"], dtype=par["dtype"], ) @@ -330,7 +341,6 @@ def test_MRI2D_radial_uni(par): npoints, par["ny"] * par["nx"], complexflag=2, - backend=backend, ) x = np.random.normal(0, 1, (par["ny"], par["nx"])) From dd22efb71800f5c1e0ff9a6aa74b26488d267a7f Mon Sep 17 00:00:00 2001 From: mrava87 Date: Wed, 21 Jan 2026 22:16:14 +0000 Subject: [PATCH 7/7] doc: added example with MRI2D --- docs/source/api/index.rst | 1 + docs/source/gpu.rst | 4 ++ examples/plot_avo.py | 2 +- examples/plot_mri.py | 104 ++++++++++++++++++++++++++++++++++++++ pylops/medical/mri.py | 2 +- 5 files changed, 111 insertions(+), 2 deletions(-) create mode 100755 examples/plot_mri.py diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 29ced722..374657f5 100755 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -171,6 +171,7 @@ Medical imaging :toctree: generated/ CT2D + MRI2D Solvers diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst index 64695ae9..957f5413 100755 --- a/docs/source/gpu.rst +++ b/docs/source/gpu.rst @@ -602,6 +602,10 @@ Medical: - |:white_check_mark:| - |:white_check_mark:| - |:white_check_mark:| + * - :class:`pylops.medical.mri.MRI2D` + - |:white_check_mark:| + - |:red_circle:| + - |:white_check_mark:| .. warning:: diff --git a/examples/plot_avo.py b/examples/plot_avo.py index a3a30636..a2f121b2 100755 --- a/examples/plot_avo.py +++ b/examples/plot_avo.py @@ -1,6 +1,6 @@ r""" AVO modelling -=================== +============= This example shows how to create pre-stack angle gathers using the :py:class:`pylops.avo.avo.AVOLinearModelling` operator. """ diff --git a/examples/plot_mri.py b/examples/plot_mri.py new file mode 100755 index 00000000..608bdbed --- /dev/null +++ b/examples/plot_mri.py @@ -0,0 +1,104 @@ +r""" +MRI modelling +============= +This example shows how to use the :py:class:`pylops.medical.mri.MRI2D` operator +to create K-space undersampled MRI data. +""" +import matplotlib.pyplot as plt +import numpy as np + +import pylops + +plt.close("all") +np.random.seed(0) + +############################################################################### +# Let"s start by loading the Shepp-Logan phantom model. +x = np.load("../testdata/optimization/shepp_logan_phantom.npy") +x = x / x.max() +nx, ny = x.shape + +############################################################################### +# Next, we create a mask to simulate undersampling in K-space and apply it to +# the phantom model. + +# Passing mask as array +mask = np.zeros((nx, ny)) +mask[:, np.random.randint(0, ny, 2 * ny // 3)] = 1 +mask[:, ny // 2 - 20 : ny // 2 + 10] = 1 + +Mop = pylops.medical.MRI2D(dims=(nx, ny), mask=mask) + +d = Mop @ x +x_adj = Mop.H @ d + +fig, axs = plt.subplots(1, 3, figsize=(12, 5)) +axs[0].imshow(x, cmap="gray", vmin=0, vmax=1) +axs[0].set_title("Original Image") +axs[1].imshow(np.abs(d), cmap="jet", vmin=0, vmax=1) +axs[1].set_title("K-space Data") +axs[2].imshow(x_adj.real, cmap="gray", vmin=0, vmax=1) +axs[2].set_title("Adjoint Reconstruction") +fig.tight_layout() + +############################################################################### +# Alternatively, we can create the same mask by specifying a sampling pattern +# using the ``mask`` keyword argument. Here, we create a ``vertical-reg`` mask +# that samples K-space lines in the vertical direction with a regular pattern. + +# Vertical uniform with center +Mop = pylops.medical.MRI2D( + dims=(nx, ny), mask="vertical-reg", nlines=ny // 2, perc_center=0.0 +) + +d = Mop @ x +x_adj = (Mop.H @ d).reshape(nx, ny) + +fig, axs = plt.subplots(1, 3, figsize=(12, 5)) +axs[0].imshow(x, cmap="gray", vmin=0, vmax=1) +axs[0].set_title("Original Image") +axs[1].imshow(np.abs(Mop.ROp.H @ d).reshape(nx, ny), cmap="jet", vmin=0, vmax=1) +axs[1].set_title("K-space Data") +axs[2].imshow(x_adj.real, cmap="gray", vmin=0, vmax=1) +axs[2].set_title("Adjoint Reconstruction") +fig.tight_layout() + +############################################################################### +# Similarly, we can create a ``vertical-uni`` mask that randomly samples +# K-space lines in the vertical direction. + +# Vertical uniform with center +Mop = pylops.medical.MRI2D( + dims=(nx, ny), mask="vertical-uni", nlines=40, perc_center=0.1 +) + +d = Mop @ x +x_adj = (Mop.H @ d).reshape(nx, ny) + +fig, axs = plt.subplots(1, 3, figsize=(12, 5)) +axs[0].imshow(x, cmap="gray", vmin=0, vmax=1) +axs[0].set_title("Original Image") +axs[1].imshow(np.abs(Mop.ROp.H @ d).reshape(nx, ny), cmap="jet", vmin=0, vmax=1) +axs[1].set_title("K-space Data") +axs[2].imshow(x_adj.real, cmap="gray", vmin=0, vmax=1) +axs[2].set_title("Adjoint Reconstruction") +fig.tight_layout() + +############################################################################### +# Finally, we can create a sampling pattern with radial lines using the +# ``radial-uni`` (or ``radial-reg``) option. + +# Radial uniform +Mop = pylops.medical.MRI2D(dims=(nx, ny), mask="radial-uni", nlines=40) + +d = Mop @ x +x_adj = (Mop.H @ d).reshape(nx, ny) + +fig, axs = plt.subplots(1, 3, figsize=(12, 5)) +axs[0].imshow(x, cmap="gray", vmin=0, vmax=1) +axs[0].set_title("Original Image") +axs[1].imshow(np.abs(Mop.ROp.H @ d).reshape(nx, ny), cmap="jet", vmin=0, vmax=1) +axs[1].set_title("K-space Data") +axs[2].imshow(x_adj.real, cmap="gray", vmin=0, vmax=1) +axs[2].set_title("Adjoint Reconstruction") +fig.tight_layout() diff --git a/pylops/medical/mri.py b/pylops/medical/mri.py index 0fba67d5..a9900fd3 100644 --- a/pylops/medical/mri.py +++ b/pylops/medical/mri.py @@ -108,7 +108,7 @@ def __init__( # Validate inputs if engine == "jax" and fft_engine != "numpy": - warnings.warn(f"When engine='jax', fft_engine is forced to 'numpy'") + warnings.warn("When engine='jax', fft_engine is forced to 'numpy'") self.fft_engine = "numpy" if isinstance(mask, str) and mask not in ( "vertical-reg",