diff --git a/src/aspire/image/__init__.py b/src/aspire/image/__init__.py index 431dedf87..442526fa5 100644 --- a/src/aspire/image/__init__.py +++ b/src/aspire/image/__init__.py @@ -1,3 +1,11 @@ +# isort: off +from .rotation import ( + compute_fastrotate_interp_tables, + fastrotate, + sp_rotate, +) + +# isort: on from .image import ( BasisImage, BispecImage, diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 8212033f7..071437a1e 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -10,6 +10,7 @@ import aspire.sinogram import aspire.volume +from aspire.image import fastrotate, sp_rotate from aspire.nufft import anufft, nufft from aspire.numeric import fft, xp from aspire.utils import ( @@ -158,6 +159,8 @@ class Image: ".tif": load_tiff, ".tiff": load_tiff, } + # Available image rotation functions + rotation_methods = {"fastrotate": fastrotate, "scipy": sp_rotate} def __init__(self, data, pixel_size=None, dtype=None): """ @@ -635,8 +638,81 @@ def filter(self, filter): original_stack_shape ) - def rotate(self): - raise NotImplementedError + def rotate(self, theta, method="scipy", mask=1, **kwargs): + """ + Return `Image` rotated by `theta` radians using `method`. + + Optionally applies `mask`. Note that some methods may + introduce edge artifacts, in which case users may consider + using a tighter mask (eg 0.9) or a combination of pad-crop. + + Any additional kwargs will be passed to `method`. + + :param theta: Scalar or array of length `n_images` + :param mask: Optional scalar or array mask matching `Image` shape. + Scalar will create a circular mask of prescribed radius `(0,1]`. + Array mask will be applied via elementwise multiplication. + `None` disables masking. + :param method: Optionally specify a rotation method. + Defaults to `scipy`. + :return: `Image` containing rotated image data. + """ + + original_stack_shape = self.stack_shape + im = self.stack_reshape(-1) + + # Resolve rotation method + if method not in self.rotation_methods: + raise NotImplementedError( + f"Requested `Image.rotation` method={method} not found." + f" Select from {self.rotation_methods.keys()}" + ) + # Assign the rotation method's function + # Any rotation method is expected to handle image data as a 2D array or 3D array (single stack axis). + rotation_function = self.rotation_methods[method] + + # Handle both scalar and arrays of rotation angles. + # `theta` arrays are checked to match length of images when stacks axis are flattened. + theta = np.array(theta).flatten() + if len(theta) == 1: + im = rotation_function(im._data, theta, **kwargs) + elif len(theta) == im.n_images: + rot_im = np.empty_like(im._data) + for i in range(im.n_images): + rot_im[i] = rotation_function(im._data[i], theta[i], **kwargs) + im = rot_im + else: + raise RuntimeError( + f"Length of `theta` {len(theta)} and `Image` data {im.n_images} inconsistent." + ) + + # Masking, scalar case + if mask is not None: + if np.size(mask) == 1: + # Confirm `mask` value is a sane radius + if not (0 < mask <= 1): + raise ValueError( + f"Mask radius must be scalar between (0,1]. Received {mask}" + ) + # Construct a boolean `mask` to apply in next code block as a 2D `mask` + mask = ( + grid_2d(im.shape[-1], normalized=True, dtype=np.float64)["r"] < mask + ) + mask = mask.astype(im.dtype) + + # Masking, 2D case + # Confirm `mask` size is consistent + if mask.shape == im.shape[-2:]: + im = im * mask[None, :, :] + else: + raise RuntimeError( + f"Shape of `mask` {mask.shape} inconsistent with `Image` data shape {im.shape[-2:]}" + ) + + # Restore original stack shape and metadata. + return self.__class__(im, pixel_size=self.pixel_size).stack_reshape( + original_stack_shape + ) def save(self, mrcs_filepath, overwrite=None): """ diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py new file mode 100644 index 000000000..706248110 --- /dev/null +++ b/src/aspire/image/rotation.py @@ -0,0 +1,251 @@ +import numpy as np +from scipy import ndimage + +from aspire.numeric import fft, xp +from aspire.utils import complex_type + + +def _pre_rotate(theta): + """ + Given `theta` radians return nearest rotation of pi/2 + required to place angle within [-pi/4,pi/4) and the residual + rotation in radians. + + :param theta: Rotation in radians + :returns: + - Residual angle in radians + - Number of pi/2 rotations + """ + + theta = np.mod(theta, 2 * np.pi) + + # 0 < pi/4 + rots = 0 + residual = theta + + if theta >= np.pi / 4 and theta < 3 * np.pi / 4: + rots = 1 + residual = theta - np.pi / 2 + elif theta >= 3 * np.pi / 4 and theta < 5 * np.pi / 4: + rots = 2 + residual = theta - np.pi + elif theta >= 5 * np.pi / 4 and theta < 7 * np.pi / 4: + rots = 3 + residual = theta - 3 * np.pi / 2 + elif theta >= 7 * np.pi / 4 and theta < 2 * np.pi: + rots = 0 + residual = theta - 2 * np.pi + + return residual, rots + + +def _shift_center(n): + """ + Given `n` pixels return center pixel and shift amount, 0 or 1/2. + + :param n: Number of pixels + :returns: + - center pixel + - shift amount + """ + if n % 2 == 0: + c = n // 2 # center + s = 1 / 2 # shift + else: + c = n // 2 + s = 0 + + return c, s + + +def compute_fastrotate_interp_tables(theta, nx, ny): + """ + Retuns iterpolation tables as tuple M = (Mx, My, rots). + + :param theta: angle in radians + :param nx: Number pixels first axis + :param ny: Number pixels second axis + """ + theta, mult90 = _pre_rotate(theta) + + # Reverse rotation, Yaroslavsky rotated CW + theta = -theta + + cy, sy = _shift_center(ny) + cx, sx = _shift_center(nx) + + # Floating point epsilon + eps = np.finfo(np.float64).eps + + # Precompute Y interpolation tables + My = np.zeros((nx, ny), dtype=np.complex128) + r = np.arange(cy + 1, dtype=int) + u = (1 - np.cos(theta)) / np.sin(theta + eps) + alpha1 = 2 * np.pi * 1j * r / ny + + linds = np.arange(ny - 1, cy, -1, dtype=int) + rinds = np.arange(1, cy - 2 * sy + 1, dtype=int) + + Ux = u * (np.arange(nx) - cx + sx + 2) + My[:, r] = np.exp(alpha1[None, :] * Ux[:, None]) + My[:, linds] = My[:, rinds].conj() + + # Precompute X interpolation tables + Mx = np.zeros((ny, nx), dtype=np.complex128) + r = np.arange(cx + 1, dtype=int) + u = -np.sin(theta) + alpha2 = 2 * np.pi * 1j * r / nx + + linds = np.arange(nx - 1, cx, -1, dtype=int) + rinds = np.arange(1, cx - 2 * sx + 1, dtype=int) + + Uy = u * (np.arange(ny) - cy + sy + 2) + Mx[:, r] = np.exp(alpha2[None, :] * Uy[:, None]) + Mx[:, linds] = Mx[:, rinds].conj() + + # After building, transpose to (nx, ny). + Mx = Mx.T + + return Mx, My, mult90 + + +# The following helper utilities are written to work with +# `img` data of dimension 2 or more where the data is expected to be +# in the (-2,-1) dimensions with any other dims as stack axes. +def _rot90(img): + """Rotate image array by 90 degrees.""" + # stack broadcast of flipud(img.T) + return xp.flip(xp.swapaxes(img, -1, -2), axis=-2) + + +def _rot180(img): + """Rotate image array by 180 degrees.""" + # stack broadcast of flipud(fliplr) + return xp.flip(img, axis=(-1, -2)) + + +def _rot270(img): + """Rotate image array by 270 degrees.""" + # stack broadcast of fliplr(img.T) + return xp.flip(xp.swapaxes(img, -1, -2), axis=-1) + + +def fastrotate(images, theta, M=None): + """ + Rotate `images` array by `theta` radians ccw using shearing algorithm. + + Note that this algorithm may have artifacts near the rotation boundary + and will have artifacts outside the rotation boundary. + Users can avoid these by zero padding the input image then + cropping the rotated image and/or masking. + + For reference and notes: + `https://github.com/PrincetonUniversity/aspire/blob/760a43b35453e55ff2d9354339e9ffa109a25371/common/fastrotate/fastrotate.m` + + :param images: (n , px, px) array of image data + :param theta: Rotation angle in radians. + Note when `M` is supplied, `theta` must be `None`. + :param M: Optional precomputed shearing table. + Provided by `M=compute_fastrotate_interp_tables(theta, px, px)`. + Note when `M` is supplied, `theta` must be `None`. + :return: (n, px, px) array of rotated image data + """ + + # Make a stack of 1 + if images.ndim == 2: + images = images[None, :, :] + + n, px0, px1 = images.shape + assert px0 == px1, "Currently only implemented for square images." + + if M is None: + M = compute_fastrotate_interp_tables(theta, px0, px1) + elif theta is not None: + raise RuntimeError( + "`theta` must be `None` when supplying `M`." + " M is precomputed for a specific `theta`." + ) + Mx, My, Mrots = M + + # Cast interp tables to match precision of `images` + Mx = xp.asarray(Mx, complex_type(images.dtype)) + My = xp.asarray(My, complex_type(images.dtype)) + + # Determine if `images` data was provided on host (np.darray) + _host = isinstance(images, np.ndarray) + + # Copy image array to device if needed + images = xp.asarray(images) + + # Pre rotate by multiples of 90 (pi/2) + if Mrots == 1: + images = _rot90(images) + elif Mrots == 2: + images = _rot180(images) + elif Mrots == 3: + images = _rot270(images) + + # Shear 1 + img_k = fft.fft(images, axis=-1) + img_k = img_k * My + images = fft.ifft(img_k, axis=-1).real + + # Shear 2 + img_k = fft.fft(images, axis=-2) + img_k = img_k * Mx + images = fft.ifft(img_k, axis=-2).real + + # Shear 3 + img_k = fft.fft(images, axis=-1) + img_k = img_k * My + images = fft.ifft(img_k, axis=-1).real + + # Return to host if input was provided on host + if _host: + images = xp.asnumpy(images) + + return images + + +def sp_rotate(img, theta, **kwargs): + """ + Utility wrapper to form a ASPIRE compatible call to Scipy's image rotation. + + Converts `theta` from radian to degrees. + Defines stack/image axes and reshape behavior. + Image data is expected to be in last two axes in all cases. + + Additional kwargs will be passed through. + See scipy.ndimage.rotate + + :param img: Array of image data shape (L,L) or (...,L, L) + :param theta: Rotation in ccw radians. + :return: Array representing rotated `img`. + """ + + # Store original shape + original_shape = img.shape + # Image data is expected to be in last two axis in all cases + # Flatten, converts all inputs to consistent 3D shape (single stack axis). + img = img.reshape(-1, *img.shape[-2:]) + + # Scipy accepts a single scalar theta in degrees. + # Handle array of thetas and scalar case by expanding to flat array of img.shape + # Flatten all inputs + theta = np.rad2deg(np.array(theta)).reshape(-1) + # Expand single scalar input + if np.size(theta) == 1: + theta = np.full(img.shape[0], theta, img.dtype) + # Check we have an array matching `img`, both should be len(n) + if theta.shape[0] != img.shape[0]: + raise RuntimeError("Inconsistent `theta` and `img` shapes.") + + # Create result array and rotate images via loop + result = np.empty_like(img) + for i in range(img.shape[0]): + result[i] = ndimage.rotate( + img[i], theta[i], reshape=False, axes=(-2, -1), **kwargs + ) + + # Restore original shape + return result.reshape(*original_shape) diff --git a/tests/test_image.py b/tests/test_image.py index 70f5906a6..009be5236 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -11,8 +11,8 @@ from pytest import raises from scipy.datasets import face -from aspire.image import Image -from aspire.utils import Rotation, powerset, utest_tolerance +from aspire.image import Image, compute_fastrotate_interp_tables, fastrotate, sp_rotate +from aspire.utils import Rotation, gaussian_2d, grid_2d, powerset, utest_tolerance from aspire.volume import CnSymmetryGroup from .test_utils import matplotlib_dry_run @@ -564,3 +564,152 @@ def test_save_load_pixel_size(get_images, dtype): np.testing.assert_almost_equal( im2.pixel_size, im.pixel_size, err_msg="Image pixel_size incorrect save-load" ) + + +@pytest.fixture( + params=Image.rotation_methods, ids=lambda x: f"method={x}", scope="module" +) +def rotation_method(request): + return request.param + + +def test_image_rotate(dtype, rotation_method): + """ + Compare image rotations against rotated gaussian blobs. + """ + + L = 129 # Test image size in pixels + num_test_angles = 42 + # Create mask, used to zero edge artifacts + mask = grid_2d(L, normalized=True)["r"] < 0.9 + + def _gen_image(angle, L, n=1, K=10): + """ + Generate `n` `L-by-L` image arrays, + each constructed by a sequence of `K` gaussian blobs, + and reference images with the blob centers rotated by `angle`. + + Return tuple of unrotated and rotated image arrays (n-by-L-by-L). + + :param angle: rotation angle + :param L: size (L-by-L) in pixels + :param K: Number of blobs + :return: + - Array of unrotated data (float64) + - Array of rotated data (float64) + """ + + im = np.zeros((n, L, L), dtype=np.float64) + rotated_im = np.zeros_like(im) + + centers = np.random.randint(-L // 4, L // 4, size=(n, 10, 2)) + sigmas = np.full((n, K, 2), L / 10, dtype=np.float64) + + # Rotate the gaussian specifications + R = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) + rotated_centers = centers @ R + + # Construct each image independently + for i in range(n): + for center, sigma in zip(centers[i], sigmas[i]): + im[i] = im[i] + gaussian_2d(L, center, sigma, dtype=np.float64) + + for center, sigma in zip(rotated_centers[i], sigmas[i]): + rotated_im[i] = rotated_im[i] + gaussian_2d( + L, center, sigma, dtype=np.float64 + ) + + return im, rotated_im + + # Test over a variety of angles `theta` + for theta in np.linspace(0, 2 * np.pi, num_test_angles): + # Generate images and reference (`theta`) rotated images + im, ref = _gen_image(theta, L, n=3) + im = Image(im.astype(dtype, copy=False)) + + # Rotate using `Image`'s `rotation_method` + im_rot = im.rotate(theta, method=rotation_method) + + # Mask off boundary artifacts + masked_diff = (im_rot - ref) * mask + + # Compute L1 error of masked diff, per image + L1_error = np.mean(np.abs(masked_diff), axis=(-1, -2)) + np.testing.assert_array_less( + L1_error, + 0.1, + err_msg=f"{L} pixels using {rotation_method} @ {theta} radians", + ) + + +def test_sp_rotate_inputs(dtype): + """ + Smoke test various input combinations to the scipy rotation wrapper. + """ + + imgs = np.zeros((6, 8, 8), dtype=dtype) + thetas = np.arange(6, dtype=dtype) + theta = thetas[0] # scalar + + # # These are the only supported calls admitted by the function doc. + # singleton, scalar + _ = sp_rotate(imgs[0], theta) + # stack, scalar + _ = sp_rotate(imgs, theta) + + # # These happen to also work with the code, so were put under test. + # # We're not advertising them, as there really isn't a good use + # # case for this wrapper code outside of the internal wrapping + # # application. + # singleton, single element array + _ = sp_rotate(imgs[0], thetas[0:1]) + # stack, single element array + _ = sp_rotate(imgs, thetas[0:1]) + # stack, stack + _ = sp_rotate(imgs, thetas) + # md-stack, md-stack + _ = sp_rotate(imgs.reshape(2, 3, 8, 8), thetas.reshape(2, 3, 1)) + _ = sp_rotate(imgs.reshape(2, 3, 8, 8), thetas.reshape(2, 3)) + + +def test_fastrotate_inputs(dtype): + """ + Smoke test various input combinations to `fastrotate`. + """ + + imgs = np.zeros((6, 8, 8), dtype=dtype) + theta = 42 + + # # These are the supported calls + # singleton, scalar + _ = fastrotate(imgs[0], theta) + # stack, scalar + _ = fastrotate(imgs, theta) + + # # These can also remain under test, but are not advertised. + # stack, single element array + _ = fastrotate(imgs, np.array(theta)) + # singleton, single element array + _ = fastrotate(imgs[0], np.array(theta)) + + +def test_fastrotate_M_arg(dtype): + """ + Smoke test precomputed `M` input to `fastrotate`. + """ + + imgs = np.random.randn(6, 8, 8).astype(dtype) + theta = np.random.uniform(0, 2 * np.pi) + + # Precompute M + M = compute_fastrotate_interp_tables(theta, *imgs.shape[-2:]) + + # Call with theta None + im_rot_M = fastrotate(imgs, None, M=M) + # Compare to calling withou `M` + im_rot = fastrotate(imgs, theta) + np.testing.assert_allclose(im_rot_M, im_rot) + + # Call with theta, should raise + with raises(RuntimeError, match=r".*`theta` must be `None`.*"): + _ = fastrotate(imgs, theta, M=M)