Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions src/aspire/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,20 +203,19 @@ def __init__(self, data, pixel_size=None, dtype=None):
else:
self.dtype = np.dtype(dtype)

if not data.shape[-1] == data.shape[-2]:
raise ValueError("Only square ndarrays are supported.")

self._data = data.astype(self.dtype, copy=False)
self.ndim = self._data.ndim
self.shape = self._data.shape
self.stack_ndim = self._data.ndim - 2
self.stack_shape = self._data.shape[:-2]
self.n_images = np.prod(self.stack_shape)
self.resolution = self._data.shape[-1]
self.pixel_size = None
if pixel_size is not None:
self.pixel_size = float(pixel_size)

self._is_square = data.shape[-1] == data.shape[-2]
self.resolution = self._data.shape[-1] # XXXXX

# Numpy interop
# https://numpy.org/devdocs/user/basics.interoperability.html#the-array-interface-protocol
self.__array_interface__ = self._data.__array_interface__
Expand All @@ -234,6 +233,12 @@ def project(self, angles):
:return: Radon transform of the Image Stack.
:rtype: Ndarray (stack size, number of angles, image resolution)
"""

if not self._is_square:
raise NotImplementedError(
"`Image.project` is not currently implemented for non-square images."
)

# number of points to sample on radial line in polar grid
n_points = self.resolution
original_stack = self.stack_shape
Expand Down Expand Up @@ -309,19 +314,19 @@ def stack_reshape(self, *args):
)

def __add__(self, other):
if isinstance(other, Image):
if isinstance(other, self.__class__):
other = other._data

return self.__class__(self._data + other, pixel_size=self.pixel_size)

def __sub__(self, other):
if isinstance(other, Image):
if isinstance(other, self.__class__):
other = other._data

return self.__class__(self._data - other, pixel_size=self.pixel_size)

def __mul__(self, other):
if isinstance(other, Image):
if isinstance(other, self.__class__):
other = other._data

return self.__class__(self._data * other, pixel_size=self.pixel_size)
Expand Down Expand Up @@ -385,7 +390,7 @@ def __repr__(self):
px_msg = f" with pixel_size={self.pixel_size} angstroms."

msg = f"{self.n_images} {self.dtype} images arranged as a {self.stack_shape} stack"
msg += f" each of size {self.resolution}x{self.resolution}{px_msg}"
msg += f" each of size {self.shape[-2:]}{px_msg}"
return msg

def asnumpy(self):
Expand Down Expand Up @@ -442,6 +447,12 @@ def legacy_whiten(self, psd, delta):
and which to set to zero. By default all `sqrt(psd)` values
less than `delta` are zeroed out in the whitening filter.
"""

if not self._is_square:
raise NotImplementedError(
"`Image.legacy_whiten` is not currently implemented for non-square images."
)

n = self.n_images
L = self.resolution
L_half = L // 2
Expand Down Expand Up @@ -607,6 +618,7 @@ def filter(self, filter):
:param filter: An object of type `Filter`.
:return: A new filtered `Image` object.
"""

original_stack_shape = self.stack_shape

im = self.stack_reshape(-1)
Expand All @@ -619,11 +631,16 @@ def filter(self, filter):
# upcast both here for most accurate convolution.
filter_values = xp.asarray(
filter.evaluate_grid(
self.resolution, dtype=np.float64, pixel_size=self.pixel_size
self.shape[-2:], dtype=np.float64, pixel_size=self.pixel_size
),
dtype=np.float64,
)

# sanity check
assert (
filter_values.shape == im._data.shape[-2:]
), f"{filter_values.shape} != {im._data.shape[:-2]}"

# Convolve
_im = xp.asarray(im._data, dtype=np.float64)
im_f = fft.centered_fft2(_im)
Expand Down Expand Up @@ -777,8 +794,8 @@ def _load_raw(filepath, dtype=None):

return im, pixel_size

@staticmethod
def load(filepath, dtype=None):
@classmethod
def load(cls, filepath, dtype=None):
"""
Load raw data from supported files.

Expand All @@ -793,7 +810,7 @@ def load(filepath, dtype=None):
im, pixel_size = Image._load_raw(filepath, dtype=dtype)

# Return as Image instance
return Image(im, pixel_size=pixel_size)
return cls(im, pixel_size=pixel_size)

def _im_translate(self, shifts):
"""
Expand All @@ -809,6 +826,10 @@ def _im_translate(self, shifts):
Alternatively, it can be a row vector of length 2, in which case the same shifts is applied to each image.
:return: The images translated by the shifts, with periodic boundaries.
"""
if not self._is_square:
raise NotImplementedError(
"`Image._im_translate` is not currently implemented for non-square images."
)

if shifts.ndim == 1:
shifts = shifts[np.newaxis, :]
Expand Down Expand Up @@ -879,6 +900,10 @@ def backproject(self, rot_matrices, symmetry_group=None, zero_nyquist=True):

