diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 29ced722..374657f5 100755 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -171,6 +171,7 @@ Medical imaging :toctree: generated/ CT2D + MRI2D Solvers diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst index 64695ae9..957f5413 100755 --- a/docs/source/gpu.rst +++ b/docs/source/gpu.rst @@ -602,6 +602,10 @@ Medical: - |:white_check_mark:| - |:white_check_mark:| - |:white_check_mark:| + * - :class:`pylops.medical.mri.MRI2D` + - |:white_check_mark:| + - |:red_circle:| + - |:white_check_mark:| .. warning:: diff --git a/examples/plot_avo.py b/examples/plot_avo.py index a3a30636..a2f121b2 100755 --- a/examples/plot_avo.py +++ b/examples/plot_avo.py @@ -1,6 +1,6 @@ r""" AVO modelling -=================== +============= This example shows how to create pre-stack angle gathers using the :py:class:`pylops.avo.avo.AVOLinearModelling` operator. """ diff --git a/examples/plot_mri.py b/examples/plot_mri.py new file mode 100755 index 00000000..608bdbed --- /dev/null +++ b/examples/plot_mri.py @@ -0,0 +1,104 @@ +r""" +MRI modelling +============= +This example shows how to use the :py:class:`pylops.medical.mri.MRI2D` operator +to create K-space undersampled MRI data. +""" +import matplotlib.pyplot as plt +import numpy as np + +import pylops + +plt.close("all") +np.random.seed(0) + +############################################################################### +# Let"s start by loading the Shepp-Logan phantom model. +x = np.load("../testdata/optimization/shepp_logan_phantom.npy") +x = x / x.max() +nx, ny = x.shape + +############################################################################### +# Next, we create a mask to simulate undersampling in K-space and apply it to +# the phantom model. + +# Passing mask as array +mask = np.zeros((nx, ny)) +mask[:, np.random.randint(0, ny, 2 * ny // 3)] = 1 +mask[:, ny // 2 - 20 : ny // 2 + 10] = 1 + +Mop = pylops.medical.MRI2D(dims=(nx, ny), mask=mask) + +d = Mop @ x +x_adj = Mop.H @ d + +fig, axs = plt.subplots(1, 3, figsize=(12, 5)) +axs[0].imshow(x, cmap="gray", vmin=0, vmax=1) +axs[0].set_title("Original Image") +axs[1].imshow(np.abs(d), cmap="jet", vmin=0, vmax=1) +axs[1].set_title("K-space Data") +axs[2].imshow(x_adj.real, cmap="gray", vmin=0, vmax=1) +axs[2].set_title("Adjoint Reconstruction") +fig.tight_layout() + +############################################################################### +# Alternatively, we can create the same mask by specifying a sampling pattern +# using the ``mask`` keyword argument. Here, we create a ``vertical-reg`` mask +# that samples K-space lines in the vertical direction with a regular pattern. + +# Vertical uniform with center +Mop = pylops.medical.MRI2D( + dims=(nx, ny), mask="vertical-reg", nlines=ny // 2, perc_center=0.0 +) + +d = Mop @ x +x_adj = (Mop.H @ d).reshape(nx, ny) + +fig, axs = plt.subplots(1, 3, figsize=(12, 5)) +axs[0].imshow(x, cmap="gray", vmin=0, vmax=1) +axs[0].set_title("Original Image") +axs[1].imshow(np.abs(Mop.ROp.H @ d).reshape(nx, ny), cmap="jet", vmin=0, vmax=1) +axs[1].set_title("K-space Data") +axs[2].imshow(x_adj.real, cmap="gray", vmin=0, vmax=1) +axs[2].set_title("Adjoint Reconstruction") +fig.tight_layout() + +############################################################################### +# Similarly, we can create a ``vertical-uni`` mask that randomly samples +# K-space lines in the vertical direction. + +# Vertical uniform with center +Mop = pylops.medical.MRI2D( + dims=(nx, ny), mask="vertical-uni", nlines=40, perc_center=0.1 +) + +d = Mop @ x +x_adj = (Mop.H @ d).reshape(nx, ny) + +fig, axs = plt.subplots(1, 3, figsize=(12, 5)) +axs[0].imshow(x, cmap="gray", vmin=0, vmax=1) +axs[0].set_title("Original Image") +axs[1].imshow(np.abs(Mop.ROp.H @ d).reshape(nx, ny), cmap="jet", vmin=0, vmax=1) +axs[1].set_title("K-space Data") +axs[2].imshow(x_adj.real, cmap="gray", vmin=0, vmax=1) +axs[2].set_title("Adjoint Reconstruction") +fig.tight_layout() + +############################################################################### +# Finally, we can create a sampling pattern with radial lines using the +# ``radial-uni`` (or ``radial-reg``) option. + +# Radial uniform +Mop = pylops.medical.MRI2D(dims=(nx, ny), mask="radial-uni", nlines=40) + +d = Mop @ x +x_adj = (Mop.H @ d).reshape(nx, ny) + +fig, axs = plt.subplots(1, 3, figsize=(12, 5)) +axs[0].imshow(x, cmap="gray", vmin=0, vmax=1) +axs[0].set_title("Original Image") +axs[1].imshow(np.abs(Mop.ROp.H @ d).reshape(nx, ny), cmap="jet", vmin=0, vmax=1) +axs[1].set_title("K-space Data") +axs[2].imshow(x_adj.real, cmap="gray", vmin=0, vmax=1) +axs[2].set_title("Adjoint Reconstruction") +fig.tight_layout() diff --git a/pylops/medical/__init__.py b/pylops/medical/__init__.py index b11b398f..2771aac3 100644 --- a/pylops/medical/__init__.py +++ b/pylops/medical/__init__.py @@ -8,11 +8,12 @@ A list of operators present in pylops.medical: CT2D 2D Computerized Tomography. + MRI2D 2D Magnetic Resonance Imaging. """ from .ct import * +from .mri import * -__all__ = [ - "CT2D", -] + +__all__ = ["CT2D", "MRI2D"] diff --git a/pylops/medical/mri.py b/pylops/medical/mri.py new file mode 100644 index 00000000..a9900fd3 --- /dev/null +++ b/pylops/medical/mri.py @@ -0,0 +1,279 @@ +__all__ = [ + "MRI2D", +] + +import warnings +from typing import Literal, Optional, Union + +import numpy as np + +from pylops import LinearOperator +from pylops.basicoperators import Diagonal, Restriction +from pylops.signalprocessing import FFT2D, Bilinear +from pylops.utils.backend import get_module +from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray + + +class MRI2D(LinearOperator): + r"""2D Magnetic Resonance Imaging + + Apply 2D Magnetic Resonance Imaging operator to obtain a k-space data (i.e., + undersampled Fourier representation of the model). + + Parameters + ---------- + dims : :obj:`list` or :obj:`int` + Number of samples for each dimension. Must be 2-dimensional and of size + :math:`n_y \times n_x` + mask : :obj:`str` or :obj:`numpy.ndarray` + Mask to be applied in the Fourier domain: + + - :obj:`numpy.ndarray`: a 2-dimensional array of size :math:`n_y \times n_x` + with 1 in the selected locations; + - ``vertical-reg``: mask with vertical lines (regularly sampled around the + second dimension); + - ``vertical-uni``: mask with vertical lines (irregularly sampled around the + second dimension, with lines drawn from a uniform distribution); + - ``radial-reg``: mask with radial lines (regularly sampled around the + :math:`-\pi/\pi` angles); + - ``radial-uni``: mask with radial lines (irregularly sampled around the + :math:`-\pi/\pi` angles, with angles drawn from a uniform distribution); + nlines: :obj:`str` + Number of lines in the k-space. + perc_center: :obj:`float` + Percentage of total lines to retain in the center. + engine : :obj:`str`, optional + Engine used for computation (``numpy`` or ``jax``). + fft_engine : :obj:`str`, optional + Engine used for fft computation (``numpy`` or ``scipy`` or ``mkl_fft``). + dtype : :obj:`str`, optional + Type of elements in input array. + name : :obj:`str`, optional + Name of operator (to be used by :func:`pylops.utils.describe.describe`) + **kwargs_fft + Arbitrary keyword arguments to be passed to the selected fft method + + Attributes + ---------- + mask : :obj:`numpy.ndarray` + Mask applied in the Fourier domain. + ROp : :obj:`pylops.Restriction` or :obj:`pylops.Diagonal` or :obj:`pylops.signalprocessing.Bilinear` + Operator that applies the mask in the Fourier domain. + shape : :obj:`tuple` + Operator shape + explicit : :obj:`bool` + Operator contains a matrix that can be solved + explicitly (``True``) or not (``False``) + + Raises + ------ + ValueError + If ``mask`` is not one of the accepted strings or a numpy array. + ValueError + If ``fft_engine`` is neither ``numpy``, ``fftw``, nor ``scipy``. + ValueError + If ``perc_center`` is greater than 0 when using ``vertical-reg`` mask. + + Notes + ----- + The MRI2D operator applies 2-dimensional Fourier transform to the model, + followed by a subsampling with a given ``mask``: + + .. math:: + \mathbf{d} = \mathbf{R} \mathbf{F}_{k} \mathbf{m} + + where :math:`\mathbf{F}_{k}` is the 2-dimensional Fourier transform and + :math:`\mathbf{R}` is the mask. + + """ + + def __init__( + self, + dims: InputDimsLike, + mask: Union[ + Literal["vertical-reg", "vertical-uni", "radial-reg", "radial-uni"], NDArray + ], + nlines: Optional[int] = None, + perc_center: float = 0.1, + engine: Literal["numpy", "jax"] = "numpy", + fft_engine: Literal["numpy", "scipy", "mkl_fft"] = "numpy", + dtype: DTypeLike = "complex128", + name: str = "M", + **kwargs_fft, + ) -> None: + self.dims = dims + self._mask_type = mask if isinstance(mask, str) else "mask" + self.engine = engine + self.fft_engine = fft_engine + + # Validate inputs + if engine == "jax" and fft_engine != "numpy": + warnings.warn("When engine='jax', fft_engine is forced to 'numpy'") + self.fft_engine = "numpy" + if isinstance(mask, str) and mask not in ( + "vertical-reg", + "vertical-uni", + "radial-reg", + "radial-uni", + ): + raise ValueError( + "mask must be a numpy array, 'vertical-reg', 'vertical-uni', 'radial-reg', or 'radial-uni'" + ) + if self.fft_engine not in ["numpy", "scipy", "mkl_fft"]: + raise ValueError("fft_engine must be 'numpy', 'scipy', or 'mkl_fft'") + if isinstance(mask, str) and mask == "vertical-reg" and perc_center > 0.0: + raise ValueError("perc_center must be 0.0 when using 'vertical-reg' mask") + + if self._mask_type == "mask": + self.mask = mask + elif "vertical" in self._mask_type: + self.mask = self._vertical_mask( + dims, + nlines, + perc_center, + uniform=True if "reg" in self._mask_type else False, + ) + elif "radial" in self._mask_type: + self.mask = self._radial_mask( + dims, nlines, uniform=True if "reg" in self._mask_type else False + ) + + # Convert mask to appropriate backend + ncp = get_module(self.engine) + self.mask = ncp.asarray(self.mask) + + # Create operator + self.ROp, Op = self._calc_op( + dims=dims, + mask_type=mask if isinstance(mask, str) else "mask", + mask=self.mask, + fft_engine=self.fft_engine, + dtype=dtype, + **kwargs_fft, + ) + super().__init__(Op=Op, name=name) + + @staticmethod + def _vertical_mask( + dims: InputDimsLike, nlines: int, perc_center: float, uniform: bool = True + ) -> NDArray: + """Create vertical mask""" + nlines_center = int(perc_center * dims[1]) + if (nlines + nlines_center) > dims[1]: + raise ValueError( + "nlines and perc_center produce a number of lines " + "greater than the total number of lines of the k-space" + f"({nlines + nlines_center}>{dims[1]})" + ) + + if nlines_center == 0: + # No lines from the center + if uniform: + step = dims[1] // nlines + mask = np.arange(0, dims[1], step)[:nlines] + else: + rng = np.random.default_rng() + mask = rng.choice(np.arange(dims[1]), nlines, replace=False) + else: + # Lines taken from the center + istart_center = dims[1] // 2 - nlines_center // 2 + iend_center = dims[1] // 2 + nlines_center // 2 + (nlines_center % 2) + ilines_center = np.arange(istart_center, iend_center) + + # Other lines + if uniform: + nlines_left = nlines // 2 + nlines % 2 + step_left = istart_center // nlines_left + ilines_left = np.arange(0, istart_center, step_left)[:nlines_left] + nlines_right = nlines // 2 + step_right = (dims[1] - iend_center) // nlines_left + ilines_right = np.arange(iend_center, dims[1], step_right)[ + :nlines_right + ] + mask = np.sort(np.hstack((ilines_left, ilines_center, ilines_right))) + else: + rng = np.random.default_rng() + ilines_other = np.hstack( + (np.arange(0, istart_center), np.arange(iend_center, dims[1])) + ) + ilines_other = rng.choice(ilines_other, nlines, replace=False) + mask = np.sort(np.hstack((ilines_center, ilines_other))) + return mask + + @staticmethod + def _radial_mask(dims: InputDimsLike, nlines: int, uniform: bool = True) -> NDArray: + """Create radial mask""" + npoints_per_line = dims[1] - 1 + + # Define angles + if uniform: + thetas = np.linspace(0, np.pi, nlines, endpoint=False) + else: + rng = np.random.default_rng() + thetas = rng.uniform(-np.pi, np.pi, nlines) + + # Create lines + lines = [] + for theta in thetas: + if theta == np.pi / 2: + # Create vertical line + xline = np.zeros(npoints_per_line) + yline = np.linspace( + -dims[1] // 2 + 1, dims[1] // 2 - 1, npoints_per_line, endpoint=True + ) + elif np.tan(theta) >= 0: + # Create lines for positive angles + xmax = min(dims[1] // 2, (dims[0] // 2) / np.tan(theta)) + xline = np.linspace( + -xmax, + min(xmax, dims[0] // 2 - 1 - (dims[0] + 1) % 2), + npoints_per_line, + endpoint=True, + ) + yline = np.tan(theta) * xline + else: + # Create lines for negative angles + xmin = max(-dims[1] // 2 + 1, (dims[0] // 2) / np.tan(theta)) + xline = np.linspace( + xmin, min(-xmin, dims[0] // 2 - 1), npoints_per_line, endpoint=True + ) + yline = np.tan(theta) * xline + xline, yline = xline + dims[0] // 2, yline + dims[1] // 2 + lines.append(np.vstack((xline, yline))) + mask = np.concatenate(lines, axis=1) + mask = mask[:, mask[0] < dims[0] - 1] + mask = mask[:, mask[1] < dims[1] - 1] + mask = np.unique(mask, axis=1) + return mask + + def _matvec(self, x: NDArray) -> NDArray: + return super()._matvec(x) + + def _rmatvec(self, x: NDArray) -> NDArray: + return super()._rmatvec(x) + + @staticmethod + def _calc_op( + dims: InputDimsLike, + mask_type: "str", + mask: NDArray, + fft_engine: float, + dtype: DTypeLike, + **kwargs_fft, + ): + """Calculate MRI operator""" + fop = FFT2D( + dims, + nffts=dims, + fftshift_after=True, + engine=fft_engine, + dtype=dtype, + **kwargs_fft, + ) + if mask_type == "mask": + rop = Diagonal(mask, dtype=dtype) + elif "vertical" in mask_type: + rop = Restriction(dims, mask, axis=-1, forceflat=True, dtype=dtype) + elif "radial" in mask_type: + rop = Bilinear(mask, dims, dtype=dtype) + return rop, rop @ fop diff --git a/pytests/test_mri.py b/pytests/test_mri.py new file mode 100644 index 00000000..739bb188 --- /dev/null +++ b/pytests/test_mri.py @@ -0,0 +1,352 @@ +import os + +import numpy as np +import pytest + +from pylops.medical import MRI2D +from pylops.utils import dottest, mkl_fft_enabled + +par1 = { + "ny": 32, + "nx": 64, + "dtype": "complex128", + "fft_engine": "numpy", +} # even input, complex dtype, numpy engine +par2 = { + "ny": 32, + "nx": 64, + "dtype": "complex128", + "fft_engine": "scipy", +} # even input, complex64 dtype, scipy engine +par3 = { + "ny": 32, + "nx": 64, + "dtype": "complex128", + "fft_engine": "mkl_fft", +} # even input, complex dtype, mkl_fft engine +par4 = { + "ny": 33, + "nx": 65, + "dtype": "complex128", + "fft_engine": "numpy", +} # odd input, complex64 dtype, numpy engine +par5 = { + "ny": 33, + "nx": 65, + "dtype": "complex128", + "fft_engine": "scipy", +} # even input, complex dtype, scipy engine +par6 = { + "ny": 33, + "nx": 65, + "dtype": "complex128", + "fft_engine": "mkl_fft", +} # odd input, complex64 dtype, scipy engine + + +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) +def test_MRI2D_invalid_mask(): + """Test MRI2D operator with invalid mask string""" + with pytest.raises(ValueError, match="mask must be"): + MRI2D( + dims=(32, 64), + mask="invalid-mask", + nlines=16, + dtype="complex128", + ) + + +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) +def test_MRI2D_invalid_engine(): + """Test MRI2D operator with invalid engine""" + mask = np.zeros((32, 64), dtype="complex128") + mask[::2, ::2] = 1.0 + + with pytest.raises(ValueError, match="engine must be"): + MRI2D( + dims=(32, 64), + mask=mask, + fft_engine="invalid-engine", + dtype="complex128", + ) + + +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) +def test_MRI2D_vertical_reg_invalid_perc_center(): + """Test MRI2D operator with vertical-reg mask and non-zero perc_center""" + with pytest.raises(ValueError, match="perc_center must be 0.0"): + MRI2D( + dims=(32, 64), + mask="vertical-reg", + nlines=16, + perc_center=0.1, + dtype="complex128", + ) + + +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) +def test_MRI2D_vertical_mask_invalid_nlines(): + """Test MRI2D operator with vertical mask and invalid nlines""" + with pytest.raises(ValueError, match="nlines and perc_center"): + MRI2D( + dims=(32, 64), + mask="vertical-uni", + nlines=60, + perc_center=0.5, + dtype="complex128", + ) + + +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) +def test_MRI2D_mask_array(par): + """Dot-test and forward/adjoint for MRI2D operator with numpy array mask""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") + np.random.seed(10) + + # Create a random mask + mask = np.zeros((par["ny"], par["nx"]), dtype=bool) + nselected = int(par["ny"] * par["nx"] * 0.3) + indices = np.random.choice(par["ny"] * par["nx"], nselected, replace=False) + mask.flat[indices] = True + mask = mask.astype(par["dtype"]) + + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask=mask, + fft_engine=par["fft_engine"], + dtype=par["dtype"], + ) + + # For Diagonal mask, output size is same as input size + assert dottest( + Mop, + par["ny"] * par["nx"], + par["ny"] * par["nx"], + complexflag=2, + ) + + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + y = Mop * x.ravel() + xadj = Mop.H * y + + assert y.shape[0] == par["ny"] * par["nx"] + assert xadj.shape[0] == par["ny"] * par["nx"] + + +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) +def test_MRI2D_vertical_reg(par): + """Dot-test and forward/adjoint for MRI2D operator with vertical-reg mask""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") + np.random.seed(10) + + nlines = 16 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="vertical-reg", + nlines=nlines, + perc_center=0.0, + fft_engine=par["fft_engine"], + dtype=par["dtype"], + ) + + assert dottest( + Mop, + par["ny"] * nlines, + par["ny"] * par["nx"], + complexflag=2, + ) + + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + y = Mop * x.ravel() + xadj = Mop.H * y + + assert y.shape[0] == par["ny"] * nlines + assert xadj.shape[0] == par["ny"] * par["nx"] + assert len(Mop.mask) == nlines + + +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) +def test_MRI2D_vertical_mask_regularity(par): + """Test that vertical-reg mask produces regularly spaced lines""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") + np.random.seed(10) + + nlines = 8 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="vertical-reg", + nlines=nlines, + perc_center=0.0, + fft_engine=par["fft_engine"], + dtype=par["dtype"], + ) + + mask_indices = Mop.mask + # Check that indices are regularly spaced + if len(mask_indices) > 1: + steps = np.diff(np.sort(mask_indices)) + # All steps should be approximately equal (within rounding) + assert np.allclose(steps, steps[0], atol=1) + + +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) +def test_MRI2D_vertical_uni(par): + """Dot-test and forward/adjoint for MRI2D operator with vertical-uni mask""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") + np.random.seed(10) + + nlines = 16 + perc_center = 0.1 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="vertical-uni", + nlines=nlines, + perc_center=perc_center, + fft_engine=par["fft_engine"], + dtype=par["dtype"], + ) + + assert dottest( + Mop, + par["ny"] * (nlines + int(perc_center * par["nx"])), + par["ny"] * par["nx"], + complexflag=2, + ) + + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + y = Mop * x.ravel() + xadj = Mop.H * y + + nlines_total = nlines + int(perc_center * par["nx"]) + assert y.shape[0] == par["ny"] * nlines_total + assert xadj.shape[0] == par["ny"] * par["nx"] + assert len(Mop.mask) == nlines_total + + +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) +def test_MRI2D_vertical_uni_no_center(par): + """Test MRI2D operator with vertical mask and no center lines""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") + np.random.seed(10) + + nlines = 16 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="vertical-uni", + nlines=nlines, + perc_center=0.0, + fft_engine=par["fft_engine"], + dtype=par["dtype"], + ) + + assert len(Mop.mask) == nlines + assert dottest( + Mop, + par["ny"] * nlines, + par["ny"] * par["nx"], + complexflag=2, + ) + + +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) +def test_MRI2D_radial_reg(par): + """Dot-test and forward/adjoint for MRI2D operator with radial-reg mask""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") + np.random.seed(10) + + nlines = 8 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="radial-reg", + nlines=nlines, + fft_engine=par["fft_engine"], + dtype=par["dtype"], + ) + + # For radial masks, output size depends on the number of points in the mask + npoints = Mop.mask.shape[1] + + assert dottest( + Mop, + npoints, + par["ny"] * par["nx"], + complexflag=2, + ) + + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + y = Mop * x + xadj = Mop.H * y + + assert y.size == npoints + assert xadj.shape == (par["ny"], par["nx"]) + assert Mop.mask.shape[0] == 2 # x and y coordinates + + +@pytest.mark.skipif( + int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" +) +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) +def test_MRI2D_radial_uni(par): + """Dot-test and forward/adjoint for MRI2D operator with radial-uni mask""" + if par["fft_engine"] == "mkl_fft" and not mkl_fft_enabled: + pytest.skip("mkl_fft is not installed") + np.random.seed(10) + + nlines = 8 + Mop = MRI2D( + dims=(par["ny"], par["nx"]), + mask="radial-uni", + nlines=nlines, + fft_engine=par["fft_engine"], + dtype=par["dtype"], + ) + + # For radial masks, output size depends on the number of points in the mask + npoints = Mop.mask.shape[1] + + assert dottest( + Mop, + npoints, + par["ny"] * par["nx"], + complexflag=2, + ) + + x = np.random.normal(0, 1, (par["ny"], par["nx"])) + y = Mop * x + xadj = Mop.H * y + + assert y.size == npoints + assert xadj.shape == (par["ny"], par["nx"]) + assert Mop.mask.shape[0] == 2 # x and y coordinates