Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/aspire/image/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# isort: off
from .rotation import (
compute_fastrotate_interp_tables,
fastrotate,
sp_rotate,
)

# isort: on
from .image import (
BasisImage,
BispecImage,
Expand Down
80 changes: 78 additions & 2 deletions src/aspire/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import aspire.sinogram
import aspire.volume
from aspire.image import fastrotate, sp_rotate
from aspire.nufft import anufft, nufft
from aspire.numeric import fft, xp
from aspire.utils import (
Expand Down Expand Up @@ -158,6 +159,8 @@ class Image:
".tif": load_tiff,
".tiff": load_tiff,
}
# Available image rotation functions
rotation_methods = {"fastrotate": fastrotate, "scipy": sp_rotate}

def __init__(self, data, pixel_size=None, dtype=None):
"""
Expand Down Expand Up @@ -635,8 +638,81 @@ def filter(self, filter):
original_stack_shape
)

def rotate(self):
raise NotImplementedError
def rotate(self, theta, method="scipy", mask=1, **kwargs):
"""
Return `Image` rotated by `theta` radians using `method`.

Optionally applies `mask`. Note that some methods may
introduce edge artifacts, in which case users may consider
using a tighter mask (eg 0.9) or a combination of pad-crop.

Any additional kwargs will be passed to `method`.

:param theta: Scalar or array of length `n_images`
:param mask: Optional scalar or array mask matching `Image` shape.
Scalar will create a circular mask of prescribed radius `(0,1]`.
Array mask will be applied via elementwise multiplication.
`None` disables masking.
:param method: Optionally specify a rotation method.
Defaults to `scipy`.
:return: `Image` containing rotated image data.
"""

original_stack_shape = self.stack_shape
im = self.stack_reshape(-1)

# Resolve rotation method
if method not in self.rotation_methods:
raise NotImplementedError(
f"Requested `Image.rotation` method={method} not found."
f" Select from {self.rotation_methods.keys()}"
)
# Assign the rotation method's function
# Any rotation method is expected to handle image data as a 2D array or 3D array (single stack axis).
rotation_function = self.rotation_methods[method]

# Handle both scalar and arrays of rotation angles.
# `theta` arrays are checked to match length of images when stacks axis are flattened.
theta = np.array(theta).flatten()
if len(theta) == 1:
im = rotation_function(im._data, theta, **kwargs)
elif len(theta) == im.n_images:
rot_im = np.empty_like(im._data)
for i in range(im.n_images):
rot_im[i] = rotation_function(im._data[i], theta[i], **kwargs)
im = rot_im
else:
raise RuntimeError(
f"Length of `theta` {len(theta)} and `Image` data {im.n_images} inconsistent."
)

# Masking, scalar case
if mask is not None:
if np.size(mask) == 1:
# Confirm `mask` value is a sane radius
if not (0 < mask <= 1):
raise ValueError(
f"Mask radius must be scalar between (0,1]. Received {mask}"
)
# Construct a boolean `mask` to apply in next code block as a 2D `mask`
mask = (
grid_2d(im.shape[-1], normalized=True, dtype=np.float64)["r"] < mask
)
mask = mask.astype(im.dtype)

# Masking, 2D case
# Confirm `mask` size is consistent
if mask.shape == im.shape[-2:]:
im = im * mask[None, :, :]
else:
raise RuntimeError(
f"Shape of `mask` {mask.shape} inconsistent with `Image` data shape {im.shape[-2:]}"
)

# Restore original stack shape and metadata.
return self.__class__(im, pixel_size=self.pixel_size).stack_reshape(
original_stack_shape
)

def save(self, mrcs_filepath, overwrite=None):
"""
Expand Down
251 changes: 251 additions & 0 deletions src/aspire/image/rotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import numpy as np
from scipy import ndimage

from aspire.numeric import fft, xp
from aspire.utils import complex_type


def _pre_rotate(theta):
"""
Given `theta` radians return nearest rotation of pi/2
required to place angle within [-pi/4,pi/4) and the residual
rotation in radians.

:param theta: Rotation in radians
:returns:
- Residual angle in radians
- Number of pi/2 rotations
"""

theta = np.mod(theta, 2 * np.pi)

# 0 < pi/4
rots = 0
residual = theta

if theta >= np.pi / 4 and theta < 3 * np.pi / 4:
rots = 1
residual = theta - np.pi / 2
elif theta >= 3 * np.pi / 4 and theta < 5 * np.pi / 4:
rots = 2
residual = theta - np.pi
elif theta >= 5 * np.pi / 4 and theta < 7 * np.pi / 4:
rots = 3
residual = theta - 3 * np.pi / 2
elif theta >= 7 * np.pi / 4 and theta < 2 * np.pi:
rots = 0
residual = theta - 2 * np.pi

return residual, rots


def _shift_center(n):
"""
Given `n` pixels return center pixel and shift amount, 0 or 1/2.

:param n: Number of pixels
:returns:
- center pixel
- shift amount
"""
if n % 2 == 0:
c = n // 2 # center
s = 1 / 2 # shift
else:
c = n // 2
s = 0

return c, s


def compute_fastrotate_interp_tables(theta, nx, ny):
"""
Retuns iterpolation tables as tuple M = (Mx, My, rots).

:param theta: angle in radians
:param nx: Number pixels first axis
:param ny: Number pixels second axis
"""
theta, mult90 = _pre_rotate(theta)

# Reverse rotation, Yaroslavsky rotated CW
theta = -theta

cy, sy = _shift_center(ny)
cx, sx = _shift_center(nx)

# Floating point epsilon
eps = np.finfo(np.float64).eps

# Precompute Y interpolation tables
My = np.zeros((nx, ny), dtype=np.complex128)
r = np.arange(cy + 1, dtype=int)
u = (1 - np.cos(theta)) / np.sin(theta + eps)
alpha1 = 2 * np.pi * 1j * r / ny

linds = np.arange(ny - 1, cy, -1, dtype=int)
rinds = np.arange(1, cy - 2 * sy + 1, dtype=int)

Ux = u * (np.arange(nx) - cx + sx + 2)
My[:, r] = np.exp(alpha1[None, :] * Ux[:, None])
My[:, linds] = My[:, rinds].conj()

# Precompute X interpolation tables
Mx = np.zeros((ny, nx), dtype=np.complex128)
r = np.arange(cx + 1, dtype=int)
u = -np.sin(theta)
alpha2 = 2 * np.pi * 1j * r / nx

linds = np.arange(nx - 1, cx, -1, dtype=int)
rinds = np.arange(1, cx - 2 * sx + 1, dtype=int)

Uy = u * (np.arange(ny) - cy + sy + 2)
Mx[:, r] = np.exp(alpha2[None, :] * Uy[:, None])
Mx[:, linds] = Mx[:, rinds].conj()

# After building, transpose to (nx, ny).
Mx = Mx.T

return Mx, My, mult90


# The following helper utilities are written to work with
# `img` data of dimension 2 or more where the data is expected to be
# in the (-2,-1) dimensions with any other dims as stack axes.
def _rot90(img):
"""Rotate image array by 90 degrees."""
# stack broadcast of flipud(img.T)
return xp.flip(xp.swapaxes(img, -1, -2), axis=-2)


def _rot180(img):
"""Rotate image array by 180 degrees."""
# stack broadcast of flipud(fliplr)
return xp.flip(img, axis=(-1, -2))


def _rot270(img):
"""Rotate image array by 270 degrees."""
# stack broadcast of fliplr(img.T)
return xp.flip(xp.swapaxes(img, -1, -2), axis=-1)


def fastrotate(images, theta, M=None):
"""
Rotate `images` array by `theta` radians ccw using shearing algorithm.

Note that this algorithm may have artifacts near the rotation boundary
and will have artifacts outside the rotation boundary.
Users can avoid these by zero padding the input image then
cropping the rotated image and/or masking.

For reference and notes:
`https://github.com/PrincetonUniversity/aspire/blob/760a43b35453e55ff2d9354339e9ffa109a25371/common/fastrotate/fastrotate.m`

:param images: (n , px, px) array of image data
:param theta: Rotation angle in radians.
Note when `M` is supplied, `theta` must be `None`.
:param M: Optional precomputed shearing table.
Provided by `M=compute_fastrotate_interp_tables(theta, px, px)`.
Note when `M` is supplied, `theta` must be `None`.
:return: (n, px, px) array of rotated image data
"""

# Make a stack of 1
if images.ndim == 2:
images = images[None, :, :]

n, px0, px1 = images.shape
assert px0 == px1, "Currently only implemented for square images."

if M is None:
M = compute_fastrotate_interp_tables(theta, px0, px1)
elif theta is not None:
raise RuntimeError(
"`theta` must be `None` when supplying `M`."
" M is precomputed for a specific `theta`."
)
Mx, My, Mrots = M

# Cast interp tables to match precision of `images`
Mx = xp.asarray(Mx, complex_type(images.dtype))
My = xp.asarray(My, complex_type(images.dtype))

# Determine if `images` data was provided on host (np.darray)
_host = isinstance(images, np.ndarray)

# Copy image array to device if needed
images = xp.asarray(images)

# Pre rotate by multiples of 90 (pi/2)
if Mrots == 1:
images = _rot90(images)
elif Mrots == 2:
images = _rot180(images)
elif Mrots == 3:
images = _rot270(images)

# Shear 1
img_k = fft.fft(images, axis=-1)
img_k = img_k * My
images = fft.ifft(img_k, axis=-1).real

# Shear 2
img_k = fft.fft(images, axis=-2)
img_k = img_k * Mx
images = fft.ifft(img_k, axis=-2).real

# Shear 3
img_k = fft.fft(images, axis=-1)
img_k = img_k * My
images = fft.ifft(img_k, axis=-1).real

# Return to host if input was provided on host
if _host:
images = xp.asnumpy(images)

return images


def sp_rotate(img, theta, **kwargs):
"""
Utility wrapper to form a ASPIRE compatible call to Scipy's image rotation.

Converts `theta` from radian to degrees.
Defines stack/image axes and reshape behavior.
Image data is expected to be in last two axes in all cases.

Additional kwargs will be passed through.
See scipy.ndimage.rotate

:param img: Array of image data shape (L,L) or (...,L, L)
:param theta: Rotation in ccw radians.
:return: Array representing rotated `img`.
"""

# Store original shape
original_shape = img.shape
# Image data is expected to be in last two axis in all cases
# Flatten, converts all inputs to consistent 3D shape (single stack axis).
img = img.reshape(-1, *img.shape[-2:])

# Scipy accepts a single scalar theta in degrees.
# Handle array of thetas and scalar case by expanding to flat array of img.shape
# Flatten all inputs
theta = np.rad2deg(np.array(theta)).reshape(-1)
# Expand single scalar input
if np.size(theta) == 1:
theta = np.full(img.shape[0], theta, img.dtype)
# Check we have an array matching `img`, both should be len(n)
if theta.shape[0] != img.shape[0]:
raise RuntimeError("Inconsistent `theta` and `img` shapes.")

# Create result array and rotate images via loop
result = np.empty_like(img)
for i in range(img.shape[0]):
result[i] = ndimage.rotate(
img[i], theta[i], reshape=False, axes=(-2, -1), **kwargs
)

# Restore original shape
return result.reshape(*original_shape)
Loading
Loading