:return: Volume instance corresonding to the backprojected images.
"""
if not self._is_square:
raise NotImplementedError(
"`Image.legacy_whiten` is not currently implemented for non-square images."
)

if self.stack_ndim > 1:
raise NotImplementedError(
Expand Down Expand Up @@ -996,6 +1021,10 @@ def frc(self, other, cutoff=None, method="fft", plot=False):
where `estimated_resolution` is in angstrom
and FRC is a Numpy array of correlations.
"""
if not self._is_square:
raise NotImplementedError(
"`Image.frc` is not currently implemented for non-square images."
)

if not isinstance(other, Image):
raise TypeError(
Expand Down
3 changes: 2 additions & 1 deletion src/aspire/operators/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs):

Passes arbritrary args and kwargs down to self.evaluate method.

:param L: Number of grid points (L by L).
:param L: Number of grid points.
L-by-L given a single `L`, or (L0, L1) if L is length 2.
:param dtype: dtype of grid, defaults np.float32.
:return: Filter values at omega's points.
"""
Expand Down
113 changes: 80 additions & 33 deletions src/aspire/source/micrograph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import warnings
from abc import ABC, abstractmethod
from glob import glob
from pathlib import Path
Expand All @@ -10,7 +11,13 @@
from aspire.source import Simulation
from aspire.source.image import _ImageAccessor
from aspire.storage import StarFile
from aspire.utils import Random, check_pixel_size, grid_2d, rename_with_timestamp
from aspire.utils import (
Random,
check_pixel_size,
grid_2d,
rename_with_timestamp,
trange,
)
from aspire.volume import Volume

logger = logging.getLogger(__name__)
Expand All @@ -20,7 +27,12 @@ class MicrographSource(ABC):
def __init__(self, micrograph_count, micrograph_size, dtype, pixel_size=None):
""" """
self.micrograph_count = int(micrograph_count)
self.micrograph_size = int(micrograph_size)
# Expand single integer to 2-tuple
if isinstance(micrograph_size, int):
micrograph_size = (micrograph_size,) * 2
if len(micrograph_size) != 2:
raise ValueError("`micrograph_size` should be a integer or 2-tuple")
self.micrograph_size = tuple(micrograph_size)
self.dtype = np.dtype(dtype)
if pixel_size is not None:
pixel_size = float(pixel_size)
Expand All @@ -34,7 +46,7 @@ def __repr__(self):

:return: Returns a string description of instance.
"""
return f"{self.__class__.__name__} with {self.micrograph_count} {self.dtype.name} micrographs of size {self.micrograph_size}x{self.micrograph_size}"
return f"{self.__class__.__name__} with {self.micrograph_count} {self.dtype.name} micrographs of size {self.micrograph_size}"

def __len__(self):
"""
Expand Down Expand Up @@ -115,6 +127,42 @@ def _images(self, indices):
:return: An `Image` object representing the micrographs for `indices`.
"""

def phase_flip(self, filters):
"""
Perform phase flip on micrographs in the source object using CTF information.
If no CTFFilters exist this will emit a warning and otherwise no-op.
"""

logger.info("Perform phase flip on source object")
filters = list(filters) # unpack any generators

if len(filters) >= 1:
assert len(filters) == self.micrograph_count

logger.info("Phaseflipping")
phase_flipped_micrographs = np.empty(
(self.micrograph_count, *self.micrograph_size), dtype=self.dtype
)
for i in trange(self.micrograph_count, desc="Phaseflipping micrograph"):
# micrograph = self.images[i]
# f = filters[i].sign
# ... = micrograph.filter(f)
phase_flipped_micrographs[i] = self.images[i].filter(filters[i].sign)

return ArrayMicrographSource(
micrographs=phase_flipped_micrographs, pixel_size=self.pixel_size
)

else:
# No CTF filters found
logger.warning(
"No Filters found."
" `phase_flip` is a no-op without Filters."
" Confirm you have correctly populated CTFFilters."
)

return self


class ArrayMicrographSource(MicrographSource):
def __init__(self, micrographs, dtype=None, pixel_size=None):
Expand Down Expand Up @@ -142,14 +190,14 @@ def __init__(self, micrographs, dtype=None, pixel_size=None):
if micrographs.ndim == 2:
micrographs = micrographs[None, :, :]

if micrographs.ndim != 3 or (micrographs.shape[-2] != micrographs.shape[-1]):
if micrographs.ndim != 3:
raise NotImplementedError(
f"Incompatible `micrographs` shape {micrographs.shape}, expects (count, L, L)"
f"Incompatible `micrographs` shape {micrographs.shape}, expects 2D or 3D array."
)

super().__init__(
micrograph_count=micrographs.shape[0],
micrograph_size=micrographs.shape[-1],
micrograph_size=micrographs.shape[-2:],
dtype=dtype or micrographs.dtype,
pixel_size=pixel_size,
)
Expand Down Expand Up @@ -201,15 +249,15 @@ def __init__(self, micrographs_path, dtype=None, pixel_size=None):

