From d2d1b6a3b5e685a7de1dba24d358d004fc28e4b3 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 25 Nov 2025 09:05:32 -0500 Subject: [PATCH 01/21] init add faasrot --- src/aspire/image/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/aspire/image/__init__.py b/src/aspire/image/__init__.py index 431dedf87a..5a5a007e6b 100644 --- a/src/aspire/image/__init__.py +++ b/src/aspire/image/__init__.py @@ -16,3 +16,4 @@ SigmaRejectionImageStacker, WinsorizedImageStacker, ) +from .faasrot.py import faasrot From 2677f962a13e3ebc3281502f86d7c6059a780632 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 25 Nov 2025 09:14:38 -0500 Subject: [PATCH 02/21] cp faasrot dbg code --- src/aspire/image/__init__.py | 2 +- src/aspire/image/faasrot.py | 173 +++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 src/aspire/image/faasrot.py diff --git a/src/aspire/image/__init__.py b/src/aspire/image/__init__.py index 5a5a007e6b..6240be96da 100644 --- a/src/aspire/image/__init__.py +++ b/src/aspire/image/__init__.py @@ -1,3 +1,4 @@ +from .faasrot.py import faasrot from .image import ( BasisImage, BispecImage, @@ -16,4 +17,3 @@ SigmaRejectionImageStacker, WinsorizedImageStacker, ) -from .faasrot.py import faasrot diff --git a/src/aspire/image/faasrot.py b/src/aspire/image/faasrot.py new file mode 100644 index 0000000000..00ad94879b --- /dev/null +++ b/src/aspire/image/faasrot.py @@ -0,0 +1,173 @@ +import numpy as np + +from aspire.numeric import xp + + +def _pre_rotate(theta): + """ + Given angle `theta` (degrees) return nearest rotation of 90 + degrees required to place angle within [-45,45) and residual + rotation in degrees. + """ + + theta = np.mod(theta, 360) + + # 0 < 45 + rot90 = 0 + residual = theta + + if theta >= 45 and theta < 135: + rot90 = 1 + residual = theta - 90 + elif theta >= 135 and theta < 225: + rot90 = 2 + residual = theta - 180 + elif theta >= 215 and theta < 315: + rot90 = 3 + residual = theta - 270 + elif theta >= 315 and theta < 360: + rot90 = 0 + residual = theta - 360 + + return residual, rot90 + + +def _shift_center(n): + """ + Given `n` pixels return center pixel and shift amount, 0 or 1/2. + """ + if n % 2 == 0: + c = n // 2 # center + s = 1 / 2 # shift + else: + c = n // 2 + s = 0 + + return c, s + + +def _pre_compute(theta, nx, ny): + """ + Retuns M = (Mx, My, rot90) + """ + theta, mult90 = _pre_rotate(theta) + + theta = np.pi * theta / 180 + theta = -theta # Yaroslavsky rotated CW + + cy, sy = _shift_center(ny) + cx, sx = _shift_center(nx) + + # Floating point epsilon + eps = np.finfo(np.float64).eps + + # Precompute Y interpolation tables + My = np.zeros((nx, ny), dtype=np.complex128) + r = np.arange(cy + 1, dtype=int) + u = (1 - np.cos(theta)) / np.sin(theta + eps) + # print("u", u) + alpha1 = 2 * np.pi * 1j * r / ny + + # print("alpha1", alpha1) + + linds = np.arange(ny - 1, cy, -1, dtype=int) + # print('aaa', ny-1, cy, -1) + rinds = np.arange(1, cy - 2 * sy + 1, dtype=int) + # print(linds,rinds) + # This can be broadcast, but leaving loop since would be close to CUDA... + for x in range(nx): + Ux = u * (x - cx + sx + 2) + # print("Ux",Ux) + My[x, r] = np.exp(alpha1 * Ux) + My[x, linds] = np.conj(My[x, rinds]) + + # Precompute X interpolation tables + Mx = np.zeros((ny, nx), dtype=np.complex128) + r = np.arange(cx + 1, dtype=int) + u = -np.sin(theta) + alpha2 = 2 * np.pi * 1j * r / nx + + linds = np.arange(nx - 1, cx, -1, dtype=int) + rinds = np.arange(1, cx - 2 * sx + 1, dtype=int) + # This can be broadcast, but leaving loop since would be close to CUDA... + for y in range(ny): + Uy = u * (y - cy + sy + 2) + Mx[y, r] = np.exp(alpha2 * Uy) + Mx[y, linds] = np.conj(Mx[y, rinds]) + + # After building, transpose to (nx, ny). + Mx = Mx.T + + return Mx, My, mult90 + + +def _rot90(img): + return np.flipud(img.T) + + +def _rot180(img): + return np.flipud(np.fliplr(img)) + + +def _rot270(img): + return np.fliplr(img.T) + + +def faastrotate(images, theta, M=None): + + # Make a stack of 1 + if images.ndim == 2: + images = images[None, :, :] + + n, px0, px1 = images.shape + assert px0 == px1, "Currently only implemented for square images." + + if M is None: + M = _pre_compute(theta, px0, px1) + Mx, My, Mrot90 = M + + result = np.empty((n, px0, px1), dtype=np.float64) + + for i in range(n): + + img = images[i] + + # Pre rotate by multiples of 90 + if Mrot90 == 1: + img = _rot90(img) + elif Mrot90 == 2: + img = _rot180(img) + elif Mrot90 == 3: + img = _rot270(img) + + # Shear 1 + img_k = np.fft.fft(img, axis=-1) + # okay print("\nfft1(img_k):\n", img_k,"\n") + print("\nMy:\n", My, "\n") + img_k = img_k * My + print("\nmult (img_k):\n", img_k, "\n") # okay + + # for _i in range(16): + # #print(f'A[{_i}].x = {img_k.flatten()[_i].real};') + # #print(f'A[{_i}].y = {img_k.flatten()[_i].imag};') + # print(f'FA[{_i}] = {img_k.flatten()[_i]};') + + # breakpoint() + result[i] = np.real(np.fft.ifft(img_k, axis=-1)) + print("\nstage1\n", result[i] * 4, "\n") + + # Shear 2 + img_k = np.fft.fft(result[i], axis=0) + img_k = img_k * Mx + result[i] = np.real(np.fft.ifft(img_k, axis=0)) + + print("\nstage2\n", result * 4 * 4) + + # Shear 3 + img_k = np.fft.fft(result[i], axis=-1) + img_k = img_k * My + result[i] = np.real(np.fft.ifft(img_k, axis=-1)) + + print("\nstage3\n", result * 4 * 4 * 4, "\n") + + return result From 7a3d22d6655ad759a2d435e1572373477a9c5503 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 25 Nov 2025 10:42:07 -0500 Subject: [PATCH 03/21] initial faasrot code/test add [skip ci] --- src/aspire/image/__init__.py | 2 +- src/aspire/image/faasrot.py | 34 +++++++++++----------------------- src/aspire/image/image.py | 15 +++++++++++++-- tests/test_image.py | 24 +++++++++++++++++++++++- 4 files changed, 48 insertions(+), 27 deletions(-) diff --git a/src/aspire/image/__init__.py b/src/aspire/image/__init__.py index 6240be96da..264656ff56 100644 --- a/src/aspire/image/__init__.py +++ b/src/aspire/image/__init__.py @@ -1,4 +1,4 @@ -from .faasrot.py import faasrot +from .faasrot import faasrotate from .image import ( BasisImage, BispecImage, diff --git a/src/aspire/image/faasrot.py b/src/aspire/image/faasrot.py index 00ad94879b..0b6a7ab71b 100644 --- a/src/aspire/image/faasrot.py +++ b/src/aspire/image/faasrot.py @@ -65,19 +65,13 @@ def _pre_compute(theta, nx, ny): My = np.zeros((nx, ny), dtype=np.complex128) r = np.arange(cy + 1, dtype=int) u = (1 - np.cos(theta)) / np.sin(theta + eps) - # print("u", u) alpha1 = 2 * np.pi * 1j * r / ny - # print("alpha1", alpha1) - linds = np.arange(ny - 1, cy, -1, dtype=int) - # print('aaa', ny-1, cy, -1) rinds = np.arange(1, cy - 2 * sy + 1, dtype=int) - # print(linds,rinds) # This can be broadcast, but leaving loop since would be close to CUDA... for x in range(nx): Ux = u * (x - cx + sx + 2) - # print("Ux",Ux) My[x, r] = np.exp(alpha1 * Ux) My[x, linds] = np.conj(My[x, rinds]) @@ -113,7 +107,17 @@ def _rot270(img): return np.fliplr(img.T) -def faastrotate(images, theta, M=None): +def faasrotate(images, theta, M=None): + """ + Rotate `images` array by `theta` radians ccw. + + :param images: (n , px, px) array of image data + :param theta: rotation angle in radians + :param M: optional precomputed shearing table + :return: (n, px, px) array fo rotated image data + """ + # Convert to degrees + theta = np.rad2deg(theta) # Make a stack of 1 if images.ndim == 2: @@ -142,32 +146,16 @@ def faastrotate(images, theta, M=None): # Shear 1 img_k = np.fft.fft(img, axis=-1) - # okay print("\nfft1(img_k):\n", img_k,"\n") - print("\nMy:\n", My, "\n") img_k = img_k * My - print("\nmult (img_k):\n", img_k, "\n") # okay - - # for _i in range(16): - # #print(f'A[{_i}].x = {img_k.flatten()[_i].real};') - # #print(f'A[{_i}].y = {img_k.flatten()[_i].imag};') - # print(f'FA[{_i}] = {img_k.flatten()[_i]};') - - # breakpoint() result[i] = np.real(np.fft.ifft(img_k, axis=-1)) - print("\nstage1\n", result[i] * 4, "\n") # Shear 2 img_k = np.fft.fft(result[i], axis=0) img_k = img_k * Mx result[i] = np.real(np.fft.ifft(img_k, axis=0)) - print("\nstage2\n", result * 4 * 4) - # Shear 3 img_k = np.fft.fft(result[i], axis=-1) img_k = img_k * My - result[i] = np.real(np.fft.ifft(img_k, axis=-1)) - - print("\nstage3\n", result * 4 * 4 * 4, "\n") return result diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 8212033f72..2f1ed6012f 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -10,6 +10,7 @@ import aspire.sinogram import aspire.volume +from aspire.image import faasrotate from aspire.nufft import anufft, nufft from aspire.numeric import fft, xp from aspire.utils import ( @@ -635,8 +636,18 @@ def filter(self, filter): original_stack_shape ) - def rotate(self): - raise NotImplementedError + def rotate(self, theta): + """ + Rotate by `theta` radians. + """ + original_stack_shape = self.stack_shape + im = self.stack_reshape(-1) + + im = faasrotate(im._data, theta) + + return self.__class__(im, pixel_size=self.pixel_size).stack_reshape( + original_stack_shape + ) def save(self, mrcs_filepath, overwrite=None): """ diff --git a/tests/test_image.py b/tests/test_image.py index 70f5906a64..f788654d93 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -10,9 +10,10 @@ from PIL import Image as PILImage from pytest import raises from scipy.datasets import face +from scipy.ndimage import rotate from aspire.image import Image -from aspire.utils import Rotation, powerset, utest_tolerance +from aspire.utils import Rotation, grid_2d, powerset, utest_tolerance from aspire.volume import CnSymmetryGroup from .test_utils import matplotlib_dry_run @@ -564,3 +565,24 @@ def test_save_load_pixel_size(get_images, dtype): np.testing.assert_almost_equal( im2.pixel_size, im.pixel_size, err_msg="Image pixel_size incorrect save-load" ) + + +def test_faasrotate(get_images, dtype): + im_np, im = get_images + + mask = grid_2d(im_np.shape[-1])["r"] < 1 + + for theta in np.linspace(0, 2 * np.pi, 100): + im_rot = im.rotate(theta) + + # reference to scipy + ref = rotate( + im_np, + np.rad2deg(theta), + reshape=False, + ) + + # mask off ears + masked_diff = (im_rot - ref) * mask + + np.testing.assert_allclose(masked_diff, 0, atol=1e-7) From 018ed132a273ee6eb8a0a5fcd61346a6754ccb88 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 25 Nov 2025 11:38:58 -0500 Subject: [PATCH 04/21] dbg --- src/aspire/image/faasrot.py | 5 +++-- tests/test_image.py | 10 +++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/aspire/image/faasrot.py b/src/aspire/image/faasrot.py index 0b6a7ab71b..9c7a5f465b 100644 --- a/src/aspire/image/faasrot.py +++ b/src/aspire/image/faasrot.py @@ -150,12 +150,13 @@ def faasrotate(images, theta, M=None): result[i] = np.real(np.fft.ifft(img_k, axis=-1)) # Shear 2 - img_k = np.fft.fft(result[i], axis=0) + img_k = np.fft.fft(result[i], axis=-2) img_k = img_k * Mx - result[i] = np.real(np.fft.ifft(img_k, axis=0)) + result[i] = np.real(np.fft.ifft(img_k, axis=-2)) # Shear 3 img_k = np.fft.fft(result[i], axis=-1) img_k = img_k * My + result[i] = np.real(np.fft.ifft(img_k, axis=-1)) return result diff --git a/tests/test_image.py b/tests/test_image.py index f788654d93..608c83bbab 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -572,7 +572,8 @@ def test_faasrotate(get_images, dtype): mask = grid_2d(im_np.shape[-1])["r"] < 1 - for theta in np.linspace(0, 2 * np.pi, 100): + #for theta in np.linspace(0, 2 * np.pi, 100): + for theta in [np.pi/4]: im_rot = im.rotate(theta) # reference to scipy @@ -580,8 +581,15 @@ def test_faasrotate(get_images, dtype): im_np, np.rad2deg(theta), reshape=False, + axes=(-1,-2), ) + peek = np.empty((3, *im_np.shape[-2:])) + peek[0] = im_np + peek[1] = im_rot + peek[2] = ref + Image(peek).show() + # mask off ears masked_diff = (im_rot - ref) * mask From 17f5c8c0c95711224965755b73ff81a641403767 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 25 Nov 2025 14:34:18 -0500 Subject: [PATCH 05/21] add smoke test and xp [skip ci] --- src/aspire/image/faasrot.py | 70 +++++++++++++++++++------------------ tests/test_image.py | 30 ++++++++++------ 2 files changed, 56 insertions(+), 44 deletions(-) diff --git a/src/aspire/image/faasrot.py b/src/aspire/image/faasrot.py index 9c7a5f465b..9e46c87afe 100644 --- a/src/aspire/image/faasrot.py +++ b/src/aspire/image/faasrot.py @@ -1,15 +1,18 @@ import numpy as np -from aspire.numeric import xp +from aspire.numeric import fft, xp def _pre_rotate(theta): """ - Given angle `theta` (degrees) return nearest rotation of 90 - degrees required to place angle within [-45,45) and residual - rotation in degrees. + Given angle `theta` (radians) return nearest rotation of 90 + degrees required to place angle within [-45,45) degrees and residual + rotation (radians). """ + # todo + theta = np.rad2deg(theta) + theta = np.mod(theta, 360) # 0 < 45 @@ -29,7 +32,7 @@ def _pre_rotate(theta): rot90 = 0 residual = theta - 360 - return residual, rot90 + return np.deg2rad(residual), rot90 def _shift_center(n): @@ -49,10 +52,11 @@ def _shift_center(n): def _pre_compute(theta, nx, ny): """ Retuns M = (Mx, My, rot90) + + :param theta: angle in radians """ theta, mult90 = _pre_rotate(theta) - theta = np.pi * theta / 180 theta = -theta # Yaroslavsky rotated CW cy, sy = _shift_center(ny) @@ -62,32 +66,32 @@ def _pre_compute(theta, nx, ny): eps = np.finfo(np.float64).eps # Precompute Y interpolation tables - My = np.zeros((nx, ny), dtype=np.complex128) - r = np.arange(cy + 1, dtype=int) - u = (1 - np.cos(theta)) / np.sin(theta + eps) - alpha1 = 2 * np.pi * 1j * r / ny + My = xp.zeros((nx, ny), dtype=xp.complex128) + r = xp.arange(cy + 1, dtype=int) + u = (1 - xp.cos(theta)) / xp.sin(theta + eps) + alpha1 = 2 * xp.pi * 1j * r / ny - linds = np.arange(ny - 1, cy, -1, dtype=int) - rinds = np.arange(1, cy - 2 * sy + 1, dtype=int) + linds = xp.arange(ny - 1, cy, -1, dtype=int) + rinds = xp.arange(1, cy - 2 * sy + 1, dtype=int) # This can be broadcast, but leaving loop since would be close to CUDA... for x in range(nx): Ux = u * (x - cx + sx + 2) - My[x, r] = np.exp(alpha1 * Ux) - My[x, linds] = np.conj(My[x, rinds]) + My[x, r] = xp.exp(alpha1 * Ux) + My[x, linds] = My[x, rinds].conj() # Precompute X interpolation tables - Mx = np.zeros((ny, nx), dtype=np.complex128) - r = np.arange(cx + 1, dtype=int) - u = -np.sin(theta) - alpha2 = 2 * np.pi * 1j * r / nx + Mx = xp.zeros((ny, nx), dtype=xp.complex128) + r = xp.arange(cx + 1, dtype=int) + u = -xp.sin(theta) + alpha2 = 2 * xp.pi * 1j * r / nx - linds = np.arange(nx - 1, cx, -1, dtype=int) - rinds = np.arange(1, cx - 2 * sx + 1, dtype=int) + linds = xp.arange(nx - 1, cx, -1, dtype=int) + rinds = xp.arange(1, cx - 2 * sx + 1, dtype=int) # This can be broadcast, but leaving loop since would be close to CUDA... for y in range(ny): Uy = u * (y - cy + sy + 2) - Mx[y, r] = np.exp(alpha2 * Uy) - Mx[y, linds] = np.conj(Mx[y, rinds]) + Mx[y, r] = xp.exp(alpha2 * Uy) + Mx[y, linds] = Mx[y, rinds].conj() # After building, transpose to (nx, ny). Mx = Mx.T @@ -96,15 +100,15 @@ def _pre_compute(theta, nx, ny): def _rot90(img): - return np.flipud(img.T) + return xp.flipud(img.T) def _rot180(img): - return np.flipud(np.fliplr(img)) + return xp.flipud(xp.fliplr(img)) def _rot270(img): - return np.fliplr(img.T) + return xp.fliplr(img.T) def faasrotate(images, theta, M=None): @@ -116,8 +120,6 @@ def faasrotate(images, theta, M=None): :param M: optional precomputed shearing table :return: (n, px, px) array fo rotated image data """ - # Convert to degrees - theta = np.rad2deg(theta) # Make a stack of 1 if images.ndim == 2: @@ -134,7 +136,7 @@ def faasrotate(images, theta, M=None): for i in range(n): - img = images[i] + img = xp.asarray(images[i]) # Pre rotate by multiples of 90 if Mrot90 == 1: @@ -145,18 +147,18 @@ def faasrotate(images, theta, M=None): img = _rot270(img) # Shear 1 - img_k = np.fft.fft(img, axis=-1) + img_k = fft.fft(img, axis=-1) img_k = img_k * My - result[i] = np.real(np.fft.ifft(img_k, axis=-1)) + result[i] = fft.ifft(img_k, axis=-1).real # Shear 2 - img_k = np.fft.fft(result[i], axis=-2) + img_k = fft.fft(result[i], axis=-2) img_k = img_k * Mx - result[i] = np.real(np.fft.ifft(img_k, axis=-2)) + result[i] = fft.ifft(img_k, axis=-2).real # Shear 3 - img_k = np.fft.fft(result[i], axis=-1) + img_k = fft.fft(result[i], axis=-1) img_k = img_k * My - result[i] = np.real(np.fft.ifft(img_k, axis=-1)) + result[i] = fft.ifft(img_k, axis=-1).real return result diff --git a/tests/test_image.py b/tests/test_image.py index 608c83bbab..0d4e9cbc7f 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -570,10 +570,10 @@ def test_save_load_pixel_size(get_images, dtype): def test_faasrotate(get_images, dtype): im_np, im = get_images - mask = grid_2d(im_np.shape[-1])["r"] < 1 + mask = grid_2d(im_np.shape[-1])["r"] < 0.9 - #for theta in np.linspace(0, 2 * np.pi, 100): - for theta in [np.pi/4]: + for theta in np.linspace(0, 2 * np.pi, 100): + # for theta in [np.pi/4]: im_rot = im.rotate(theta) # reference to scipy @@ -581,16 +581,26 @@ def test_faasrotate(get_images, dtype): im_np, np.rad2deg(theta), reshape=False, - axes=(-1,-2), + axes=(-1, -2), ) - peek = np.empty((3, *im_np.shape[-2:])) - peek[0] = im_np - peek[1] = im_rot - peek[2] = ref - Image(peek).show() + # peek = np.empty((5, *im_np.shape[-2:])) + # peek[0] = im_np + # peek[1] = im_rot + # peek[2] = ref + # peek[3] = im_rot - ref + + # # print('origin', np.sum(np.abs(im_np*mask))) + # # print('im_rot', np.sum(np.abs(im_rot*mask))) + # # print('ref', np.sum(np.abs(ref*mask))) # mask off ears masked_diff = (im_rot - ref) * mask - np.testing.assert_allclose(masked_diff, 0, atol=1e-7) + # #masked_diff[:,mask] = masked_diff.asnumpy()[:,mask] / ref[:,mask] + # #peek[4] = np.nan_to_num(masked_diff) + # peek[4] = masked_diff + # Image(peek*mask).show() + + # mean masked pixel value is ~0.5, so this is ~2% + np.testing.assert_allclose(masked_diff, 0, atol=1) From 57a02679274c8d8f0a0e0e44b4200e4ee88e5295 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 25 Nov 2025 14:43:24 -0500 Subject: [PATCH 06/21] xp cleanup [skip ci] --- src/aspire/image/faasrot.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/aspire/image/faasrot.py b/src/aspire/image/faasrot.py index 9e46c87afe..b2ab00a48e 100644 --- a/src/aspire/image/faasrot.py +++ b/src/aspire/image/faasrot.py @@ -66,31 +66,31 @@ def _pre_compute(theta, nx, ny): eps = np.finfo(np.float64).eps # Precompute Y interpolation tables - My = xp.zeros((nx, ny), dtype=xp.complex128) - r = xp.arange(cy + 1, dtype=int) - u = (1 - xp.cos(theta)) / xp.sin(theta + eps) - alpha1 = 2 * xp.pi * 1j * r / ny + My = np.zeros((nx, ny), dtype=np.complex128) + r = np.arange(cy + 1, dtype=int) + u = (1 - np.cos(theta)) / np.sin(theta + eps) + alpha1 = 2 * np.pi * 1j * r / ny - linds = xp.arange(ny - 1, cy, -1, dtype=int) - rinds = xp.arange(1, cy - 2 * sy + 1, dtype=int) + linds = np.arange(ny - 1, cy, -1, dtype=int) + rinds = np.arange(1, cy - 2 * sy + 1, dtype=int) # This can be broadcast, but leaving loop since would be close to CUDA... for x in range(nx): Ux = u * (x - cx + sx + 2) - My[x, r] = xp.exp(alpha1 * Ux) + My[x, r] = np.exp(alpha1 * Ux) My[x, linds] = My[x, rinds].conj() # Precompute X interpolation tables - Mx = xp.zeros((ny, nx), dtype=xp.complex128) - r = xp.arange(cx + 1, dtype=int) - u = -xp.sin(theta) - alpha2 = 2 * xp.pi * 1j * r / nx + Mx = np.zeros((ny, nx), dtype=np.complex128) + r = np.arange(cx + 1, dtype=int) + u = -np.sin(theta) + alpha2 = 2 * np.pi * 1j * r / nx - linds = xp.arange(nx - 1, cx, -1, dtype=int) - rinds = xp.arange(1, cx - 2 * sx + 1, dtype=int) + linds = np.arange(nx - 1, cx, -1, dtype=int) + rinds = np.arange(1, cx - 2 * sx + 1, dtype=int) # This can be broadcast, but leaving loop since would be close to CUDA... for y in range(ny): Uy = u * (y - cy + sy + 2) - Mx[y, r] = xp.exp(alpha2 * Uy) + Mx[y, r] = np.exp(alpha2 * Uy) Mx[y, linds] = Mx[y, rinds].conj() # After building, transpose to (nx, ny). @@ -132,8 +132,9 @@ def faasrotate(images, theta, M=None): M = _pre_compute(theta, px0, px1) Mx, My, Mrot90 = M - result = np.empty((n, px0, px1), dtype=np.float64) + Mx, My = xp.asarray(Mx), xp.asarray(My) + result = xp.empty((n, px0, px1), dtype=np.float64) for i in range(n): img = xp.asarray(images[i]) @@ -161,4 +162,4 @@ def faasrotate(images, theta, M=None): img_k = img_k * My result[i] = fft.ifft(img_k, axis=-1).real - return result + return xp.asnumpy(result) From d7457639b718e1490e4a9ae0da48022fa4c3d6c2 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 Dec 2025 15:49:44 -0500 Subject: [PATCH 07/21] cleanup and extension --- src/aspire/image/__init__.py | 2 +- .../image/{faasrot.py => fastrotate.py} | 20 ++++-- src/aspire/image/image.py | 61 ++++++++++++++++++- 3 files changed, 74 insertions(+), 9 deletions(-) rename src/aspire/image/{faasrot.py => fastrotate.py} (83%) diff --git a/src/aspire/image/__init__.py b/src/aspire/image/__init__.py index 264656ff56..ff01f2d553 100644 --- a/src/aspire/image/__init__.py +++ b/src/aspire/image/__init__.py @@ -1,4 +1,4 @@ -from .faasrot import faasrotate +from .fastrotate import compute_fastrotate_interp_tables, fastrotate from .image import ( BasisImage, BispecImage, diff --git a/src/aspire/image/faasrot.py b/src/aspire/image/fastrotate.py similarity index 83% rename from src/aspire/image/faasrot.py rename to src/aspire/image/fastrotate.py index b2ab00a48e..52092e735e 100644 --- a/src/aspire/image/faasrot.py +++ b/src/aspire/image/fastrotate.py @@ -49,11 +49,13 @@ def _shift_center(n): return c, s -def _pre_compute(theta, nx, ny): +def compute_fastrotate_interp_tables(theta, nx, ny): """ Retuns M = (Mx, My, rot90) :param theta: angle in radians + :param nx: Number pixels first axis + :param ny: Number pixels second axis """ theta, mult90 = _pre_rotate(theta) @@ -111,14 +113,22 @@ def _rot270(img): return xp.fliplr(img.T) -def faasrotate(images, theta, M=None): +def fastrotate(images, theta, M=None): """ - Rotate `images` array by `theta` radians ccw. + Rotate `images` array by `theta` radians ccw using shearing algorithm. + + Note that this algorithm may have artifacts near the rotation boundary + and will have artifacts outside the rotation boundary. + Users can avoid these by zero padding the input image then + cropping the rotated image and/or masking. + + For reference and notes: + `https://github.com/PrincetonUniversity/aspire/blob/760a43b35453e55ff2d9354339e9ffa109a25371/common/fastrotate/fastrotate.m` :param images: (n , px, px) array of image data :param theta: rotation angle in radians :param M: optional precomputed shearing table - :return: (n, px, px) array fo rotated image data + :return: (n, px, px) array of rotated image data """ # Make a stack of 1 @@ -129,7 +139,7 @@ def faasrotate(images, theta, M=None): assert px0 == px1, "Currently only implemented for square images." if M is None: - M = _pre_compute(theta, px0, px1) + M = compute_fastrotate_interp_tables(theta, px0, px1) Mx, My, Mrot90 = M Mx, My = xp.asarray(Mx), xp.asarray(My) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 2f1ed6012f..9bdae9f7f5 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -10,7 +10,7 @@ import aspire.sinogram import aspire.volume -from aspire.image import faasrotate +from aspire.image import fastrotate from aspire.nufft import anufft, nufft from aspire.numeric import fft, xp from aspire.utils import ( @@ -159,6 +159,8 @@ class Image: ".tif": load_tiff, ".tiff": load_tiff, } + # Available image rotation functions + rotation_methods = {"fastrotate": fastrotate} def __init__(self, data, pixel_size=None, dtype=None): """ @@ -636,15 +638,68 @@ def filter(self, filter): original_stack_shape ) - def rotate(self, theta): + def rotate(self, theta, method="fastrotate", mask=1): """ Rotate by `theta` radians. + + :param theta: Scalar or array of length `n_images` + :param mask: Optional scalar or array mask matching `Image` shape. + Scalar will create a circular mask of prescribed radius `(0,1]`. + Array mask will be applied via elementwise multiplication. + `None` disables masking. + :returns: `Image` containing Rotated image data. """ original_stack_shape = self.stack_shape im = self.stack_reshape(-1) - im = faasrotate(im._data, theta) + # Resolve rotation method + if method not in self.rotation_methods: + raise NotImplementedError( + f"Requested `Image.rotation` method={method} not found." + f" Select from {self.rotation_methods.keys()}" + ) + # otherwise, assign the function + rotation_function = self.rotation_methods[method] + + # Handle both scalar and arrays of rotation angles. + # `theta` arrays are checked to match length of images when stacks axis are flattened. + theta = np.array(theta).flatten() + if len(theta) == 1: + im = rotation_function(im._data, theta) + elif len(theta) == im.n_images: + rot_im = np.empty_like(im._data) + for i in range(im.n_images): + rot_im[i] = rotation_function(im._data[i], theta[i]) + im = rot_im + else: + raise RuntimeError( + f"Length of `theta` {len(theta)} and `Image` data {im.n_images} inconsistent." + ) + + # Masking, scalar case + if mask is not None: + if np.size(mask) == 1: + # Confirm `mask` value is a sane radius + if not (0 < mask <= 1): + raise ValueError( + f"Mask radius must be scalar between (0,1]. Received {mask}" + ) + # Construct a boolean `mask` to apply in next code block as a 2D `mask` + mask = ( + grid_2d(im.shape[-1], normalized=True, dtype=np.float64)["r"] < mask + ) + mask = mask.astype(im.dtype) + + # Masking, 2D case + # Confirm `mask` size is consistent + if mask.shape == im.shape[-2:]: + im = im * mask[None, :, :] + else: + raise RuntimeError( + f"Shape of `mask` {mask.shape} inconsistent with `Image` data shape {im.shape[-2:]}" + ) + # Restore original stack shape and metadata. return self.__class__(im, pixel_size=self.pixel_size).stack_reshape( original_stack_shape ) From c3caac0a6ed8b3ee2006858742a9ce754229dc4b Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 5 Dec 2025 11:07:59 -0500 Subject: [PATCH 08/21] polishing fastrotate --- src/aspire/image/fastrotate.py | 144 +++++++++++++++++++-------------- 1 file changed, 82 insertions(+), 62 deletions(-) diff --git a/src/aspire/image/fastrotate.py b/src/aspire/image/fastrotate.py index 52092e735e..37eaa4e1b8 100644 --- a/src/aspire/image/fastrotate.py +++ b/src/aspire/image/fastrotate.py @@ -5,39 +5,46 @@ def _pre_rotate(theta): """ - Given angle `theta` (radians) return nearest rotation of 90 - degrees required to place angle within [-45,45) degrees and residual - rotation (radians). + Given `theta` radians return nearest rotation of pi/2 + required to place angle within [-pi/4,pi/4) and the residual + rotation in radians. + + :param theta: Rotation in radians + :returns: + - Residual angle in radians + - Number of pi/2 rotations """ - # todo - theta = np.rad2deg(theta) + theta = np.mod(theta, 2 * np.pi) - theta = np.mod(theta, 360) - - # 0 < 45 - rot90 = 0 + # 0 < pi/4 + rots = 0 residual = theta - if theta >= 45 and theta < 135: - rot90 = 1 - residual = theta - 90 - elif theta >= 135 and theta < 225: - rot90 = 2 - residual = theta - 180 - elif theta >= 215 and theta < 315: - rot90 = 3 - residual = theta - 270 - elif theta >= 315 and theta < 360: - rot90 = 0 - residual = theta - 360 + if theta >= np.pi / 4 and theta < 3 * np.pi / 4: + rots = 1 + residual = theta - np.pi / 2 + elif theta >= 3 * np.pi / 4 and theta < 5 * np.pi / 4: + rots = 2 + residual = theta - np.pi + elif theta >= 5 * np.pi / 4 and theta < 7 * np.pi / 4: + rots = 3 + residual = theta - 3 * np.pi / 2 + elif theta >= 7 * np.pi / 4 and theta < 2 * np.pi: + rots = 0 + residual = theta - 2 * np.pi - return np.deg2rad(residual), rot90 + return residual, rots def _shift_center(n): """ Given `n` pixels return center pixel and shift amount, 0 or 1/2. + + :param n: Number of pixels + :returns: + - center pixel + - shift amount """ if n % 2 == 0: c = n // 2 # center @@ -51,7 +58,7 @@ def _shift_center(n): def compute_fastrotate_interp_tables(theta, nx, ny): """ - Retuns M = (Mx, My, rot90) + Retuns iterpolation tables as tuple M = (Mx, My, rots). :param theta: angle in radians :param nx: Number pixels first axis @@ -59,7 +66,8 @@ def compute_fastrotate_interp_tables(theta, nx, ny): """ theta, mult90 = _pre_rotate(theta) - theta = -theta # Yaroslavsky rotated CW + # Reverse rotation, Yaroslavsky rotated CW + theta = -theta cy, sy = _shift_center(ny) cx, sx = _shift_center(nx) @@ -75,11 +83,10 @@ def compute_fastrotate_interp_tables(theta, nx, ny): linds = np.arange(ny - 1, cy, -1, dtype=int) rinds = np.arange(1, cy - 2 * sy + 1, dtype=int) - # This can be broadcast, but leaving loop since would be close to CUDA... - for x in range(nx): - Ux = u * (x - cx + sx + 2) - My[x, r] = np.exp(alpha1 * Ux) - My[x, linds] = My[x, rinds].conj() + + Ux = u * (np.arange(nx) - cx + sx + 2) + My[:, r] = np.exp(alpha1[None, :] * Ux[:, None]) + My[:, linds] = My[:, rinds].conj() # Precompute X interpolation tables Mx = np.zeros((ny, nx), dtype=np.complex128) @@ -89,11 +96,10 @@ def compute_fastrotate_interp_tables(theta, nx, ny): linds = np.arange(nx - 1, cx, -1, dtype=int) rinds = np.arange(1, cx - 2 * sx + 1, dtype=int) - # This can be broadcast, but leaving loop since would be close to CUDA... - for y in range(ny): - Uy = u * (y - cy + sy + 2) - Mx[y, r] = np.exp(alpha2 * Uy) - Mx[y, linds] = Mx[y, rinds].conj() + + Uy = u * (np.arange(ny) - cy + sy + 2) + Mx[:, r] = np.exp(alpha2[None, :] * Uy[:, None]) + Mx[:, linds] = Mx[:, rinds].conj() # After building, transpose to (nx, ny). Mx = Mx.T @@ -101,16 +107,25 @@ def compute_fastrotate_interp_tables(theta, nx, ny): return Mx, My, mult90 +# The following helper utilities are written to work with +# `img` data of dimension 2 or more where the data is expected to be +# in the (-2,-1) dimensions with any other dims as stack axes. def _rot90(img): - return xp.flipud(img.T) + """Rotate image array by 90 degrees.""" + # stack broadcast of flipud(img.T) + return xp.flip(xp.swapaxes(img, -1, -2), axis=-2) def _rot180(img): - return xp.flipud(xp.fliplr(img)) + """Rotate image array by 180 degrees.""" + # stack broadcast of flipud(fliplr) + return xp.flip(img, axis=(-1, -2)) def _rot270(img): - return xp.fliplr(img.T) + """Rotate image array by 90 degrees.""" + # stack broadcast of fliplr(img.T) + return xp.flip(xp.swapaxes(img, -1, -2), axis=-1) def fastrotate(images, theta, M=None): @@ -140,36 +155,41 @@ def fastrotate(images, theta, M=None): if M is None: M = compute_fastrotate_interp_tables(theta, px0, px1) - Mx, My, Mrot90 = M + Mx, My, Mrots = M + + Mx, My = xp.asarray(Mx, dtype=images.dtype), xp.asarray(My, dtype=images.dtype) - Mx, My = xp.asarray(Mx), xp.asarray(My) + # Store if `images` data was provide on host (np.darray) + _host = isinstance(images, np.ndarray) - result = xp.empty((n, px0, px1), dtype=np.float64) - for i in range(n): + # If needed copy image array to device + images = xp.asarray(images) - img = xp.asarray(images[i]) + # Pre rotate by multiples of 90 (pi/2) + if Mrots == 1: + images = _rot90(images) + elif Mrots == 2: + images = _rot180(images) + elif Mrots == 3: + images = _rot270(images) - # Pre rotate by multiples of 90 - if Mrot90 == 1: - img = _rot90(img) - elif Mrot90 == 2: - img = _rot180(img) - elif Mrot90 == 3: - img = _rot270(img) + # Shear 1 + img_k = fft.fft(images, axis=-1) + img_k = img_k * My + images = fft.ifft(img_k, axis=-1).real - # Shear 1 - img_k = fft.fft(img, axis=-1) - img_k = img_k * My - result[i] = fft.ifft(img_k, axis=-1).real + # Shear 2 + img_k = fft.fft(images, axis=-2) + img_k = img_k * Mx + images = fft.ifft(img_k, axis=-2).real - # Shear 2 - img_k = fft.fft(result[i], axis=-2) - img_k = img_k * Mx - result[i] = fft.ifft(img_k, axis=-2).real + # Shear 3 + img_k = fft.fft(images, axis=-1) + img_k = img_k * My + images = fft.ifft(img_k, axis=-1).real - # Shear 3 - img_k = fft.fft(result[i], axis=-1) - img_k = img_k * My - result[i] = fft.ifft(img_k, axis=-1).real + # Return to host if needed + if _host: + images = xp.asnumpy(images) - return xp.asnumpy(result) + return images From a4f3075a272180e3bc045a057dd3b404656f2b26 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 5 Dec 2025 11:20:56 -0500 Subject: [PATCH 09/21] first pass polish Image.rotate --- src/aspire/image/image.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 9bdae9f7f5..c9734b72e8 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -638,17 +638,25 @@ def filter(self, filter): original_stack_shape ) - def rotate(self, theta, method="fastrotate", mask=1): + def rotate(self, theta, method="fastrotate", mask=1, **kwargs): """ - Rotate by `theta` radians. + Return `Image` rotated by `theta` radians using `method`. + + Optionally applies `mask`. Note that some methods may + introduce edge artifacts, in which case users may consider + using a tighter mask (eg 0.9) or a combination of pad-crop. + + Any additional kwargs will be passed to `method`. :param theta: Scalar or array of length `n_images` :param mask: Optional scalar or array mask matching `Image` shape. Scalar will create a circular mask of prescribed radius `(0,1]`. Array mask will be applied via elementwise multiplication. `None` disables masking. - :returns: `Image` containing Rotated image data. + :param method: Optionally specify a rotation method. + :return: `Image` containing rotated image data. """ + original_stack_shape = self.stack_shape im = self.stack_reshape(-1) @@ -658,18 +666,19 @@ def rotate(self, theta, method="fastrotate", mask=1): f"Requested `Image.rotation` method={method} not found." f" Select from {self.rotation_methods.keys()}" ) - # otherwise, assign the function + # Assign the rotation method's function + # Any rotation method is expected to handle image data as a 2D array or 3D array (single stack axis). rotation_function = self.rotation_methods[method] # Handle both scalar and arrays of rotation angles. # `theta` arrays are checked to match length of images when stacks axis are flattened. theta = np.array(theta).flatten() if len(theta) == 1: - im = rotation_function(im._data, theta) + im = rotation_function(im._data, theta, **kwargs) elif len(theta) == im.n_images: rot_im = np.empty_like(im._data) for i in range(im.n_images): - rot_im[i] = rotation_function(im._data[i], theta[i]) + rot_im[i] = rotation_function(im._data[i], theta[i], **kwargs) im = rot_im else: raise RuntimeError( From 829a0400fa33bc04e83d20835d0b3cbb42eca152 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 5 Dec 2025 11:54:36 -0500 Subject: [PATCH 10/21] initial add scipy.ndimage.rotate [skip ci] --- src/aspire/image/__init__.py | 9 ++++++++- src/aspire/image/image.py | 4 ++-- src/aspire/image/{fastrotate.py => rotation.py} | 17 +++++++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) rename src/aspire/image/{fastrotate.py => rotation.py} (90%) diff --git a/src/aspire/image/__init__.py b/src/aspire/image/__init__.py index ff01f2d553..442526fa54 100644 --- a/src/aspire/image/__init__.py +++ b/src/aspire/image/__init__.py @@ -1,4 +1,11 @@ -from .fastrotate import compute_fastrotate_interp_tables, fastrotate +# isort: off +from .rotation import ( + compute_fastrotate_interp_tables, + fastrotate, + sp_rotate, +) + +# isort: on from .image import ( BasisImage, BispecImage, diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index c9734b72e8..9ba54a66a4 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -10,7 +10,7 @@ import aspire.sinogram import aspire.volume -from aspire.image import fastrotate +from aspire.image import fastrotate, sp_rotate from aspire.nufft import anufft, nufft from aspire.numeric import fft, xp from aspire.utils import ( @@ -160,7 +160,7 @@ class Image: ".tiff": load_tiff, } # Available image rotation functions - rotation_methods = {"fastrotate": fastrotate} + rotation_methods = {"fastrotate": fastrotate, "scipy": sp_rotate} def __init__(self, data, pixel_size=None, dtype=None): """ diff --git a/src/aspire/image/fastrotate.py b/src/aspire/image/rotation.py similarity index 90% rename from src/aspire/image/fastrotate.py rename to src/aspire/image/rotation.py index 37eaa4e1b8..268d73cc43 100644 --- a/src/aspire/image/fastrotate.py +++ b/src/aspire/image/rotation.py @@ -1,4 +1,5 @@ import numpy as np +from scipy import ndimage from aspire.numeric import fft, xp @@ -193,3 +194,19 @@ def fastrotate(images, theta, M=None): images = xp.asnumpy(images) return images + + +def sp_rotate(im, theta, **kwargs): + """Utility wrapper to form a ASPIRE compatible call to Scipy's image rotation. + + Converts `theta` from radian to degrees. + Defines image axes and reshape behavior. + + Additional kwargs will be passed through. + See scipy.ndimage.rotate + + :param im: Array of image data shape (L,L) or (n,L, L) + :param theta: Rotation in ccw radians. + :return: Array representing rotated `im`. + """ + return ndimage.rotate(im, np.rad2deg(theta), reshape=False, axes=(-1, -2), **kwargs) From c946b466584cebccb47cd1bee2eeed74acdafb50 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 5 Dec 2025 14:57:25 -0500 Subject: [PATCH 11/21] extend test, but still need to improve it --- src/aspire/image/rotation.py | 37 +++++++++++++++++++++++---- tests/test_image.py | 48 ++++++++++++++---------------------- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py index 268d73cc43..2649812f9c 100644 --- a/src/aspire/image/rotation.py +++ b/src/aspire/image/rotation.py @@ -196,17 +196,44 @@ def fastrotate(images, theta, M=None): return images -def sp_rotate(im, theta, **kwargs): +def sp_rotate(img, theta, **kwargs): """Utility wrapper to form a ASPIRE compatible call to Scipy's image rotation. Converts `theta` from radian to degrees. - Defines image axes and reshape behavior. + Defines stack/image axes and reshape behavior. + Image data is expected to be in last two axes in all cases. Additional kwargs will be passed through. See scipy.ndimage.rotate - :param im: Array of image data shape (L,L) or (n,L, L) + :param img: Array of image data shape (L,L) or (...,L, L) :param theta: Rotation in ccw radians. - :return: Array representing rotated `im`. + :return: Array representing rotated `img`. """ - return ndimage.rotate(im, np.rad2deg(theta), reshape=False, axes=(-1, -2), **kwargs) + + # Store original shape + original_shape = img.shape + # Image data is expected to be in last two axis in all cases + # Flatten, converts all inputs to consistent 3D shape (single stack axis). + img = img.reshape(-1, *img.shape[-2:]) + + # Scipy accepts a single scalar theta in degrees. + # Handle array of thetas and scalar case by expanding to flat array of img.shape + # Flatten all inputs + theta = np.rad2deg(np.array(theta)).reshape(-1, 1) + # Expand scalar input + if theta.shape[0] == 1: + theta = np.full(img.shape[0], theta, img.dtype) + # Check we have an array matching `img` + if theta.shape != img.shape[:1]: + raise RuntimeError("Inconsistent `theta` and `img` shapes.") + + # Create result array and rotate images via loop + result = np.empty_like(img) + for i in range(img.shape[0]): + result[i] = ndimage.rotate( + img[i], theta[i], reshape=False, axes=(-2, -1), **kwargs + ) + + # Restore original shape + return result.reshape(*original_shape) diff --git a/tests/test_image.py b/tests/test_image.py index 0d4e9cbc7f..f178175e3c 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -567,40 +567,28 @@ def test_save_load_pixel_size(get_images, dtype): ) -def test_faasrotate(get_images, dtype): +@pytest.fixture( + params=Image.rotation_methods, ids=lambda x: f"method={x}", scope="module" +) +def rotation_method(request): + return request.param + + +# TODO, lets replace this with an analytic test +def test_image_rotate(get_images, dtype, rotation_method): im_np, im = get_images - mask = grid_2d(im_np.shape[-1])["r"] < 0.9 + # mask = grid_2d(im_np.shape[-1])["r"] < 0.9 for theta in np.linspace(0, 2 * np.pi, 100): # for theta in [np.pi/4]: - im_rot = im.rotate(theta) - - # reference to scipy - ref = rotate( - im_np, - np.rad2deg(theta), - reshape=False, - axes=(-1, -2), - ) - - # peek = np.empty((5, *im_np.shape[-2:])) - # peek[0] = im_np - # peek[1] = im_rot - # peek[2] = ref - # peek[3] = im_rot - ref - - # # print('origin', np.sum(np.abs(im_np*mask))) - # # print('im_rot', np.sum(np.abs(im_rot*mask))) - # # print('ref', np.sum(np.abs(ref*mask))) - - # mask off ears - masked_diff = (im_rot - ref) * mask - - # #masked_diff[:,mask] = masked_diff.asnumpy()[:,mask] / ref[:,mask] - # #peek[4] = np.nan_to_num(masked_diff) - # peek[4] = masked_diff - # Image(peek*mask).show() + im_rot = im.rotate(theta, method=rotation_method) + + # Use manual call to PIL as reference + ref = np.asarray(PILImage.fromarray(im_np[0]).rotate(np.rad2deg(theta))) + + # masked_diff = (im_rot - ref) * mask + diff = im_rot - ref # mean masked pixel value is ~0.5, so this is ~2% - np.testing.assert_allclose(masked_diff, 0, atol=1) + np.testing.assert_allclose(diff, 0, atol=1) From 396be322e9c43090c2b4fa68fb122eb019c76ac4 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 8 Dec 2025 09:56:11 -0500 Subject: [PATCH 12/21] rm unused import --- tests/test_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_image.py b/tests/test_image.py index f178175e3c..a8eea5b8c0 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -10,7 +10,6 @@ from PIL import Image as PILImage from pytest import raises from scipy.datasets import face -from scipy.ndimage import rotate from aspire.image import Image from aspire.utils import Rotation, grid_2d, powerset, utest_tolerance From 7f70d1d8e0441caebc008215a1b47e60b4baab2c Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 9 Dec 2025 07:54:49 -0500 Subject: [PATCH 13/21] more analytic image rot test --- tests/test_image.py | 67 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 53 insertions(+), 14 deletions(-) diff --git a/tests/test_image.py b/tests/test_image.py index a8eea5b8c0..793bf25245 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -12,7 +12,7 @@ from scipy.datasets import face from aspire.image import Image -from aspire.utils import Rotation, grid_2d, powerset, utest_tolerance +from aspire.utils import Rotation, gaussian_2d, grid_2d, powerset, utest_tolerance from aspire.volume import CnSymmetryGroup from .test_utils import matplotlib_dry_run @@ -573,21 +573,60 @@ def rotation_method(request): return request.param -# TODO, lets replace this with an analytic test -def test_image_rotate(get_images, dtype, rotation_method): - im_np, im = get_images +def test_image_rotate(dtype, rotation_method): + """ + Compare image rotations against rotated gaussian blobs. + """ - # mask = grid_2d(im_np.shape[-1])["r"] < 0.9 + def _gen_image(angle, L, K=10): + """ + Generate a sequence of `K` gaussian blobs rotated by `angle`. - for theta in np.linspace(0, 2 * np.pi, 100): - # for theta in [np.pi/4]: - im_rot = im.rotate(theta, method=rotation_method) + Return tuple of unrotated and rotated image arrays. + + :param angle: rotation angle + :param L: size (L-by-L) in pixels + :param K: Number of blobs + :return: + - Array of unrotated data (float64) + - Array of rotated data (float64) + """ + + im = np.zeros((L, L), dtype=np.float64) + rotated_im = np.zeros_like(im) + + centers = np.random.randint(-L // 4, L // 4, size=(10, 2)) + sigmas = np.full((K, 2), L / 10, dtype=np.float64) + + # Rotate the gaussian specifications + R = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) - # Use manual call to PIL as reference - ref = np.asarray(PILImage.fromarray(im_np[0]).rotate(np.rad2deg(theta))) + rotated_centers = centers @ R + + for center, sigma in zip(centers, sigmas): + im[:] = im[:] + gaussian_2d(L, center, sigma, dtype=np.float64) + + for center, sigma in zip(rotated_centers, sigmas): + rotated_im[:] = rotated_im[:] + gaussian_2d( + L, center, sigma, dtype=np.float64 + ) + + return im, rotated_im + + L = 129 # Test image size in pixels + # Create mask, zeros edge artifacts + mask = grid_2d(L, normalized=True)["r"] < 0.9 + + # for theta in np.linspace(0, 2 * np.pi, 100): + for theta in [np.pi / 4]: + im, ref = _gen_image(theta, L) + im = Image(im.astype(dtype, copy=False)) + + # Rotate using `Image` method + im_rot = im.rotate(theta, method=rotation_method) - # masked_diff = (im_rot - ref) * mask - diff = im_rot - ref + masked_diff = (im_rot - ref) * mask - # mean masked pixel value is ~0.5, so this is ~2% - np.testing.assert_allclose(diff, 0, atol=1) + # Compute L1 error of masked diff + L1_error = np.mean(np.abs(masked_diff)) + np.testing.assert_array_less(L1_error, 1e-6) From 6d185704dffbeceab361e51ccfa9049752b14c79 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 9 Dec 2025 07:55:00 -0500 Subject: [PATCH 14/21] comment cleanup --- src/aspire/image/rotation.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py index 2649812f9c..426c28b96c 100644 --- a/src/aspire/image/rotation.py +++ b/src/aspire/image/rotation.py @@ -2,6 +2,7 @@ from scipy import ndimage from aspire.numeric import fft, xp +from aspire.utils import complex_type def _pre_rotate(theta): @@ -158,12 +159,14 @@ def fastrotate(images, theta, M=None): M = compute_fastrotate_interp_tables(theta, px0, px1) Mx, My, Mrots = M - Mx, My = xp.asarray(Mx, dtype=images.dtype), xp.asarray(My, dtype=images.dtype) + # Cast interp tables to match precision of `images` + Mx = xp.asarray(Mx, complex_type(images.dtype)) + My = xp.asarray(My, complex_type(images.dtype)) - # Store if `images` data was provide on host (np.darray) + # Determine if `images` data was provided on host (np.darray) _host = isinstance(images, np.ndarray) - # If needed copy image array to device + # Copy image array to device if needed images = xp.asarray(images) # Pre rotate by multiples of 90 (pi/2) @@ -189,7 +192,7 @@ def fastrotate(images, theta, M=None): img_k = img_k * My images = fft.ifft(img_k, axis=-1).real - # Return to host if needed + # Return to host if input was provided on host if _host: images = xp.asnumpy(images) From 0d02899f4535b9918c71e98a83840a6e60fb3c4c Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 9 Dec 2025 09:42:49 -0500 Subject: [PATCH 15/21] cleanup --- tests/test_image.py | 58 ++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/tests/test_image.py b/tests/test_image.py index 793bf25245..7206998a4c 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -578,11 +578,18 @@ def test_image_rotate(dtype, rotation_method): Compare image rotations against rotated gaussian blobs. """ - def _gen_image(angle, L, K=10): + L = 129 # Test image size in pixels + num_test_angles = 42 + # Create mask, used to zero edge artifacts + mask = grid_2d(L, normalized=True)["r"] < 0.9 + + def _gen_image(angle, L, n=1, K=10): """ - Generate a sequence of `K` gaussian blobs rotated by `angle`. + Generate `n` `L-by-L` image arrays, + each constructed by a sequence of `K` gaussian blobs, + and reference images with the blob centers rotated by `angle`. - Return tuple of unrotated and rotated image arrays. + Return tuple of unrotated and rotated image arrays (n-by-L-by-L). :param angle: rotation angle :param L: size (L-by-L) in pixels @@ -592,41 +599,44 @@ def _gen_image(angle, L, K=10): - Array of rotated data (float64) """ - im = np.zeros((L, L), dtype=np.float64) + im = np.zeros((n, L, L), dtype=np.float64) rotated_im = np.zeros_like(im) - centers = np.random.randint(-L // 4, L // 4, size=(10, 2)) - sigmas = np.full((K, 2), L / 10, dtype=np.float64) + centers = np.random.randint(-L // 4, L // 4, size=(n, 10, 2)) + sigmas = np.full((n, K, 2), L / 10, dtype=np.float64) # Rotate the gaussian specifications R = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) - rotated_centers = centers @ R - for center, sigma in zip(centers, sigmas): - im[:] = im[:] + gaussian_2d(L, center, sigma, dtype=np.float64) + # Construct each image independently + for i in range(n): + for center, sigma in zip(centers[i], sigmas[i]): + im[i] = im[i] + gaussian_2d(L, center, sigma, dtype=np.float64) - for center, sigma in zip(rotated_centers, sigmas): - rotated_im[:] = rotated_im[:] + gaussian_2d( - L, center, sigma, dtype=np.float64 - ) + for center, sigma in zip(rotated_centers[i], sigmas[i]): + rotated_im[i] = rotated_im[i] + gaussian_2d( + L, center, sigma, dtype=np.float64 + ) return im, rotated_im - L = 129 # Test image size in pixels - # Create mask, zeros edge artifacts - mask = grid_2d(L, normalized=True)["r"] < 0.9 - - # for theta in np.linspace(0, 2 * np.pi, 100): - for theta in [np.pi / 4]: - im, ref = _gen_image(theta, L) + # Test over a variety of angles `theta` + for theta in np.linspace(0, 2 * np.pi, num_test_angles): + # Generate images and reference (`theta`) rotated images + im, ref = _gen_image(theta, L, n=3) im = Image(im.astype(dtype, copy=False)) - # Rotate using `Image` method + # Rotate using `Image`'s `rotation_method` im_rot = im.rotate(theta, method=rotation_method) + # Mask off boundary artifacts masked_diff = (im_rot - ref) * mask - # Compute L1 error of masked diff - L1_error = np.mean(np.abs(masked_diff)) - np.testing.assert_array_less(L1_error, 1e-6) + # Compute L1 error of masked diff, per image + L1_error = np.mean(np.abs(masked_diff), axis=(-1, -2)) + np.testing.assert_array_less( + L1_error, + 0.1, + err_msg=f"{L} pixels using {rotation_method} @ {theta} radians", + ) From 833c7a02add22c7f17474cd69a2fbe8886d3b363 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 9 Dec 2025 10:24:56 -0500 Subject: [PATCH 16/21] more cleanup --- src/aspire/image/image.py | 3 ++- src/aspire/image/rotation.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 9ba54a66a4..071437a1ef 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -638,7 +638,7 @@ def filter(self, filter): original_stack_shape ) - def rotate(self, theta, method="fastrotate", mask=1, **kwargs): + def rotate(self, theta, method="scipy", mask=1, **kwargs): """ Return `Image` rotated by `theta` radians using `method`. @@ -654,6 +654,7 @@ def rotate(self, theta, method="fastrotate", mask=1, **kwargs): Array mask will be applied via elementwise multiplication. `None` disables masking. :param method: Optionally specify a rotation method. + Defaults to `scipy`. :return: `Image` containing rotated image data. """ diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py index 426c28b96c..043fbfde72 100644 --- a/src/aspire/image/rotation.py +++ b/src/aspire/image/rotation.py @@ -125,7 +125,7 @@ def _rot180(img): def _rot270(img): - """Rotate image array by 90 degrees.""" + """Rotate image array by 270 degrees.""" # stack broadcast of fliplr(img.T) return xp.flip(xp.swapaxes(img, -1, -2), axis=-1) @@ -200,7 +200,8 @@ def fastrotate(images, theta, M=None): def sp_rotate(img, theta, **kwargs): - """Utility wrapper to form a ASPIRE compatible call to Scipy's image rotation. + """ + Utility wrapper to form a ASPIRE compatible call to Scipy's image rotation. Converts `theta` from radian to degrees. Defines stack/image axes and reshape behavior. From 9a1d2dd2f223bb890c80d97323546f00f4b54714 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 10 Dec 2025 10:13:24 -0500 Subject: [PATCH 17/21] attempt to fix theta bcast bug --- src/aspire/image/rotation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py index 043fbfde72..acce553c14 100644 --- a/src/aspire/image/rotation.py +++ b/src/aspire/image/rotation.py @@ -223,13 +223,13 @@ def sp_rotate(img, theta, **kwargs): # Scipy accepts a single scalar theta in degrees. # Handle array of thetas and scalar case by expanding to flat array of img.shape - # Flatten all inputs + # Flatten all inputs, becomes 2d stack theta = np.rad2deg(np.array(theta)).reshape(-1, 1) - # Expand scalar input - if theta.shape[0] == 1: + # Expand single scalar input + if np.size(theta) == 1: theta = np.full(img.shape[0], theta, img.dtype) - # Check we have an array matching `img` - if theta.shape != img.shape[:1]: + # Check we have an array matching `img`, both should be (n,1) + if theta.shape[0] != img.shape[0]: raise RuntimeError("Inconsistent `theta` and `img` shapes.") # Create result array and rotate images via loop From 7352979cc664ec277e3b3ac69c722a849712417e Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 11 Dec 2025 13:57:26 -0500 Subject: [PATCH 18/21] actually fix the bug this time --- src/aspire/image/rotation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py index acce553c14..cd8211b98a 100644 --- a/src/aspire/image/rotation.py +++ b/src/aspire/image/rotation.py @@ -223,12 +223,12 @@ def sp_rotate(img, theta, **kwargs): # Scipy accepts a single scalar theta in degrees. # Handle array of thetas and scalar case by expanding to flat array of img.shape - # Flatten all inputs, becomes 2d stack - theta = np.rad2deg(np.array(theta)).reshape(-1, 1) + # Flatten all inputs + theta = np.rad2deg(np.array(theta)).reshape(-1) # Expand single scalar input if np.size(theta) == 1: theta = np.full(img.shape[0], theta, img.dtype) - # Check we have an array matching `img`, both should be (n,1) + # Check we have an array matching `img`, both should be len(n) if theta.shape[0] != img.shape[0]: raise RuntimeError("Inconsistent `theta` and `img` shapes.") From be9effef67a7f0537e9a5356413f05d3ecd6509a Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 11 Dec 2025 14:21:35 -0500 Subject: [PATCH 19/21] add tests for additional unsupported input cases --- tests/test_image.py | 53 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/tests/test_image.py b/tests/test_image.py index 7206998a4c..ce60a946fe 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -11,7 +11,7 @@ from pytest import raises from scipy.datasets import face -from aspire.image import Image +from aspire.image import Image, fastrotate, sp_rotate from aspire.utils import Rotation, gaussian_2d, grid_2d, powerset, utest_tolerance from aspire.volume import CnSymmetryGroup @@ -640,3 +640,54 @@ def _gen_image(angle, L, n=1, K=10): 0.1, err_msg=f"{L} pixels using {rotation_method} @ {theta} radians", ) + + +def test_sp_rotate_inputs(dtype): + """ + Smoke test various input combinations to the scipy rotation wrapper. + """ + + imgs = np.zeros((6, 8, 8), dtype=dtype) + thetas = np.arange(6, dtype=dtype) + theta = thetas[0] # scalar + + ## These are the only supported calls admitted by the function doc. + # singleton, scalar + _ = sp_rotate(imgs[0], theta) + # stack, scalar + _ = sp_rotate(imgs, theta) + + ## These happen to also work with the code, so were put under test. + ## We're not advertising them, as there really isn't a good use + ## case for this wrapper code outside of the internal wrapping + ## application. + # singleton, single element array + _ = sp_rotate(imgs[0], thetas[0:1]) + # stack, single element array + _ = sp_rotate(imgs, thetas[0:1]) + # stack, stack + _ = sp_rotate(imgs, thetas) + # md-stack, md-stack + _ = sp_rotate(imgs.reshape(2, 3, 8, 8), thetas.reshape(2, 3, 1)) + _ = sp_rotate(imgs.reshape(2, 3, 8, 8), thetas.reshape(2, 3)) + + +def test_fastrotate_inputs(dtype): + """ + Smoke test various input combinations to `fastrotate`. + """ + + imgs = np.zeros((6, 8, 8), dtype=dtype) + theta = 42 + + ## These are the supported calls + # singleton, scalar + _ = fastrotate(imgs[0], theta) + # stack, scalar + _ = fastrotate(imgs, theta) + + ## These can also remain under test, but are not advertised. + # stack, single element array + _ = fastrotate(imgs, np.array(theta)) + # singleton, single element array + _ = fastrotate(imgs[0], np.array(theta)) From 9adc54949162b9caf782738466bb38d49ef68241 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 11 Dec 2025 14:30:34 -0500 Subject: [PATCH 20/21] tox doesn't like ## --- tests/test_image.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_image.py b/tests/test_image.py index ce60a946fe..3d981b6d09 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -651,16 +651,16 @@ def test_sp_rotate_inputs(dtype): thetas = np.arange(6, dtype=dtype) theta = thetas[0] # scalar - ## These are the only supported calls admitted by the function doc. + # # These are the only supported calls admitted by the function doc. # singleton, scalar _ = sp_rotate(imgs[0], theta) # stack, scalar _ = sp_rotate(imgs, theta) - ## These happen to also work with the code, so were put under test. - ## We're not advertising them, as there really isn't a good use - ## case for this wrapper code outside of the internal wrapping - ## application. + # # These happen to also work with the code, so were put under test. + # # We're not advertising them, as there really isn't a good use + # # case for this wrapper code outside of the internal wrapping + # # application. # singleton, single element array _ = sp_rotate(imgs[0], thetas[0:1]) # stack, single element array @@ -680,13 +680,13 @@ def test_fastrotate_inputs(dtype): imgs = np.zeros((6, 8, 8), dtype=dtype) theta = 42 - ## These are the supported calls + # # These are the supported calls # singleton, scalar _ = fastrotate(imgs[0], theta) # stack, scalar _ = fastrotate(imgs, theta) - ## These can also remain under test, but are not advertised. + # # These can also remain under test, but are not advertised. # stack, single element array _ = fastrotate(imgs, np.array(theta)) # singleton, single element array From d351fea2a79b65e18534abd3f6d493b5a8556e0c Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 12 Dec 2025 09:47:43 -0500 Subject: [PATCH 21/21] Add M+theta arg error and test --- src/aspire/image/rotation.py | 12 ++++++++++-- tests/test_image.py | 24 +++++++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py index cd8211b98a..706248110f 100644 --- a/src/aspire/image/rotation.py +++ b/src/aspire/image/rotation.py @@ -143,8 +143,11 @@ def fastrotate(images, theta, M=None): `https://github.com/PrincetonUniversity/aspire/blob/760a43b35453e55ff2d9354339e9ffa109a25371/common/fastrotate/fastrotate.m` :param images: (n , px, px) array of image data - :param theta: rotation angle in radians - :param M: optional precomputed shearing table + :param theta: Rotation angle in radians. + Note when `M` is supplied, `theta` must be `None`. + :param M: Optional precomputed shearing table. + Provided by `M=compute_fastrotate_interp_tables(theta, px, px)`. + Note when `M` is supplied, `theta` must be `None`. :return: (n, px, px) array of rotated image data """ @@ -157,6 +160,11 @@ def fastrotate(images, theta, M=None): if M is None: M = compute_fastrotate_interp_tables(theta, px0, px1) + elif theta is not None: + raise RuntimeError( + "`theta` must be `None` when supplying `M`." + " M is precomputed for a specific `theta`." + ) Mx, My, Mrots = M # Cast interp tables to match precision of `images` diff --git a/tests/test_image.py b/tests/test_image.py index 3d981b6d09..009be52364 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -11,7 +11,7 @@ from pytest import raises from scipy.datasets import face -from aspire.image import Image, fastrotate, sp_rotate +from aspire.image import Image, compute_fastrotate_interp_tables, fastrotate, sp_rotate from aspire.utils import Rotation, gaussian_2d, grid_2d, powerset, utest_tolerance from aspire.volume import CnSymmetryGroup @@ -691,3 +691,25 @@ def test_fastrotate_inputs(dtype): _ = fastrotate(imgs, np.array(theta)) # singleton, single element array _ = fastrotate(imgs[0], np.array(theta)) + + +def test_fastrotate_M_arg(dtype): + """ + Smoke test precomputed `M` input to `fastrotate`. + """ + + imgs = np.random.randn(6, 8, 8).astype(dtype) + theta = np.random.uniform(0, 2 * np.pi) + + # Precompute M + M = compute_fastrotate_interp_tables(theta, *imgs.shape[-2:]) + + # Call with theta None + im_rot_M = fastrotate(imgs, None, M=M) + # Compare to calling withou `M` + im_rot = fastrotate(imgs, theta) + np.testing.assert_allclose(im_rot_M, im_rot) + + # Call with theta, should raise + with raises(RuntimeError, match=r".*`theta` must be `None`.*"): + _ = fastrotate(imgs, theta, M=M)