# Load the first micrograph to infer shape/type
# Size will be checked during on-the-fly loading of subsequent micrographs.
micrograph0 = Image.load(self.micrograph_files[0])
if micrograph0.pixel_size is not None and micrograph0.pixel_size != pixel_size:
raise ValueError(
f"Mismatched pixel size. {micrograph0.pixel_size} angstroms defined in {self.micrograph_files[0]}, but provided {pixel_size} angstroms."
)
micrograph0, _pixel_size = Image._load_raw(self.micrograph_files[0])
# Compare with user provided pixel size
if pixel_size is not None and _pixel_size != pixel_size:
msg = f"Mismatched pixel size. {_pixel_size} angstroms defined in {self.micrograph_files[0]}, but provided {pixel_size} angstroms."
warnings.warn(msg, UserWarning, stacklevel=2)

super().__init__(
micrograph_count=len(self.micrograph_files),
micrograph_size=micrograph0.resolution,
micrograph_size=micrograph0.shape[-2:],
dtype=dtype or micrograph0.dtype,
pixel_size=pixel_size,
)
Expand Down Expand Up @@ -265,28 +313,27 @@ def _images(self, indices):
# Initialize empty result
n_micrographs = len(indices)
micrographs = np.empty(
(n_micrographs, self.micrograph_size, self.micrograph_size),
(n_micrographs, *self.micrograph_size),
dtype=self.dtype,
)
for i, ind in enumerate(indices):
# Load the micrograph image from file
micrograph = Image.load(self.micrograph_files[ind])
micrograph, _pixel_size = Image._load_raw(self.micrograph_files[ind])

# Assert size
if micrograph.resolution != self.micrograph_size:
if micrograph.shape != self.micrograph_size:
raise NotImplementedError(
f"Micrograph {ind} has inconsistent shape {micrograph.shape},"
f" expected {(self.micrograph_size, self.micrograph_size)}."
f" expected {self.micrograph_size}."
)

# Continually compare with initial pixel_size
if _pixel_size is not None and _pixel_size != self.pixel_size:
msg = f"Mismatched pixel size. {_pixel_size} angstroms defined in {self.micrograph_files[ind]}, but provided {self.pixel_size} angstroms."
warnings.warn(msg, UserWarning, stacklevel=2)

# Assign to array, implicitly performs casting to dtype
micrographs[i] = micrograph.asnumpy()
# Assert pixel_size
if (
micrograph.pixel_size is not None
and micrograph.pixel_size != self.pixel_size
):
raise ValueError(
f"Mismatched pixel size. {micrograph.pixel_size} angstroms defined in {self.micrograph_files[ind]}, but provided {self.pixel_size} angstroms."
)
micrographs[i] = micrograph

return Image(micrographs, pixel_size=self.pixel_size)

Expand Down Expand Up @@ -314,7 +361,7 @@ def __init__(

:param volume: `Volume` instance to be used in `Simulation`.
An `(L,L,L)` `Volume` will generate `(L,L)` particle images.
:param micrograph_size: Size of micrograph in pixels, defaults to 4096.
:param micrograph_size: Size of micrograph in pixels as integer or 2-tuple. Defaults to 4096.
:param micrograph_count: Number of micrographs to generate (integer). Defaults to 1.
:param particles_per_micrograph: The amount of particles generated for each micrograph. Defaults to 10.
:param particle_amplitudes: Optional, amplitudes to pass to `Simulation`.
Expand Down Expand Up @@ -366,7 +413,7 @@ def __init__(

self.noise_adder = noise_adder

if self.particle_box_size > micrograph_size:
if self.particle_box_size > max(self.micrograph_size):
raise ValueError(
"The micrograph size must be larger or equal to the `particle_box_size`."
)
Expand Down Expand Up @@ -428,7 +475,7 @@ def __init__(
else:
if (
boundary < (-self.particle_box_size // 2)
or boundary > self.micrograph_size // 2
or boundary > max(self.micrograph_size) // 2
):
raise ValueError("Illegal boundary value.")
self.boundary = boundary
Expand Down Expand Up @@ -518,8 +565,8 @@ def _set_mask(self):
"""
self._mask = np.full(
(
int(self.micrograph_size + 2 * self.pad),
int(self.micrograph_size + 2 * self.pad),
int(self.micrograph_size[0] + 2 * self.pad),
int(self.micrograph_size[1] + 2 * self.pad),
),
False,
dtype=bool,
Expand Down Expand Up @@ -560,7 +607,7 @@ def _clean_images(self, indices):
# Initialize empty micrograph
n_micrographs = len(indices)
clean_micrograph = np.zeros(
(n_micrographs, self.micrograph_size, self.micrograph_size),
(n_micrographs, *self.micrograph_size),
dtype=self.dtype,
)
# Pad the micrograph
Expand Down Expand Up @@ -592,8 +639,8 @@ def _clean_images(self, indices):
)
clean_micrograph = clean_micrograph[
:,
self.pad : self.micrograph_size + self.pad,
self.pad : self.micrograph_size + self.pad,
self.pad : self.micrograph_size[0] + self.pad,
self.pad : self.micrograph_size[1] + self.pad,
]
return Image(clean_micrograph, pixel_size=self.pixel_size)

Expand Down
Loading