diff --git a/.bumpversion.cfg b/.bumpversion.cfg index dc9339ea5b..9d6eb705dd 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.14.1 +current_version = 0.14.2 commit = True tag = True diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index 4eae25111b..2bf5a462dc 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -156,7 +156,7 @@ jobs: shell: bash -el {0} strategy: matrix: - os: [ubuntu-latest, ubuntu-22.04, macOS-latest, macOS-13] + os: [ubuntu-latest, ubuntu-22.04, macOS-latest, macOS-14] backend: [default, openblas] python-version: ['3.9'] include: diff --git a/README.md b/README.md index 569fa69c62..a6f6ee4a80 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5657281.svg)](https://doi.org/10.5281/zenodo.5657281) [![Downloads](https://static.pepy.tech/badge/aspire/month)](https://pepy.tech/project/aspire) -# ASPIRE - Algorithms for Single Particle Reconstruction - v0.14.1 +# ASPIRE - Algorithms for Single Particle Reconstruction - v0.14.2 The ASPIRE-Python project supersedes [Matlab ASPIRE](https://github.com/PrincetonUniversity/aspire). @@ -20,7 +20,7 @@ For more information about the project, algorithms, and related publications ple Please cite using the following DOI. This DOI represents all versions, and will always resolve to the latest one. ``` -ComputationalCryoEM/ASPIRE-Python: v0.14.1 https://doi.org/10.5281/zenodo.5657281 +ComputationalCryoEM/ASPIRE-Python: v0.14.2 https://doi.org/10.5281/zenodo.5657281 ``` diff --git a/docs/source/conf.py b/docs/source/conf.py index fd203b75de..5faee47e40 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -86,7 +86,7 @@ # built documents. # # The full version, including alpha/beta/rc tags. -release = version = "0.14.1" +release = version = "0.14.2" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/source/index.rst b/docs/source/index.rst index a7ce8b3783..6154ac327b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,4 +1,4 @@ -Aspire v0.14.1 +Aspire v0.14.2 ============== Algorithms for Single Particle Reconstruction diff --git a/environment-accelerate.yml b/environment-accelerate.yml index 5248c3a953..5738909d7d 100644 --- a/environment-accelerate.yml +++ b/environment-accelerate.yml @@ -7,8 +7,8 @@ channels: dependencies: - pip - python=3.9 - - numpy=1.24.1 - - scipy=1.10.1 + - numpy=1.25.0 + - scipy=1.13.1 - scikit-learn - scikit-image - libblas=*=*accelerate diff --git a/environment-default.yml b/environment-default.yml index f827ec55bc..2b1d542fd6 100644 --- a/environment-default.yml +++ b/environment-default.yml @@ -7,7 +7,7 @@ channels: dependencies: - pip - python=3.9 - - numpy=1.23.5 - - scipy=1.9.3 + - numpy=1.25.0 + - scipy=1.13.1 - scikit-learn - scikit-image diff --git a/environment-intel.yml b/environment-intel.yml index 840dd92f76..3295ff934a 100644 --- a/environment-intel.yml +++ b/environment-intel.yml @@ -7,8 +7,8 @@ channels: dependencies: - pip - python=3.9 - - numpy=1.23.5 - - scipy=1.9.3 + - numpy=1.25.0 + - scipy=1.13.1 - scikit-learn - scikit-image - mkl_fft diff --git a/environment-openblas.yml b/environment-openblas.yml index 088035f88d..04f80266be 100644 --- a/environment-openblas.yml +++ b/environment-openblas.yml @@ -7,8 +7,8 @@ channels: dependencies: - pip - python=3.9 - - numpy=1.23.5 - - scipy=1.9.3 + - numpy=1.25.0 + - scipy=1.13.1 - scikit-learn - scikit-image - libblas=*=*openblas diff --git a/environment-win64.yml b/environment-win64.yml index 34ca5d9fa6..ea9ed840f6 100644 --- a/environment-win64.yml +++ b/environment-win64.yml @@ -7,8 +7,8 @@ channels: dependencies: - pip - python=3.9 - - numpy=1.23.5 - - scipy=1.9.3 + - numpy=1.25.0 + - scipy=1.13.1 - scikit-learn - scikit-image - mkl=2024.1.* # possible regression impacts eig solver in later versions up to 2025.0 diff --git a/gallery/experiments/save_simulation_relion_reconstruct.py b/gallery/experiments/save_simulation_relion_reconstruct.py new file mode 100644 index 0000000000..e52f6bc2a8 --- /dev/null +++ b/gallery/experiments/save_simulation_relion_reconstruct.py @@ -0,0 +1,80 @@ +""" +Simulated Stack to RELION Reconstruction +======================================== + +This experiment shows how to: + +1. build a synthetic dataset with ASPIRE, +2. write the stack via ``ImageSource.save`` so RELION can consume it, and +3. call :code:`relion_reconstruct` on the saved STAR file. +""" + +# %% +# Imports +# ------- + +import logging +from pathlib import Path + +import numpy as np + +from aspire.downloader import emdb_2660 +from aspire.noise import WhiteNoiseAdder +from aspire.operators import RadialCTFFilter +from aspire.source import Simulation + +logger = logging.getLogger(__name__) + + +# %% +# Configuration +# ------------- +# We set a few parameters to initialize the Simulation. +# You can safely alter ``n_particles`` (or change the defocus values, etc.) when +# trying this interactively; the defaults here are chosen for demonstrative purposes. + +output_dir = Path("relion_save_demo") +output_dir.mkdir(exist_ok=True) + +n_particles = 512 +snr = 0.5 +defocus = np.linspace( + 15000, 25000, 7 +) # defocus values for the radial CTF filters (angstroms) +star_file = f"sim_n{n_particles}.star" +star_path = output_dir / star_file + +# %% +# Volume and Filters +# ------------------ +# Start from the EMDB-2660 ribosome map and build a small set of radial CTF filters +# that RELION will recover as optics groups. + +vol = emdb_2660() +ctf_filters = [RadialCTFFilter(defocus=d) for d in defocus] + + +# %% +# Simulate, Add Noise, Save +# ------------------------- +# Initialize the Simulation: +# mix the CTFs across the stack, add white noise at a target SNR, +# and write the particles and metadata to a RELION-compatible STAR/MRC stack. + +sim = Simulation( + n=n_particles, + vols=vol, + unique_filters=ctf_filters, + noise_adder=WhiteNoiseAdder.from_snr(snr), +) +sim.save(star_path, overwrite=True) + + +# %% +# Running ``relion_reconstruct`` +# ------------------------------ +# ``relion_reconstruct`` is an external RELION command, so we just show the call. +# Run this, from the output directory, in a RELION-enabled shell after generating +# the STAR file above. + +logger.info(f"relion_reconstruct --i {star_file} " f"--o 'relion_recon.mrc' --ctf") diff --git a/gallery/tutorials/tutorials/class_averaging.py b/gallery/tutorials/tutorials/class_averaging.py index 7108a97846..063e10c1e0 100644 --- a/gallery/tutorials/tutorials/class_averaging.py +++ b/gallery/tutorials/tutorials/class_averaging.py @@ -226,6 +226,7 @@ est_shifts = avgs.averager.shifts est_dot_products = avgs.averager.dot_products +# These are dictionaries mapping each class to arrays of attributes. print(f"Estimated Rotations: {est_rotations}") print(f"Estimated Shifts: {est_shifts}") print(f"Estimated Dot Products: {est_dot_products}") @@ -241,7 +242,12 @@ original_img_nbr = noisy_src.images[original_img_nbr_idx].asnumpy()[0] # Rotate using estimated rotations. -angle = est_rotations[0, nbr] * 180 / np.pi +# First retrieve all angles for the `review_class` (original_img_0_idx), +# then lookup the specific neighbor `nbr` +assert ( + original_img_0_idx == review_class +), "DebugClassAvgSource should retain original source image ordering" +angle = est_rotations[original_img_0_idx][nbr] * 180 / np.pi if reflections[nbr]: print("Reflection reported.") original_img_nbr = np.flipud(original_img_nbr) diff --git a/pyproject.toml b/pyproject.toml index 8a7696dcb9..198a3bfb06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "aspire" -version = "0.14.1" +version = "0.14.2" description = "Algorithms for Single Particle Reconstruction" readme = "README.md" # Optional requires-python = ">=3.9" @@ -37,7 +37,7 @@ dependencies = [ "joblib", "matplotlib >= 3.2.0", "mrcfile", - "numpy>=1.21.5", + "numpy>=1.25.0", "packaging", "pooch>=1.7.0", "pillow", diff --git a/src/aspire/__init__.py b/src/aspire/__init__.py index 3b77b9cb67..d7c9359a0e 100644 --- a/src/aspire/__init__.py +++ b/src/aspire/__init__.py @@ -15,7 +15,7 @@ from aspire.exceptions import handle_exception # version in maj.min.bld format -__version__ = "0.14.1" +__version__ = "0.14.2" # Setup `confuse` config diff --git a/src/aspire/abinitio/J_sync.py b/src/aspire/abinitio/J_sync.py new file mode 100644 index 0000000000..b8458d9807 --- /dev/null +++ b/src/aspire/abinitio/J_sync.py @@ -0,0 +1,247 @@ +import logging + +import numpy as np +from numpy.linalg import norm + +from aspire.utils import J_conjugate, all_pairs, all_triplets, tqdm +from aspire.utils.random import randn + +logger = logging.getLogger(__name__) + + +class JSync: + """ + Class for handling J-synchronization methods. + """ + + def __init__( + self, + n, + epsilon=1e-2, + max_iters=1000, + seed=None, + ): + """ + Initialize JSync object for estimating global handedness synchronization for a + set of relative rotations, Rij = Ri @ Rj.T, where i <= j = 0, 1, ..., n. + + :param n: Number of images/rotations. + :param epsilon: Tolerance for the power method. + :param max_iters: Maximum iterations for the power method. + :param seed: Optional seed for power method initial random vector. + """ + self.n_img = n + self.epsilon = epsilon + self.max_iters = max_iters + self.seed = seed + + def global_J_sync(self, vijs): + """ + Global J-synchronization of all third row outer products. Given 3x3 matrices vijs, each + of which might contain a spurious J (ie. vij = J*vi*vj^T*J instead of vij = vi*vj^T), + we return vijs that all have either a spurious J or not. + + :param vijs: An (n-choose-2)x3x3 array where each 3x3 slice holds an estimate for the corresponding + outer-product vi*vj^T between the third rows of the rotation matrices Ri and Rj. Each estimate + might have a spurious J independently of other estimates. + + :return: vijs, all of which have a spurious J or not. + """ + + # Determine relative handedness of vijs. + sign_ij_J = self.power_method(vijs) + + # Synchronize vijs + vijs_sync = vijs.copy() + for i, sign in enumerate(sign_ij_J): + if sign == -1: + vijs_sync[i] = J_conjugate(vijs[i]) + + return vijs_sync + + def power_method(self, vijs): + """ + Calculate the leading eigenvector of the J-synchronization matrix + using the power method. + + As the J-synchronization matrix is of size (n-choose-2)x(n-choose-2), we + use the power method to compute the eigenvalues and eigenvectors, + while constructing the matrix on-the-fly. + + :param vijs: (n-choose-2)x3x3 array of estimates of relative orientation matrices. + + :return: An array of length n-choose-2 consisting of 1 or -1, where the sign of the + i'th entry indicates whether the i'th relative orientation matrix will be J-conjugated. + """ + + # Set power method tolerance and maximum iterations. + epsilon = self.epsilon + max_iters = self.max_iters + + # Initialize candidate eigenvectors + n_vijs = vijs.shape[0] + vec = randn(n_vijs, seed=self.seed) + vec = vec / norm(vec) + residual = 1 + itr = 0 + + # Power method iterations + logger.info( + "Initiating power method to estimate J-synchronization matrix eigenvector." + ) + while itr < max_iters and residual > epsilon: + itr += 1 + # Note, this appears to need double precision for accuracy in the following division. + vec_new = self._signs_times_v(vijs, vec).astype(np.float64, copy=False) + vec_new = vec_new / norm(vec_new) + residual = norm(vec_new - vec) + vec = vec_new + logger.info( + f"Iteration {itr}, residual {round(residual, 5)} (target {epsilon})" + ) + + # We need only the signs of the eigenvector + J_sync = np.sign(vec, dtype=vijs.dtype) + + return J_sync + + def sync_viis(self, vijs, viis): + """ + Given a set of synchronized pairwise outer products vijs, J-synchronize the set of + outer products viis. + + :param vijs: An (n-choose-2)x3x3 array where each 3x3 slice holds an estimate for the corresponding + outer-product vi*vj^T between the third rows of the rotation matrices Ri and Rj. Each estimate + might have a spurious J independently of other estimates. + + :param viis: An n_imgx3x3 array where the i'th slice holds an estimate for the outer product vi*vi^T + between the third row of matrix Ri and itself. Each estimate might have a spurious J independently + of other estimates. + + :return: J-synchronized viis. + """ + + # Synchronize viis + # We use the fact that if v_ii and v_ij are of the same handedness, then v_ii @ v_ij = v_ij. + # If they are opposite handed then Jv_iiJ @ v_ij = v_ij. We compare each v_ii against all + # previously synchronized v_ij to get a consensus on the handedness of v_ii. + _, pairs_to_linear = all_pairs(self.n_img, return_map=True) + for i in range(self.n_img): + vii = viis[i] + vii_J = J_conjugate(vii) + J_consensus = 0 + for j in range(self.n_img): + if j < i: + idx = pairs_to_linear[j, i] + vji = vijs[idx] + + err1 = norm(vji @ vii - vji) + err2 = norm(vji @ vii_J - vji) + + elif j > i: + idx = pairs_to_linear[i, j] + vij = vijs[idx] + + err1 = norm(vii @ vij - vij) + err2 = norm(vii_J @ vij - vij) + + else: + continue + + # Accumulate J consensus + if err1 < err2: + J_consensus -= 1 + else: + J_consensus += 1 + + if J_consensus > 0: + viis[i] = vii_J + return viis + + def _signs_times_v(self, vijs, vec): + """ + Multiplication of the J-synchronization matrix by a candidate eigenvector. + + The J-synchronization matrix is a matrix representation of the handedness graph, Gamma, whose set of + nodes consists of the estimates vijs and whose set of edges consists of the undirected edges between + all triplets of estimates vij, vjk, and vik, where i i", pf_i, r_max, h) pf_j = self._apply_filter_and_norm("i, i -> i", pf_j, r_max, h) # apply the shifts to images pf_i_flipped = np.conj(pf_i) - pf_i_stack = np.einsum("i, ji -> ij", pf_i, shift_phases) - pf_i_flipped_stack = np.einsum("i, ji -> ij", pf_i_flipped, shift_phases) + pf_i_stack = pf_i[:, None] * shift_phases.T + pf_i_flipped_stack = pf_i_flipped[:, None] * shift_phases.T - c1 = 2 * np.real(np.dot(np.conj(pf_i_stack.T), pf_j)) - c2 = 2 * np.real(np.dot(np.conj(pf_i_flipped_stack.T), pf_j)) + c1 = 2 * np.dot(pf_i_stack.T.conj(), pf_j).real + c2 = 2 * np.dot(pf_i_flipped_stack.T.conj(), pf_j).real # find the indices for the maximum values # and apply corresponding shifts @@ -623,17 +627,25 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): shift_b[shift_eq_idx] = dx # Compute the coefficients of the current equation - coefs = np.array( - [ - np.cos(shift_alpha), - np.sin(shift_alpha), - -np.cos(shift_beta), - -np.sin(shift_beta), - ] - ) - shift_eq[shift_eq_idx] = ( - [-1, -1, 0, 0] * coefs if is_pf_j_flipped else coefs - ) + if not is_pf_j_flipped: + shift_eq[shift_eq_idx] = np.array( + [ + np.sin(shift_alpha), + np.cos(shift_alpha), + -np.sin(shift_beta), + -np.cos(shift_beta), + ] + ) + else: + shift_beta = shift_beta - np.pi + shift_eq[shift_eq_idx] = np.array( + [ + -np.sin(shift_alpha), + -np.cos(shift_alpha), + -np.sin(shift_beta), + -np.cos(shift_beta), + ] + ) # create sparse matrix object only containing non-zero elements shift_equations = sparse.csr_matrix( @@ -644,7 +656,7 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): return shift_equations, shift_b - def _estimate_num_shift_equations(self, n_img, equations_factor=1, max_memory=4000): + def _estimate_num_shift_equations(self, n_img): """ Estimate total number of shift equations in images @@ -652,29 +664,24 @@ def _estimate_num_shift_equations(self, n_img, equations_factor=1, max_memory=40 number of images and preselected memory factor. :param n_img: The total number of input images - :param equations_factor: The factor to rescale the number of shift equations - (=1 in default) - :param max_memory: If there are N images and N_check selected to check - for common lines, then the exact system of equations solved for the shifts - is of size 2N x N(N_check-1)/2 (2N unknowns and N(N_check-1)/2 equations). - This may be too big if N is large. The algorithm will use `equations_factor` - times the total number of equations if the resulting total number of - memory requirements is less than `max_memory` (in megabytes); otherwise it - will reduce the number of equations to fit in `max_memory`. :return: Estimated number of shift equations """ # Number of equations that will be used to estimation the shifts n_equations_total = int(np.ceil(n_img * (self.n_check - 1) / 2)) + # Estimated memory requirements for the full system of equation. # This ignores the sparsity of the system, since backslash seems to # ignore it. - memory_total = equations_factor * ( + memory_total = self.offsets_equations_factor * ( n_equations_total * 2 * n_img * self.dtype.itemsize ) - if memory_total < (max_memory * 10**6): - n_equations = int(np.ceil(equations_factor * n_equations_total)) + + if memory_total < (self.offsets_max_memory * 10**6): + n_equations = int( + np.ceil(self.offsets_equations_factor * n_equations_total) + ) else: - subsampling_factor = (max_memory * 10**6) / memory_total + subsampling_factor = (self.offsets_max_memory * 10**6) / memory_total subsampling_factor = min(1.0, subsampling_factor) n_equations = int(np.ceil(n_equations_total * subsampling_factor)) @@ -691,46 +698,13 @@ def _estimate_num_shift_equations(self, n_img, equations_factor=1, max_memory=40 return n_equations - def _generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): - """ - Prepare the shift phases and generate filter for common-line detection - - The shift phases are pre-defined in a range of max_shift that can be - applied to maximize the common line calculation. The common-line filter - is also applied to the radial direction for easier detection. - - :param r_max: Maximum index for common line detection - :param max_shift: Maximum value of 1D shift (in pixels) to search - :param shift_step: Resolution of shift estimation in pixels - :return: shift phases matrix and common lines filter - """ - - # Number of shifts to try - n_shifts = int(np.ceil(2 * max_shift / shift_step + 1)) - - # only half of ray, excluding the DC component. - rk = np.arange(1, r_max + 1, dtype=self.dtype) - - # Generate all shift phases - shifts = -max_shift + shift_step * np.arange(n_shifts, dtype=self.dtype) - shift_phases = np.exp(np.outer(shifts, -2 * np.pi * 1j * rk / (2 * r_max + 1))) - # Set filter for common-line detection - h = np.sqrt(np.abs(rk)) * np.exp(-np.square(rk) / (2 * (r_max / 4) ** 2)) - - return shifts, shift_phases, h - def _generate_index_pairs(self, n_equations): """ Generate two index lists for [i, j] pairs of images """ - idx_i = [] - idx_j = [] - for i in range(self.n_img - 1): - tmp_j = range(i + 1, self.n_img) - idx_i.extend([i] * len(tmp_j)) - idx_j.extend(tmp_j) - idx_i = np.array(idx_i, dtype="int") - idx_j = np.array(idx_j, dtype="int") + + # Generate the i,j tuples of indices representing the upper triangle above the diagonal. + idx_i, idx_j = np.triu_indices(self.n_img, k=1) # Select random pairs based on the size of n_equations rp = choice(np.arange(len(idx_j)), size=n_equations, replace=False) diff --git a/src/aspire/abinitio/commonline_c2.py b/src/aspire/abinitio/commonline_c2.py index 6bf6cf99ed..febd63951a 100644 --- a/src/aspire/abinitio/commonline_c2.py +++ b/src/aspire/abinitio/commonline_c2.py @@ -3,13 +3,20 @@ import numpy as np from scipy.linalg import eigh -from aspire.abinitio import CLSymmetryC3C4 +from aspire.abinitio import CLOrient3D, JSync +from aspire.abinitio.sync_voting import _syncmatrix_ij_vote_3n from aspire.utils import J_conjugate, Rotation, all_pairs +from .commonline_utils import ( + _complete_third_row_to_rot, + _estimate_third_rows, + _generate_shift_phase_and_filter, +) + logger = logging.getLogger(__name__) -class CLSymmetryC2(CLSymmetryC3C4): +class CLSymmetryC2(CLOrient3D): """ Define a class to estimate 3D orientations using common lines methods for molecules with C2 cyclic symmetry. @@ -40,10 +47,10 @@ def __init__( shift_step=1, epsilon=1e-3, max_iters=1000, - degree_res=1, min_dist_cls=25, seed=None, mask=True, + **kwargs, ): """ Initialize object for estimating 3D orientations for molecules with C2 symmetry. @@ -54,40 +61,30 @@ def __init__( :param max_shift: Maximum range for shifts as a proportion of resolution. Default = 0.15. :param shift_step: Resolution of shift estimation in pixels. Default = 1 pixel. :param epsilon: Tolerance for the power method. - :param max_iter: Maximum iterations for the power method. - :param degree_res: Degree resolution for estimating in-plane rotations. + :param max_iters: Maximum iterations for the power method. :param min_dist_cls: Minimum distance between mutual common-lines. Default = 25 degrees. :param seed: Optional seed for RNG. :param mask: Option to mask `src.images` with a fuzzy mask (boolean). Default, `True`, applies a mask. """ + super().__init__( src, - symmetry="C2", n_rad=n_rad, n_theta=n_theta, max_shift=max_shift, shift_step=shift_step, - epsilon=epsilon, - max_iters=max_iters, - degree_res=degree_res, - seed=seed, mask=mask, + **kwargs, ) self.min_dist_cls = min_dist_cls self.epsilon = epsilon self.max_iters = max_iters - self.degree_res = degree_res self.seed = seed self.order = 2 - def _check_symmetry(self, symmetry): - symmetry = symmetry.upper() - if symmetry != "C2": - raise NotImplementedError( - f"Only C2 symmetry supported. {symmetry} was supplied." - ) + self.J_sync = JSync(src.n, self.epsilon, self.max_iters, self.seed) def build_clmatrix(self): """ @@ -114,8 +111,8 @@ def build_clmatrix(self): # Prepare the shift phases and generate filter for common-line detection. r_max = pf.shape[2] - shifts, shift_phases, h = self._generate_shift_phase_and_filter( - r_max, self.max_shift, self.shift_step + shifts, shift_phases, h = _generate_shift_phase_and_filter( + r_max, self.max_shift, self.shift_step, self.dtype ) n_shifts = len(shifts) @@ -224,7 +221,7 @@ def estimate_rotations(self): viis = np.vstack((np.eye(3, dtype=self.dtype),) * self.n_img).reshape( self.n_img, 3, 3 ) - vis = self._estimate_third_rows(vijs, viis) + vis = _estimate_third_rows(vijs, viis) logger.info("Estimating in-plane rotations and rotations matrices.") Ris = self._estimate_inplane_rotations(vis, Rijs, Rijgs) @@ -247,7 +244,7 @@ def _estimate_relative_viewing_directions(self): self.build_clmatrix() # Step 2: Calculate relative rotations associated with both mutual common lines. - Rijs, Rijgs = self._estimate_all_Rijs_c2(self.clmatrix) + Rijs, Rijgs = self._estimate_all_Rijs_c2() # Step 3: Inner J-synchronization Rijs, Rijgs = self._local_J_sync_c2(Rijs, Rijgs) @@ -276,7 +273,7 @@ def _global_J_sync(self, Rijs, Rijgs): vijs = (Rijs + Rijgs) / 2 # Determine relative handedness of vijs. - sign_ij_J = self._J_sync_power_method(vijs) + sign_ij_J = self.J_sync.power_method(vijs) # Synchronize relative rotations for i, sign in enumerate(sign_ij_J): @@ -301,7 +298,7 @@ def _estimate_inplane_rotations(self, vis, Rijs, Rijgs): H = np.zeros((self.n_img, self.n_img), dtype=complex) # Step 1: Construct all rotation matrices Ris_tilde whose third rows are equal to # the corresponding third rows vis. - Ris_tilde = self._complete_third_row_to_rot(vis) + Ris_tilde = _complete_third_row_to_rot(vis) pairs = all_pairs(self.n_img) for idx, (i, j) in enumerate(pairs): @@ -344,23 +341,36 @@ def _estimate_inplane_rotations(self, vis, Rijs, Rijgs): # Secondary Methods for computing outer product # ################################################# - def _estimate_all_Rijs_c2(self, clmatrix): + def _estimate_all_Rijs_c2(self): """ Estimate the two sets of relative rotations, Rijs and Rijgs, between pairs of images using the voting method. - :param clmatrix: 2 x n_img x n_img array holding two sets of mutual common-lines - between pairs of images. :return: Relative rotations, Rijs and Rijgs. """ k_list = np.arange(self.n_img) - n_theta = self.n_theta pairs = all_pairs(self.n_img) Rijs = np.zeros((len(pairs), 3, 3), dtype=self.dtype) Rijgs = np.zeros((len(pairs), 3, 3), dtype=self.dtype) for idx, (i, j) in enumerate(pairs): - Rijs[idx] = self._syncmatrix_ij_vote_3n(clmatrix[0], i, j, k_list, n_theta) - Rijgs[idx] = self._syncmatrix_ij_vote_3n(clmatrix[1], i, j, k_list, n_theta) + Rijs[idx] = _syncmatrix_ij_vote_3n( + self.clmatrix[0], + i, + j, + k_list, + self.n_theta, + self.hist_bin_width, + self.full_width, + ) + Rijgs[idx] = _syncmatrix_ij_vote_3n( + self.clmatrix[1], + i, + j, + k_list, + self.n_theta, + self.hist_bin_width, + self.full_width, + ) return Rijs, Rijgs diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index 0170d8b88f..34d0ee4763 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -1,26 +1,23 @@ import logging import numpy as np -from numpy.linalg import eigh, norm, svd +from numpy.linalg import norm, svd -from aspire.abinitio import CLOrient3D, SyncVotingMixin +from aspire.abinitio import CLOrient3D, JSync +from aspire.abinitio.sync_voting import _syncmatrix_ij_vote_3n from aspire.operators import PolarFT -from aspire.utils import ( - J_conjugate, - Rotation, - all_pairs, - all_triplets, - anorm, - cyclic_rotations, - tqdm, - trange, +from aspire.utils import J_conjugate, Rotation, all_pairs, anorm, trange + +from .commonline_utils import ( + _estimate_inplane_rotations, + _estimate_third_rows, + _generate_shift_phase_and_filter, ) -from aspire.utils.random import randn logger = logging.getLogger(__name__) -class CLSymmetryC3C4(CLOrient3D, SyncVotingMixin): +class CLSymmetryC3C4(CLOrient3D): """ Define a class to estimate 3D orientations using common lines methods for molecules with C3 and C4 cyclic symmetry. @@ -52,6 +49,7 @@ def __init__( degree_res=1, seed=None, mask=True, + **kwargs, ): """ Initialize object for estimating 3D orientations for molecules with C3 and C4 symmetry. @@ -77,6 +75,7 @@ def __init__( max_shift=max_shift, shift_step=shift_step, mask=mask, + **kwargs, ) self._check_symmetry(symmetry) @@ -85,6 +84,8 @@ def __init__( self.degree_res = degree_res self.seed = seed + self.J_sync = JSync(src.n, self.epsilon, self.max_iters, self.seed) + def _check_symmetry(self, symmetry): if symmetry is None: raise NotImplementedError( @@ -110,10 +111,17 @@ def estimate_rotations(self): vijs, viis = self._global_J_sync(vijs, viis) logger.info("Estimating third rows of rotation matrices.") - vis = self._estimate_third_rows(vijs, viis) + vis = _estimate_third_rows(vijs, viis) logger.info("Estimating in-plane rotations and rotations matrices.") - Ris = self._estimate_inplane_rotations(vis) + Ris = _estimate_inplane_rotations( + vis, + self.pf, + self.max_shift, + self.shift_step, + self.order, + self.degree_res, + ) self.rotations = Ris @@ -139,7 +147,7 @@ def _estimate_relative_viewing_directions(self): Riis = self._estimate_all_Riis_c3_c4(sclmatrix) # Step 4: Calculate relative rotations - Rijs = self._estimate_all_Rijs_c3_c4(self.clmatrix) + Rijs = self._estimate_all_Rijs_c3_c4() # Step 5: Inner J-synchronization vijs, viis = self._local_J_sync_c3_c4(Rijs, Riis) @@ -162,235 +170,14 @@ def _global_J_sync(self, vijs, viis): :return: vijs, viis all of which have a spurious J or not. """ - n_img = self.n_img # Determine relative handedness of vijs. - sign_ij_J = self._J_sync_power_method(vijs) - - # Synchronize vijs - for i, sign in enumerate(sign_ij_J): - if sign == -1: - vijs[i] = J_conjugate(vijs[i]) - - # Synchronize viis - # We use the fact that if v_ii and v_ij are of the same handedness, then v_ii @ v_ij = v_ij. - # If they are opposite handed then Jv_iiJ @ v_ij = v_ij. We compare each v_ii against all - # previously synchronized v_ij to get a consensus on the handedness of v_ii. - _, pairs_to_linear = all_pairs(n_img, return_map=True) - for i in range(n_img): - vii = viis[i] - vii_J = J_conjugate(vii) - J_consensus = 0 - for j in range(n_img): - if j < i: - idx = pairs_to_linear[j, i] - vji = vijs[idx] - - err1 = norm(vji @ vii - vji) - err2 = norm(vji @ vii_J - vji) - - elif j > i: - idx = pairs_to_linear[i, j] - vij = vijs[idx] - - err1 = norm(vii @ vij - vij) - err2 = norm(vii_J @ vij - vij) - - else: - continue - - # Accumulate J consensus - if err1 < err2: - J_consensus -= 1 - else: - J_consensus += 1 - - if J_consensus > 0: - viis[i] = vii_J - return vijs, viis - - def _estimate_third_rows(self, vijs, viis): - """ - Find the third row of each rotation matrix given a collection of matrices - representing the outer products of the third rows from each rotation matrix. - - :param vijs: An (n-choose-2)x3x3 array where each 3x3 slice holds the third rows - outer product of the rotation matrices Ri and Rj. - - :param viis: An n_imgx3x3 array where the i'th 3x3 slice holds the outer product of - the third row of Ri with itself. - - :param order: The underlying molecular symmetry. - - :return: vis, An n_imgx3 matrix whose i'th row is the third row of the rotation matrix Ri. - """ - - n_img = self.n_img - - # Build matrix V whose (i,j)-th block of size 3x3 holds the outer product vij - V = np.zeros((n_img, n_img, 3, 3), dtype=vijs.dtype) - - # All pairs (i,j) where i epsilon: - itr += 1 - # Note, this appears to need double precision for accuracy in the following division. - vec_new = self._signs_times_v(vijs, vec).astype(np.float64, copy=False) - vec_new = vec_new / norm(vec_new) - residual = norm(vec_new - vec) - vec = vec_new - logger.info( - f"Iteration {itr}, residual {round(residual, 5)} (target {epsilon})" - ) - - # We need only the signs of the eigenvector - J_sync = np.sign(vec) - - return J_sync - - def _signs_times_v(self, vijs, vec): - """ - Multiplication of the J-synchronization matrix by a candidate eigenvector. - - The J-synchronization matrix is a matrix representation of the handedness graph, Gamma, whose set of - nodes consists of the estimates vijs and whose set of edges consists of the undirected edges between - all triplets of estimates vij, vjk, and vik, where i= 1e-5 - - # If the third row coincides with the z-axis we return the identity matrix. - rots[~mask] = np.eye(3, dtype=r3.dtype) - - # 'norm_12' is non-zero since r3 does not coincide with the z-axis. - norm_12 = np.sqrt(r3[mask, 0] ** 2 + r3[mask, 1] ** 2) - - # Populate 1st rows with vector orthogonal to row 3. - rots[mask, 0, 0] = r3[mask, 1] / norm_12 - rots[mask, 0, 1] = -r3[mask, 0] / norm_12 - - # Populate 2nd rows such that r3 = r1 x r2 - rots[mask, 1, 0] = r3[mask, 0] * r3[mask, 2] / norm_12 - rots[mask, 1, 1] = r3[mask, 1] * r3[mask, 2] / norm_12 - rots[mask, 1, 2] = -norm_12 - - if singleton: - rots = rots.reshape(3, 3) - - return rots - - @staticmethod - def cl_angles_to_ind(cl_angles, n_theta): - thetas = np.arctan2(cl_angles[:, 1], cl_angles[:, 0]) - - # Shift from [-pi,pi] to [0,2*pi). - thetas = np.mod(thetas, 2 * np.pi) - - # linear scale from [0,2*pi) to [0,n_theta). - ind = np.mod(np.round(thetas / (2 * np.pi) * n_theta), n_theta).astype(int) - - # Return scalar for single value. - if ind.size == 1: - ind = ind.flat[0] - - return ind - - @staticmethod - def g_sync(rots, order, rots_gt): - """ - Every estimated rotation might be a version of the ground truth rotation - rotated by g^{s_i}, where s_i = 0, 1, ..., order. This method synchronizes the - ground truth rotations so that only a single global rotation need be applied - to all estimates for error analysis. - - :param rots: Estimated rotation matrices - :param order: The cyclic order asssociated with the symmetry of the underlying molecule. - :param rots_gt: Ground truth rotation matrices. - - :return: g-synchronized ground truth rotations. - """ - assert len(rots) == len( - rots_gt - ), "Number of estimates not equal to number of references." - n_img = len(rots) - dtype = rots.dtype - - rots_symm = cyclic_rotations(order, dtype).matrices - - A_g = np.zeros((n_img, n_img), dtype=complex) - - pairs = all_pairs(n_img) - - for i, j in pairs: - Ri = rots[i] - Rj = rots[j] - Rij = Ri.T @ Rj - - Ri_gt = rots_gt[i] - Rj_gt = rots_gt[j] - - diffs = np.zeros(order) - for s, g_s in enumerate(rots_symm): - Rij_gt = Ri_gt.T @ g_s @ Rj_gt - diffs[s] = min([norm(Rij - Rij_gt), norm(Rij - J_conjugate(Rij_gt))]) - - idx = np.argmin(diffs) - - A_g[i, j] = np.exp(-1j * 2 * np.pi / order * idx) - - # A_g(k,l) is exp(-j(-theta_k+theta_l)) - # Diagonal elements correspond to exp(-i*0) so put 1. - # This is important only for verification purposes that spectrum is (K,0,0,0...,0). - A_g += np.conj(A_g).T + np.eye(n_img) - - _, eig_vecs = eigh(A_g) - leading_eig_vec = eig_vecs[:, -1] - - angles = np.exp(1j * 2 * np.pi / order * np.arange(order)) - rots_gt_sync = np.zeros((n_img, 3, 3), dtype=dtype) - - for i, rot_gt in enumerate(rots_gt): - # Since the closest ccw or cw rotation are just as good, - # we take the absolute value of the angle differences. - angle_dists = np.abs(np.angle(leading_eig_vec[i] / angles)) - power_g_Ri = np.argmin(angle_dists) - rots_gt_sync[i] = rots_symm[power_g_Ri] @ rot_gt - - return rots_gt_sync diff --git a/src/aspire/abinitio/commonline_cn.py b/src/aspire/abinitio/commonline_cn.py index a1440cadc3..9adddb37ac 100644 --- a/src/aspire/abinitio/commonline_cn.py +++ b/src/aspire/abinitio/commonline_cn.py @@ -3,7 +3,7 @@ import numpy as np from numpy.linalg import norm -from aspire.abinitio import CLSymmetryC3C4 +from aspire.abinitio import CLOrient3D, JSync from aspire.operators import PolarFT from aspire.utils import ( J_conjugate, @@ -17,10 +17,18 @@ ) from aspire.utils.random import Random, randn +from .commonline_utils import ( + _cl_angles_to_ind, + _complete_third_row_to_rot, + _estimate_inplane_rotations, + _estimate_third_rows, + _generate_shift_phase_and_filter, +) + logger = logging.getLogger(__name__) -class CLSymmetryCn(CLSymmetryC3C4): +class CLSymmetryCn(CLOrient3D): """ Define a class to estimate 3D orientations using common lines methods for molecules with Cn cyclic symmetry, with n>4. @@ -41,6 +49,7 @@ def __init__( equator_threshold=10, seed=None, mask=True, + **kwargs, ): """ Initialize object for estimating 3D orientations for molecules with Cn symmetry, n>4. @@ -64,21 +73,24 @@ def __init__( super().__init__( src, - symmetry=symmetry, n_rad=n_rad, n_theta=n_theta, max_shift=max_shift, shift_step=shift_step, - epsilon=epsilon, - max_iters=max_iters, - degree_res=degree_res, - seed=seed, mask=mask, + **kwargs, ) + self._check_symmetry(symmetry) + self.epsilon = epsilon + self.max_iters = max_iters + self.degree_res = degree_res + self.seed = seed self.n_points_sphere = n_points_sphere self.equator_threshold = equator_threshold + self.J_sync = JSync(src.n, self.epsilon, self.max_iters, self.seed) + def _check_symmetry(self, symmetry): if symmetry is None: raise NotImplementedError( @@ -100,7 +112,27 @@ def estimate_rotations(self): :return: Array of rotation matrices, size n_imgx3x3. """ - super().estimate_rotations() + vijs, viis = self._estimate_relative_viewing_directions() + + logger.info("Performing global handedness synchronization.") + vijs, viis = self._global_J_sync(vijs, viis) + + logger.info("Estimating third rows of rotation matrices.") + vis = _estimate_third_rows(vijs, viis) + + logger.info("Estimating in-plane rotations and rotations matrices.") + Ris = _estimate_inplane_rotations( + vis, + self.pf, + self.max_shift, + self.shift_step, + self.order, + self.degree_res, + ) + + self.rotations = Ris + + return self.rotations def _estimate_relative_viewing_directions(self): logger.info(f"Estimating relative viewing directions for {self.n_img} images.") @@ -121,8 +153,8 @@ def _estimate_relative_viewing_directions(self): # Generate shift phases. r_max = pf.shape[-1] - shifts, shift_phases, _ = self._generate_shift_phase_and_filter( - r_max, self.max_shift, self.shift_step + shifts, shift_phases, _ = _generate_shift_phase_and_filter( + r_max, self.max_shift, self.shift_step, self.dtype ) n_shifts = len(shifts) @@ -284,6 +316,31 @@ def _compute_cls_inds(self, Ris_tilde, R_theta_ijs): cij_inds[i, j, :, 1] = c2s return cij_inds + def _global_J_sync(self, vijs, viis): + """ + Global J-synchronization of all third row outer products. Given 3x3 matrices vijs and viis, each + of which might contain a spurious J (ie. vij = J*vi*vj^T*J instead of vij = vi*vj^T), + we return vijs and viis that all have either a spurious J or not. + + :param vijs: An (n-choose-2)x3x3 array where each 3x3 slice holds an estimate for the corresponding + outer-product vi*vj^T between the third rows of the rotation matrices Ri and Rj. Each estimate + might have a spurious J independently of other estimates. + + :param viis: An n_imgx3x3 array where the i'th slice holds an estimate for the outer product vi*vi^T + between the third row of matrix Ri and itself. Each estimate might have a spurious J independently + of other estimates. + + :return: vijs, viis all of which have a spurious J or not. + """ + + # Determine relative handedness of vijs. + vijs = self.J_sync.global_J_sync(vijs) + + # Determine relative handedness of viis, given synchronized vijs. + viis = self.J_sync.sync_viis(vijs, viis) + + return vijs, viis + @staticmethod def relative_rots_to_cl_indices(relative_rots, n_theta): """ @@ -298,8 +355,8 @@ def relative_rots_to_cl_indices(relative_rots, n_theta): c1s = np.array((-relative_rots[:, 1, 2], relative_rots[:, 0, 2])).T c2s = np.array((relative_rots[:, 2, 1], -relative_rots[:, 2, 0])).T - c1s = CLSymmetryC3C4.cl_angles_to_ind(c1s, n_theta) - c2s = CLSymmetryC3C4.cl_angles_to_ind(c2s, n_theta) + c1s = _cl_angles_to_ind(c1s, n_theta) + c2s = _cl_angles_to_ind(c2s, n_theta) inds = np.where(c1s >= n_theta // 2) c1s[inds] -= n_theta // 2 @@ -331,7 +388,7 @@ def generate_candidate_rots(n, equator_threshold, order, degree_res, seed): while counter < n: third_row = randn(3) third_row /= anorm(third_row, axes=(-1,)) - Ri_tilde = CLSymmetryC3C4._complete_third_row_to_rot(third_row) + Ri_tilde = _complete_third_row_to_rot(third_row) # Exclude candidates that represent equator images. Equator candidates # induce collinear self-common-lines, which always have perfect correlation. diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 8acbade2ca..e1730bf35e 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -10,6 +10,8 @@ from aspire.utils.random import randn from aspire.volume import DnSymmetryGroup +from .commonline_utils import _generate_shift_phase_and_filter + logger = logging.getLogger(__name__) @@ -37,6 +39,7 @@ def __init__( epsilon=0.01, seed=None, mask=True, + **kwargs, ): """ Initialize object for estimating 3D orientations for molecules with D2 symmetry. @@ -65,6 +68,7 @@ def __init__( max_shift=max_shift, shift_step=shift_step, mask=mask, + **kwargs, ) self.grid_res = grid_res @@ -129,8 +133,8 @@ def _compute_shifted_pf(self): # Generate shift phases. r_max = pf.shape[-1] max_shift_1d = np.ceil(2 * np.sqrt(2) * self.max_shift) - shifts, shift_phases, _ = self._generate_shift_phase_and_filter( - r_max, max_shift_1d, self.shift_step + shifts, shift_phases, _ = _generate_shift_phase_and_filter( + r_max, max_shift_1d, self.shift_step, self.dtype ) self.n_shifts = len(shifts) diff --git a/src/aspire/abinitio/commonline_sync.py b/src/aspire/abinitio/commonline_sync.py index 222173f318..6ed774958b 100644 --- a/src/aspire/abinitio/commonline_sync.py +++ b/src/aspire/abinitio/commonline_sync.py @@ -2,14 +2,15 @@ import numpy as np -from aspire.abinitio import CLOrient3D, SyncVotingMixin +from aspire.abinitio import CLOrient3D +from aspire.abinitio.sync_voting import _rotratio_eulerangle_vec, _vote_ij from aspire.utils import nearest_rotations from aspire.utils.matlab_compat import stable_eigsh logger = logging.getLogger(__name__) -class CLSyncVoting(CLOrient3D, SyncVotingMixin): +class CLSyncVoting(CLOrient3D): """ Define a class to estimate 3D orientations using synchronization matrix and voting method. @@ -33,6 +34,7 @@ def __init__( hist_bin_width=3, full_width=6, mask=True, + **kwargs, ): """ Initialize an object for estimating 3D orientations using synchronization matrix @@ -60,6 +62,7 @@ def __init__( hist_bin_width=hist_bin_width, full_width=full_width, mask=mask, + **kwargs, ) self.syncmatrix = None @@ -199,9 +202,11 @@ def _syncmatrix_ij_vote(self, clmatrix, i, j, k_list, n_theta): :return: The (i,j) rotation block of the synchronization matrix """ - _, good_k = self._vote_ij(clmatrix, n_theta, i, j, k_list) + _, good_k = _vote_ij( + clmatrix, n_theta, i, j, k_list, self.hist_bin_width, self.full_width + ) - rots = self._rotratio_eulerangle_vec(clmatrix, i, j, good_k, n_theta) + rots = _rotratio_eulerangle_vec(clmatrix, i, j, good_k, n_theta) if rots is not None: rot_mean = np.mean(rots, 0) diff --git a/src/aspire/abinitio/commonline_sync3n.py b/src/aspire/abinitio/commonline_sync3n.py index 841f9dfceb..7be4b37c2c 100644 --- a/src/aspire/abinitio/commonline_sync3n.py +++ b/src/aspire/abinitio/commonline_sync3n.py @@ -6,22 +6,15 @@ from numpy.linalg import norm from scipy.optimize import curve_fit -from aspire.abinitio import CLOrient3D, SyncVotingMixin -from aspire.utils import ( - J_conjugate, - Rotation, - all_pairs, - nearest_rotations, - random, - tqdm, - trange, -) +from aspire.abinitio import CLOrient3D +from aspire.abinitio.sync_voting import _syncmatrix_ij_vote_3n +from aspire.utils import J_conjugate, all_pairs, nearest_rotations, random, tqdm, trange from aspire.utils.matlab_compat import stable_eigsh logger = logging.getLogger(__name__) -class CLSync3N(CLOrient3D, SyncVotingMixin): +class CLSync3N(CLOrient3D): """ Define a class to estimate 3D orientations using common lines Sync3N methods (2017). @@ -69,6 +62,7 @@ def __init__( J_weighting=False, hist_intervals=100, disable_gpu=False, + **kwargs, ): """ Initialize object for estimating 3D orientations. @@ -108,6 +102,7 @@ def __init__( hist_bin_width=hist_bin_width, full_width=full_width, mask=mask, + **kwargs, ) # Generate pair mappings @@ -963,44 +958,18 @@ def _estimate_all_Rijs_host(self, clmatrix): Rijs = np.zeros((len(self._pairs), 3, 3)) for idx, (i, j) in enumerate(tqdm(self._pairs, desc="Estimate Rijs")): - Rijs[idx] = self._syncmatrix_ij_vote_3n( - clmatrix, i, j, np.arange(n_img), n_theta + Rijs[idx] = _syncmatrix_ij_vote_3n( + clmatrix, + i, + j, + np.arange(n_img), + n_theta, + self.hist_bin_width, + self.full_width, ) return Rijs - def _syncmatrix_ij_vote_3n(self, clmatrix, i, j, k_list, n_theta): - """ - Compute the (i,j) rotation block of the synchronization matrix using voting method - - Given the common lines matrix `clmatrix`, a list of images specified in k_list - and the number of common lines n_theta, find the (i, j) rotation block Rij. - - :param clmatrix: The common lines matrix - :param i: The i image - :param j: The j image - :param k_list: The list of images for the third image for voting algorithm - :param n_theta: The number of points in the theta direction (common lines) - :return: The (i,j) rotation block of the synchronization matrix - """ - alphas, good_k = self._vote_ij(clmatrix, n_theta, i, j, k_list, sync=True) - - angles = np.zeros(3) - - if alphas is not None: - angles[0] = clmatrix[i, j] * 2 * np.pi / n_theta + np.pi / 2 - angles[1] = np.mean(alphas) - angles[2] = -np.pi / 2 - clmatrix[j, i] * 2 * np.pi / n_theta - rot = Rotation.from_euler(angles).matrices - - else: - # This is for the case that images i and j correspond to the same - # viewing direction and differ only by in-plane rotation. - # We set to zero as in the Matlab code. - rot = np.zeros((3, 3)) - - return rot - ####################################### # Secondary Methods for Global J Sync # ####################################### diff --git a/src/aspire/abinitio/commonline_utils.py b/src/aspire/abinitio/commonline_utils.py new file mode 100644 index 0000000000..45352bf026 --- /dev/null +++ b/src/aspire/abinitio/commonline_utils.py @@ -0,0 +1,392 @@ +import logging + +import numpy as np +from numpy.linalg import eigh, norm + +from aspire.operators import PolarFT +from aspire.utils import J_conjugate, Rotation, all_pairs, anorm, cyclic_rotations, tqdm + +logger = logging.getLogger(__name__) + + +def _estimate_third_rows(vijs, viis): + """ + Find the third row of each rotation matrix given a collection of matrices + representing the outer products of the third rows from each rotation matrix. + + :param vijs: An (n-choose-2)x3x3 array where each 3x3 slice holds the third rows + outer product of the rotation matrices Ri and Rj. + + :param viis: An n_imgx3x3 array where the i'th 3x3 slice holds the outer product of + the third row of Ri with itself. + + :param order: The underlying molecular symmetry. + + :return: vis, An n_imgx3 matrix whose i'th row is the third row of the rotation matrix Ri. + """ + + n_img = viis.shape[0] + + # Build matrix V whose (i,j)-th block of size 3x3 holds the outer product vij + V = np.zeros((n_img, n_img, 3, 3), dtype=vijs.dtype) + + # All pairs (i,j) where i= 1e-5 + + # If the third row coincides with the z-axis we return the identity matrix. + rots[~mask] = np.eye(3, dtype=r3.dtype) + + # 'norm_12' is non-zero since r3 does not coincide with the z-axis. + norm_12 = np.sqrt(r3[mask, 0] ** 2 + r3[mask, 1] ** 2) + + # Populate 1st rows with vector orthogonal to row 3. + rots[mask, 0, 0] = r3[mask, 1] / norm_12 + rots[mask, 0, 1] = -r3[mask, 0] / norm_12 + + # Populate 2nd rows such that r3 = r1 x r2 + rots[mask, 1, 0] = r3[mask, 0] * r3[mask, 2] / norm_12 + rots[mask, 1, 1] = r3[mask, 1] * r3[mask, 2] / norm_12 + rots[mask, 1, 2] = -norm_12 + + if singleton: + rots = rots.reshape(3, 3) + + return rots + + +def _cl_angles_to_ind(cl_angles, n_theta): + """ + Map 2D direction vectors to discretized angular indices. + + For each 2D vector [x, y] in `cl_angles`, compute its polar angle + and find the nearest of `n_theta` polar ray indices. + + :param cl_angles: Array of shape (n, 2) of [x, y] values corresponding + to the commonline induced by a pair of rotations. + :param n_theta: Resolution of polar rays. + + :return: int or array of length n of commonline indice in the range [0, n_theta - 1]. + """ + thetas = np.arctan2(cl_angles[:, 1], cl_angles[:, 0]) + + # Shift from [-pi,pi] to [0,2*pi). + thetas = np.mod(thetas, 2 * np.pi) + + # linear scale from [0,2*pi) to [0,n_theta). + ind = np.mod(np.round(thetas / (2 * np.pi) * n_theta), n_theta).astype(int) + + # Return scalar for single value. + if ind.size == 1: + ind = ind.flat[0] + + return ind + + +def g_sync(rots, order, rots_gt): + """ + Given ground truth rotations, synchronize estimated rotations over + symmetry group elements. + + Every estimated rotation might be a version of the ground truth rotation + rotated by g^{s_i}, where s_i = 0, 1, ..., order. This method synchronizes the + ground truth rotations so that only a single global rotation need be applied + to all estimates for error analysis. + + :param rots: Estimated rotation matrices + :param order: The cyclic order asssociated with the symmetry of the underlying molecule. + :param rots_gt: Ground truth rotation matrices. + + :return: g-synchronized ground truth rotations. + """ + assert len(rots) == len( + rots_gt + ), "Number of estimates not equal to number of references." + n_img = len(rots) + dtype = rots.dtype + + rots_symm = cyclic_rotations(order, dtype).matrices + + A_g = np.zeros((n_img, n_img), dtype=complex) + + pairs = all_pairs(n_img) + + for i, j in pairs: + Ri = rots[i] + Rj = rots[j] + Rij = Ri.T @ Rj + + Ri_gt = rots_gt[i] + Rj_gt = rots_gt[j] + + diffs = np.zeros(order) + for s, g_s in enumerate(rots_symm): + Rij_gt = Ri_gt.T @ g_s @ Rj_gt + diffs[s] = min([norm(Rij - Rij_gt), norm(Rij - J_conjugate(Rij_gt))]) + + idx = np.argmin(diffs) + + A_g[i, j] = np.exp(-1j * 2 * np.pi / order * idx) + + # A_g(k,l) is exp(-j(-theta_k+theta_l)) + # Diagonal elements correspond to exp(-i*0) so put 1. + # This is important only for verification purposes that spectrum is (K,0,0,0...,0). + A_g += np.conj(A_g).T + np.eye(n_img) + + _, eig_vecs = eigh(A_g) + leading_eig_vec = eig_vecs[:, -1] + + angles = np.exp(1j * 2 * np.pi / order * np.arange(order)) + rots_gt_sync = np.zeros((n_img, 3, 3), dtype=dtype) + + for i, rot_gt in enumerate(rots_gt): + # Since the closest ccw or cw rotation are just as good, + # we take the absolute value of the angle differences. + angle_dists = np.abs(np.angle(leading_eig_vec[i] / angles)) + power_g_Ri = np.argmin(angle_dists) + rots_gt_sync[i] = rots_symm[power_g_Ri] @ rot_gt + + return rots_gt_sync + + +def build_outer_products(n, dtype): + """ + Builds sets of outer products of 3rd rows of rotation matrices. + This is a helper function used in commonline testing. + + :param n: Number of 3rd rows to construct outer product from. + :param dtype: dtype of outputs + + :return: tuple of (vijs, viis, gt_vis), where vijs are the pairwise + outer products of gt_vis and viis are self outer products of gt_vis. + """ + # Build random third rows, ground truth vis (unit vectors) + gt_vis = np.zeros((n, 3), dtype=dtype) + for i in range(n): + np.random.seed(i) + v = np.random.randn(3) + gt_vis[i] = v / norm(v) + + # Find outer products viis and vijs for i 1e-12): - logger.warning( - f"Globally Consistent Angular Reconstruction (GCAR) exists" - f" numerical problem: abs(cos_phi2) > 1, with the" - f" difference of {np.abs(cos_phi2)-1}." - ) - cos_phi2 = np.clip(cos_phi2, -1, 1) - - # Store angles between i and j induced by each third image k. - phis = cos_phi2 - # Sore good indices of l in k_list of the image that creates that angle. - inds = k_list[good_idx] - - if phis.shape[0] == 0: - return None, [] - - # Parameters used to compute the smoothed angle histogram. - ntics = int(180 / self.hist_bin_width) - angles_grid = np.linspace(0, 180, ntics + 1, True) - - # Get angles between images i and j for computing the histogram - angles = np.arccos(phis[:]) * 180 / np.pi - - # Angles that are up to 10 degrees apart are considered - # similar. This sigma ensures that the width of the density - # estimation kernel is roughly 10 degrees. For 15 degrees, the - # value of the kernel is negligible. - sigma = getattr(self, "sigma", 3.0) # get from class if avail - - # Compute the histogram of the angles between images i and j - angles_distances = angles_grid[None, :] - angles[:, None] - angles_hist = np.sum(np.exp(-(angles_distances**2) / (2 * sigma**2)), axis=0) - - # We assume that at the location of the peak we get the true angle - # between images i and j. Find all third images k, that induce an - # angle between i and j that is at most 10 off the true angle. - # Even for debugging, don't put a value that is smaller than two - # tics, since the peak might move a little bit due to wrong k images - # that accidentally fall near the peak. - peak_idx = angles_hist.argmax() - - if self.full_width == -1: - # Adaptive width (MATLAB) - # Look for the estimations in the peak of the histogram - w_theta_needed = 0 - idx = [] - while sum(idx) == 0: - w_theta_needed += self.hist_bin_width # widen peak as needed - idx = np.abs(angles - angles_grid[peak_idx]) < w_theta_needed - if w_theta_needed > self.hist_bin_width: - logger.info( - f"Adaptive width {w_theta_needed} required for ({i},{j}), found {sum(idx)} indices." - ) - else: - # Fixed width - idx = np.abs(angles - angles_grid[peak_idx]) < self.full_width - - good_k = inds[idx] - alpha = np.arccos(phis[idx]) - - return alpha, good_k.astype("int") - - def _get_cos_phis(self, cl_diff1, cl_diff2, cl_diff3, n_theta, sync=False): - """ - Calculate cos values of rotation angles between i and j images - - Given C1, C2, and C3 are unit circles of image i, j, and k, compute - resulting cos values of rotation angles between i an j images when both - of them are intersecting with k. - - To ensure that the smallest singular value is big enough, controlled by - the determinant of the matrix, - C=[ 1 c1 c2 ; - c1 1 c3 ; - c2 c3 1 ], - we therefore use the condition below - 1+2*c1*c2*c3-(c1^2+c2^2+c3^2) > 1.0e-5, - so the matrix is far from singular. - - :param cl_diff1: Difference of common line indices on C1 created by - its intersection with C3 and C2 - :param cl_diff2: Difference of common line indices on C2 created by - its intersection with C1 and C3 - :param cl_diff3: Difference of common line indices on C3 created by - its intersection with C2 and C1 - :param n_theta: The number of points in the theta direction (common lines) - :param sync: Perform 180 degree ambiguity synchronization. - :return: cos values of rotation angles between i and j images - and indices for good k - """ - - # Calculate the theta values from the differences of common line indices - # C1, C2, and C3 are unit circles of image i, j, and k - # theta1 is the angle on C1 created by its intersection with C3 and C2. - # theta2 is the angle on C2 created by its intersection with C1 and C3. - # theta3 is the angle on C3 created by its intersection with C2 and C1. - theta1 = cl_diff1 * 2 * np.pi / n_theta - theta2 = cl_diff2 * 2 * np.pi / n_theta - theta3 = cl_diff3 * 2 * np.pi / n_theta - - c1 = np.cos(theta1) - c2 = np.cos(theta2) - c3 = np.cos(theta3) - - # Each common-line corresponds to a point on the unit sphere. Denote the - # coordinates of these points by (Pix, Piy Piz), and put them in the matrix - # M=[ P1x P2x P3x ; - # P1y P2y P3y ; - # P1z P2z P3z ]. - # - # Then the matrix - # C=[ 1 c1 c2 ; - # c1 1 c3 ; - # c2 c3 1 ], - # where c1, c2, c3 are given above, is given by C = M.T @ M. - # For the points P1, P2, and P3 to form a triangle on the unit sphere, a - # necessary and sufficient condition is for C to be positive definite. This - # is equivalent to - # 1+2*c1*c2*c3-(c1^2+c2^2+c3^2) > 0. - # However, this may result in a triangle that is too flat, that is, the - # angle between the projections is very close to zero. We therefore use the - # condition below - # 1+2*c1*c2*c3-(c1^2+c2^2+c3^2) > 1.0e-5. - # This ensures that the smallest singular value (which is actually - # controlled by the determinant of C) is big enough, so the matrix is far - # from singular. This condition is equivalent to computing the singular - # values of C, followed by checking that the smallest one is big enough. - - cond = 1 + 2 * c1 * c2 * c3 - (np.square(c1) + np.square(c2) + np.square(c3)) - good_idx = np.nonzero(cond > 1e-5)[0] - - # Calculated cos values of angle between i and j images - if sync: - # MATLAB - cos_phi2 = (c3[good_idx] - c1[good_idx] * c2[good_idx]) / ( - np.sqrt(1 - c1[good_idx] ** 2) * np.sqrt(1 - c2[good_idx] ** 2) - ) + :return: The rotation matrix that takes image i to image j for good index of k. + """ - # Some synchronization must be applied when common line is - # out by 180 degrees. - # Here fix the angles between c_ij(c_ji) and c_ik(c_jk) to be smaller than pi/2, - # otherwise there will be an ambiguity between alpha and pi-alpha. - TOL_idx = 1e-12 + if i == j: + return [] - # Select only good_idx - theta1 = theta1[good_idx] - theta2 = theta2[good_idx] - theta3 = theta3[good_idx] + # Prepare the theta values from the differences of common line indices + # C1, C2, and C3 are unit circles of image i, j, and k + # cl_diff1 is for the angle on C1 created by its intersection with C3 and C2. + # cl_diff2 is for the angle on C2 created by its intersection with C1 and C3. + # cl_diff3 is for the angle on C3 created by its intersection with C2 and C1. + cl_diff1 = clmatrix[i, good_k] - clmatrix[i, j] # for theta1 + cl_diff2 = clmatrix[j, good_k] - clmatrix[j, i] # for theta2 + cl_diff3 = clmatrix[good_k, j] - clmatrix[good_k, i] # for theta3 - # Check sync conditions - ind1 = (theta1 > (np.pi + TOL_idx)) | ( - (theta1 < -TOL_idx) & (theta1 > -np.pi) - ) - ind2 = (theta2 > (np.pi + TOL_idx)) | ( - (theta2 < -TOL_idx) & (theta2 > -np.pi) - ) - align180 = (ind1 & ~ind2) | (~ind1 & ind2) - - # Apply sync - cos_phi2[align180] = -cos_phi2[align180] - else: - # Python - cos_phi2 = (c3[good_idx] - c1[good_idx] * c2[good_idx]) / ( - np.sin(theta1[good_idx]) * np.sin(theta2[good_idx]) + # Calculate the cos values of rotation angles between i an j images for good k images + c_alpha, good_idx = _get_cos_phis(cl_diff1, cl_diff2, cl_diff3, n_theta, sync=False) + + if len(c_alpha) == 0: + return None + alpha = np.arccos(c_alpha) + + # Convert the Euler angles with ZYZ conversion to rotation matrices + angles = np.zeros((alpha.shape[0], 3)) + angles[:, 0] = clmatrix[i, j] * 2 * np.pi / n_theta + np.pi / 2 + angles[:, 1] = alpha + angles[:, 2] = -np.pi / 2 - clmatrix[j, i] * 2 * np.pi / n_theta + r = Rotation.from_euler(angles).matrices + + return r[good_idx, :, :] + + +def _vote_ij( + clmatrix, n_theta, i, j, k_list, hist_bin_width, full_width, sigma=3.0, sync=False +): + """ + Apply the voting algorithm for images i and j. + + clmatrix is the common lines matrix, constructed using angular resolution, + n_theta. k_list are the images to be used for voting of the pair of images + (i ,j). + + :param clmatrix: The common lines matrix + :param n_theta: The number of points in the theta direction (common lines) + :param i: The i image + :param j: The j image + :param k_list: The list of images for the third image for voting algorithm + :param hist_bin_width: Bin width in smoothing histogram (degrees). + :param full_width: Selection width around smoothed histogram peak (degrees). + `adaptive` will attempt to automatically find the smallest number of + `hist_bin_width`s required to find at least one valid image index. + :param sigma: Voting contribution smoothing factor. Default is 3.0. + :param sync: Perform 180 degree ambiguity synchronization. + + :return: (alpha, good_k), angles and list of all third images + in the peak of the histogram corresponding to the pair of + images (i,j) + """ + + if i == j or clmatrix[i, j] == -1: + return None, [] + + # Some of the entries in clmatrix may be zero if we cleared + # them due to small correlation, or if for each image + # we compute intersections with only some of the other images. + # + # Note that as long as the diagonal of the common lines matrix is + # -1, the conditions (i != j) && (j != k) are not needed, since + # if i == j then clmatrix[i, k] == -1 and similarly for i == k or + # j == k. Thus, the previous voting code (from the JSB paper) is + # correct even though it seems that we should test also that + # (i != j) && (i != k) && (j != k), and only (i != j) && (i != k) + # as tested there. + cl_idx12 = clmatrix[i, j] + cl_idx21 = clmatrix[j, i] + k_list = k_list[ + (k_list != i) & (clmatrix[i, k_list] != -1) & (clmatrix[j, k_list] != -1) + ] + cl_idx13 = clmatrix[i, k_list] + cl_idx31 = clmatrix[k_list, i] + cl_idx23 = clmatrix[j, k_list] + cl_idx32 = clmatrix[k_list, j] + + # Prepare the theta values from the differences of common line indices + # C1, C2, and C3 are unit circles of image i, j, and k + # cl_diff1 is for the angle on C1 created by its intersection with C3 and C2. + # cl_diff2 is for the angle on C2 created by its intersection with C1 and C3. + # cl_diff3 is for the angle on C3 created by its intersection with C2 and C1. + cl_diff1 = cl_idx13 - cl_idx12 + cl_diff2 = cl_idx23 - cl_idx21 + cl_diff3 = cl_idx32 - cl_idx31 + + # Calculate the cos values of rotation angles between i an j images for good k images + cos_phi2, good_idx = _get_cos_phis(cl_diff1, cl_diff2, cl_diff3, n_theta, sync=sync) + + if np.any(np.abs(cos_phi2) - 1 > 1e-12): + logger.warning( + f"Globally Consistent Angular Reconstruction (GCAR) exists" + f" numerical problem: abs(cos_phi2) > 1, with the" + f" difference of {np.abs(cos_phi2)-1}." + ) + cos_phi2 = np.clip(cos_phi2, -1, 1) + + # Store angles between i and j induced by each third image k. + phis = cos_phi2 + # Sore good indices of l in k_list of the image that creates that angle. + inds = k_list[good_idx] + + if phis.shape[0] == 0: + return None, [] + + # Parameters used to compute the smoothed angle histogram. + ntics = int(180 / hist_bin_width) + angles_grid = np.linspace(0, 180, ntics + 1, True) + + # Get angles between images i and j for computing the histogram + angles = np.arccos(phis[:]) * 180 / np.pi + + # Angles that are up to 10 degrees apart are considered + # similar. `sigma` ensures that the width of the density + # estimation kernel is roughly 10 degrees. For 15 degrees, the + # value of the kernel is negligible. + + # Compute the histogram of the angles between images i and j + angles_distances = angles_grid[None, :] - angles[:, None] + angles_hist = np.sum(np.exp(-(angles_distances**2) / (2 * sigma**2)), axis=0) + + # We assume that at the location of the peak we get the true angle + # between images i and j. Find all third images k, that induce an + # angle between i and j that is at most 10 off the true angle. + # Even for debugging, don't put a value that is smaller than two + # tics, since the peak might move a little bit due to wrong k images + # that accidentally fall near the peak. + peak_idx = angles_hist.argmax() + + if full_width == -1: + # Adaptive width (MATLAB) + # Look for the estimations in the peak of the histogram + w_theta_needed = 0 + idx = [] + while sum(idx) == 0: + w_theta_needed += hist_bin_width # widen peak as needed + idx = np.abs(angles - angles_grid[peak_idx]) < w_theta_needed + if w_theta_needed > hist_bin_width: + logger.info( + f"Adaptive width {w_theta_needed} required for ({i},{j}), found {sum(idx)} indices." ) + else: + # Fixed width + idx = np.abs(angles - angles_grid[peak_idx]) < full_width + + good_k = inds[idx] + alpha = np.arccos(phis[idx]) + + return alpha, good_k.astype("int") + + +def _get_cos_phis(cl_diff1, cl_diff2, cl_diff3, n_theta, sync=False): + """ + Calculate cos values of rotation angles between i and j images + + Given C1, C2, and C3 are unit circles of image i, j, and k, compute + resulting cos values of rotation angles between i an j images when both + of them are intersecting with k. + + To ensure that the smallest singular value is big enough, controlled by + the determinant of the matrix, + C=[ 1 c1 c2 ; + c1 1 c3 ; + c2 c3 1 ], + we therefore use the condition below + 1+2*c1*c2*c3-(c1^2+c2^2+c3^2) > 1.0e-5, + so the matrix is far from singular. + + :param cl_diff1: Difference of common line indices on C1 created by + its intersection with C3 and C2 + :param cl_diff2: Difference of common line indices on C2 created by + its intersection with C1 and C3 + :param cl_diff3: Difference of common line indices on C3 created by + its intersection with C2 and C1 + :param n_theta: The number of points in the theta direction (common lines) + :param sync: Perform 180 degree ambiguity synchronization. + + :return: cos values of rotation angles between i and j images + and indices for good k + """ + + # Calculate the theta values from the differences of common line indices + # C1, C2, and C3 are unit circles of image i, j, and k + # theta1 is the angle on C1 created by its intersection with C3 and C2. + # theta2 is the angle on C2 created by its intersection with C1 and C3. + # theta3 is the angle on C3 created by its intersection with C2 and C1. + theta1 = cl_diff1 * 2 * np.pi / n_theta + theta2 = cl_diff2 * 2 * np.pi / n_theta + theta3 = cl_diff3 * 2 * np.pi / n_theta + + c1 = np.cos(theta1) + c2 = np.cos(theta2) + c3 = np.cos(theta3) + + # Each common-line corresponds to a point on the unit sphere. Denote the + # coordinates of these points by (Pix, Piy Piz), and put them in the matrix + # M=[ P1x P2x P3x ; + # P1y P2y P3y ; + # P1z P2z P3z ]. + # + # Then the matrix + # C=[ 1 c1 c2 ; + # c1 1 c3 ; + # c2 c3 1 ], + # where c1, c2, c3 are given above, is given by C = M.T @ M. + # For the points P1, P2, and P3 to form a triangle on the unit sphere, a + # necessary and sufficient condition is for C to be positive definite. This + # is equivalent to + # 1+2*c1*c2*c3-(c1^2+c2^2+c3^2) > 0. + # However, this may result in a triangle that is too flat, that is, the + # angle between the projections is very close to zero. We therefore use the + # condition below + # 1+2*c1*c2*c3-(c1^2+c2^2+c3^2) > 1.0e-5. + # This ensures that the smallest singular value (which is actually + # controlled by the determinant of C) is big enough, so the matrix is far + # from singular. This condition is equivalent to computing the singular + # values of C, followed by checking that the smallest one is big enough. + + cond = 1 + 2 * c1 * c2 * c3 - (np.square(c1) + np.square(c2) + np.square(c3)) + good_idx = np.nonzero(cond > 1e-5)[0] + + # Calculated cos values of angle between i and j images + if sync: + # MATLAB + cos_phi2 = (c3[good_idx] - c1[good_idx] * c2[good_idx]) / ( + np.sqrt(1 - c1[good_idx] ** 2) * np.sqrt(1 - c2[good_idx] ** 2) + ) + + # Some synchronization must be applied when common line is + # out by 180 degrees. + # Here fix the angles between c_ij(c_ji) and c_ik(c_jk) to be smaller than pi/2, + # otherwise there will be an ambiguity between alpha and pi-alpha. + TOL_idx = 1e-12 + + # Select only good_idx + theta1 = theta1[good_idx] + theta2 = theta2[good_idx] + theta3 = theta3[good_idx] + + # Check sync conditions + ind1 = (theta1 > (np.pi + TOL_idx)) | ((theta1 < -TOL_idx) & (theta1 > -np.pi)) + ind2 = (theta2 > (np.pi + TOL_idx)) | ((theta2 < -TOL_idx) & (theta2 > -np.pi)) + align180 = (ind1 & ~ind2) | (~ind1 & ind2) + + # Apply sync + cos_phi2[align180] = -cos_phi2[align180] + else: + # Python + cos_phi2 = (c3[good_idx] - c1[good_idx] * c2[good_idx]) / ( + np.sin(theta1[good_idx]) * np.sin(theta2[good_idx]) + ) - return cos_phi2, good_idx + return cos_phi2, good_idx diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index d982e746e5..3bbaff05b2 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -84,7 +84,7 @@ def average( :param classes: class indices, refering to src. (src.n, n_nbor). :param reflections: Bool representing whether to reflect image in `classes`. - (n_clases, n_nbor) + (n_classes, n_nbor) :param coefs: Optional basis coefs (could avoid recomputing). (src.n, coef_count) :return: Stack of synthetic class average images as Image instance. @@ -152,6 +152,16 @@ def __init__( f"{self.__class__.__name__}'s composite_basis {self.composite_basis} must provide a `shift` method." ) + # Instantiate dicts to hold alignment results. + # Note dicts are used in place of arrays because: + # The entire set of `src.n` classes may not always need to be computed, + # and the order/batching where results are computed is potentially arbitrary. + # We may not know apriori how many nbors are in each class, + # and this may be variable with future methods. + self.rotations = dict() + self.shifts = dict() + self.dot_products = dict() + @abstractmethod def align(self, classes, reflections, basis_coefficients=None): """ @@ -192,9 +202,15 @@ def average( classes = np.atleast_2d(classes) reflections = np.atleast_2d(reflections) - self.rotations, self.shifts, self.dot_products = self.align( - classes, reflections, coefs - ) + rotations, shifts, dot_products = self.align(classes, reflections, coefs) + + # Assign batch results + src_indices = classes[:, 0] # First column of class table + for i, k in enumerate(src_indices): + self.rotations[k] = rotations[i] + if shifts is not None: + self.shifts[k] = shifts[i] + self.dot_products[k] = dot_products[i] n_classes, n_nbor = classes.shape @@ -212,29 +228,29 @@ def _innerloop(i): neighbors_imgs = Image(self._cls_images(classes[i])) # Do shifts - if self.shifts is not None: - neighbors_imgs = neighbors_imgs.shift(self.shifts[i]) + if shifts is not None: + neighbors_imgs = neighbors_imgs.shift(shifts[i]) neighbors_coefs = self.composite_basis.evaluate_t(neighbors_imgs) else: # Get the neighbors neighbors_ids = classes[i] neighbors_coefs = coefs[neighbors_ids] - if self.shifts is not None: + if shifts is not None: neighbors_coefs = self.composite_basis.shift( - neighbors_coefs, self.shifts[i] + neighbors_coefs, shifts[i] ) # Rotate in composite_basis neighbors_coefs = self.composite_basis.rotate( - neighbors_coefs, self.rotations[i], reflections[i] + neighbors_coefs, rotations[i], reflections[i] ) # Averaging in composite_basis return self.image_stacker(neighbors_coefs.asnumpy()) - desc = f"Stacking and evaluating class averages from {self.composite_basis.__class__.__name__} to Cartesian" - for start in trange(0, n_classes, self.batch_size, desc=desc): + desc = f"Stacking and evaluating batch of class averages from {self.composite_basis.__class__.__name__} to Cartesian" + for start in trange(0, n_classes, self.batch_size, desc=desc, leave=False): end = min(start + self.batch_size, n_classes) for i, cls in enumerate( trange(start, end, desc="Stacking batch", leave=False) @@ -362,7 +378,7 @@ def align(self, classes, reflections, basis_coefficients=None): # This is done primarily in case of a tie later, we would take unshifted. test_shifts = self._shift_search_grid(self.src.L, self.radius, roll_zero=True) - for k in trange(n_classes, desc="Rotationally aligning classes"): + for k in trange(n_classes, desc="Rotationally aligning classes", leave=False): # We want to locally cache the original images, # because we will mutate them with shifts in the next loop. # This avoids recomputing them before each shift @@ -564,7 +580,7 @@ def _innerloop(k): dtype=self.dtype, ) - for k in trange(n_classes, desc="Rotationally aligning classes"): + for k in trange(n_classes, desc="Rotationally aligning classes", leave=False): rotations[k], shifts[k], dot_products[k] = _innerloop(k) return rotations, shifts, dot_products @@ -580,9 +596,15 @@ def average( Otherwise is similar to `AligningAverager2D.average`. """ - self.rotations, self.shifts, self.dot_products = self.align( - classes, reflections, coefs - ) + rotations, shifts, dot_products = self.align(classes, reflections, coefs) + + # Assign batch results + src_indices = classes[:, 0] # First column of class table + for i, k in enumerate(src_indices): + self.rotations[k] = rotations[i] + if shifts is not None: + self.shifts[k] = shifts[i] + self.dot_products[k] = dot_products[i] n_classes, n_nbor = classes.shape @@ -601,19 +623,17 @@ def _innerloop(i): # Rotate in composite_basis neighbors_coefs = self.composite_basis.rotate( - neighbors_coefs, self.rotations[i], reflections[i] + neighbors_coefs, rotations[i], reflections[i] ) # Note shifts are after rotation for this approach! - if self.shifts is not None: - neighbors_coefs = self.composite_basis.shift( - neighbors_coefs, self.shifts[i] - ) + if shifts is not None: + neighbors_coefs = self.composite_basis.shift(neighbors_coefs, shifts[i]) # Averaging in composite_basis return self.image_stacker(neighbors_coefs.asnumpy()) - for i in trange(n_classes, desc="Stacking class averages"): + for i in trange(n_classes, desc="Stacking class averages", leave=False): b_avgs[i] = _innerloop(i) # Now we convert the averaged images from Basis to Cartesian. @@ -732,7 +752,7 @@ def _innerloop(k): return _rotations, _shifts, _dot_products - for k in trange(n_classes, desc="Rotationally aligning classes"): + for k in trange(n_classes, desc="Rotationally aligning classes", leave=False): rotations[k], shifts[k], dot_products[k] = _innerloop(k) return rotations, shifts, dot_products @@ -880,7 +900,7 @@ def align(self, classes, reflections, basis_coefficients=None): ) _images = xp.empty((n_nbor - 1, self.src.L, self.src.L), dtype=self.dtype) - for k in trange(n_classes, desc="Rotationally aligning classes"): + for k in trange(n_classes, desc="Rotationally aligning classes", leave=False): # We want to locally cache the original images, # because we will mutate them with shifts in the next loop. # This avoids recomputing them before each shift diff --git a/src/aspire/config_default.yaml b/src/aspire/config_default.yaml index 4f02923054..ee97c4b8ca 100644 --- a/src/aspire/config_default.yaml +++ b/src/aspire/config_default.yaml @@ -1,4 +1,4 @@ -version: 0.14.1 +version: 0.14.2 common: # numeric module to use - one of numpy/cupy numeric: numpy diff --git a/src/aspire/denoising/class_avg.py b/src/aspire/denoising/class_avg.py index 564549342d..646cdc9e2a 100644 --- a/src/aspire/denoising/class_avg.py +++ b/src/aspire/denoising/class_avg.py @@ -223,6 +223,7 @@ def _class_select(self): self._classify() # Perform class selection + logger.info("Performing class selection") _selection_indices = self.class_selector.select( self.class_indices, self.class_refl, diff --git a/src/aspire/image/__init__.py b/src/aspire/image/__init__.py index 431dedf87a..442526fa54 100644 --- a/src/aspire/image/__init__.py +++ b/src/aspire/image/__init__.py @@ -1,3 +1,11 @@ +# isort: off +from .rotation import ( + compute_fastrotate_interp_tables, + fastrotate, + sp_rotate, +) + +# isort: on from .image import ( BasisImage, BispecImage, diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 8212033f72..071437a1ef 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -10,6 +10,7 @@ import aspire.sinogram import aspire.volume +from aspire.image import fastrotate, sp_rotate from aspire.nufft import anufft, nufft from aspire.numeric import fft, xp from aspire.utils import ( @@ -158,6 +159,8 @@ class Image: ".tif": load_tiff, ".tiff": load_tiff, } + # Available image rotation functions + rotation_methods = {"fastrotate": fastrotate, "scipy": sp_rotate} def __init__(self, data, pixel_size=None, dtype=None): """ @@ -635,8 +638,81 @@ def filter(self, filter): original_stack_shape ) - def rotate(self): - raise NotImplementedError + def rotate(self, theta, method="scipy", mask=1, **kwargs): + """ + Return `Image` rotated by `theta` radians using `method`. + + Optionally applies `mask`. Note that some methods may + introduce edge artifacts, in which case users may consider + using a tighter mask (eg 0.9) or a combination of pad-crop. + + Any additional kwargs will be passed to `method`. + + :param theta: Scalar or array of length `n_images` + :param mask: Optional scalar or array mask matching `Image` shape. + Scalar will create a circular mask of prescribed radius `(0,1]`. + Array mask will be applied via elementwise multiplication. + `None` disables masking. + :param method: Optionally specify a rotation method. + Defaults to `scipy`. + :return: `Image` containing rotated image data. + """ + + original_stack_shape = self.stack_shape + im = self.stack_reshape(-1) + + # Resolve rotation method + if method not in self.rotation_methods: + raise NotImplementedError( + f"Requested `Image.rotation` method={method} not found." + f" Select from {self.rotation_methods.keys()}" + ) + # Assign the rotation method's function + # Any rotation method is expected to handle image data as a 2D array or 3D array (single stack axis). + rotation_function = self.rotation_methods[method] + + # Handle both scalar and arrays of rotation angles. + # `theta` arrays are checked to match length of images when stacks axis are flattened. + theta = np.array(theta).flatten() + if len(theta) == 1: + im = rotation_function(im._data, theta, **kwargs) + elif len(theta) == im.n_images: + rot_im = np.empty_like(im._data) + for i in range(im.n_images): + rot_im[i] = rotation_function(im._data[i], theta[i], **kwargs) + im = rot_im + else: + raise RuntimeError( + f"Length of `theta` {len(theta)} and `Image` data {im.n_images} inconsistent." + ) + + # Masking, scalar case + if mask is not None: + if np.size(mask) == 1: + # Confirm `mask` value is a sane radius + if not (0 < mask <= 1): + raise ValueError( + f"Mask radius must be scalar between (0,1]. Received {mask}" + ) + # Construct a boolean `mask` to apply in next code block as a 2D `mask` + mask = ( + grid_2d(im.shape[-1], normalized=True, dtype=np.float64)["r"] < mask + ) + mask = mask.astype(im.dtype) + + # Masking, 2D case + # Confirm `mask` size is consistent + if mask.shape == im.shape[-2:]: + im = im * mask[None, :, :] + else: + raise RuntimeError( + f"Shape of `mask` {mask.shape} inconsistent with `Image` data shape {im.shape[-2:]}" + ) + + # Restore original stack shape and metadata. + return self.__class__(im, pixel_size=self.pixel_size).stack_reshape( + original_stack_shape + ) def save(self, mrcs_filepath, overwrite=None): """ diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py new file mode 100644 index 0000000000..706248110f --- /dev/null +++ b/src/aspire/image/rotation.py @@ -0,0 +1,251 @@ +import numpy as np +from scipy import ndimage + +from aspire.numeric import fft, xp +from aspire.utils import complex_type + + +def _pre_rotate(theta): + """ + Given `theta` radians return nearest rotation of pi/2 + required to place angle within [-pi/4,pi/4) and the residual + rotation in radians. + + :param theta: Rotation in radians + :returns: + - Residual angle in radians + - Number of pi/2 rotations + """ + + theta = np.mod(theta, 2 * np.pi) + + # 0 < pi/4 + rots = 0 + residual = theta + + if theta >= np.pi / 4 and theta < 3 * np.pi / 4: + rots = 1 + residual = theta - np.pi / 2 + elif theta >= 3 * np.pi / 4 and theta < 5 * np.pi / 4: + rots = 2 + residual = theta - np.pi + elif theta >= 5 * np.pi / 4 and theta < 7 * np.pi / 4: + rots = 3 + residual = theta - 3 * np.pi / 2 + elif theta >= 7 * np.pi / 4 and theta < 2 * np.pi: + rots = 0 + residual = theta - 2 * np.pi + + return residual, rots + + +def _shift_center(n): + """ + Given `n` pixels return center pixel and shift amount, 0 or 1/2. + + :param n: Number of pixels + :returns: + - center pixel + - shift amount + """ + if n % 2 == 0: + c = n // 2 # center + s = 1 / 2 # shift + else: + c = n // 2 + s = 0 + + return c, s + + +def compute_fastrotate_interp_tables(theta, nx, ny): + """ + Retuns iterpolation tables as tuple M = (Mx, My, rots). + + :param theta: angle in radians + :param nx: Number pixels first axis + :param ny: Number pixels second axis + """ + theta, mult90 = _pre_rotate(theta) + + # Reverse rotation, Yaroslavsky rotated CW + theta = -theta + + cy, sy = _shift_center(ny) + cx, sx = _shift_center(nx) + + # Floating point epsilon + eps = np.finfo(np.float64).eps + + # Precompute Y interpolation tables + My = np.zeros((nx, ny), dtype=np.complex128) + r = np.arange(cy + 1, dtype=int) + u = (1 - np.cos(theta)) / np.sin(theta + eps) + alpha1 = 2 * np.pi * 1j * r / ny + + linds = np.arange(ny - 1, cy, -1, dtype=int) + rinds = np.arange(1, cy - 2 * sy + 1, dtype=int) + + Ux = u * (np.arange(nx) - cx + sx + 2) + My[:, r] = np.exp(alpha1[None, :] * Ux[:, None]) + My[:, linds] = My[:, rinds].conj() + + # Precompute X interpolation tables + Mx = np.zeros((ny, nx), dtype=np.complex128) + r = np.arange(cx + 1, dtype=int) + u = -np.sin(theta) + alpha2 = 2 * np.pi * 1j * r / nx + + linds = np.arange(nx - 1, cx, -1, dtype=int) + rinds = np.arange(1, cx - 2 * sx + 1, dtype=int) + + Uy = u * (np.arange(ny) - cy + sy + 2) + Mx[:, r] = np.exp(alpha2[None, :] * Uy[:, None]) + Mx[:, linds] = Mx[:, rinds].conj() + + # After building, transpose to (nx, ny). + Mx = Mx.T + + return Mx, My, mult90 + + +# The following helper utilities are written to work with +# `img` data of dimension 2 or more where the data is expected to be +# in the (-2,-1) dimensions with any other dims as stack axes. +def _rot90(img): + """Rotate image array by 90 degrees.""" + # stack broadcast of flipud(img.T) + return xp.flip(xp.swapaxes(img, -1, -2), axis=-2) + + +def _rot180(img): + """Rotate image array by 180 degrees.""" + # stack broadcast of flipud(fliplr) + return xp.flip(img, axis=(-1, -2)) + + +def _rot270(img): + """Rotate image array by 270 degrees.""" + # stack broadcast of fliplr(img.T) + return xp.flip(xp.swapaxes(img, -1, -2), axis=-1) + + +def fastrotate(images, theta, M=None): + """ + Rotate `images` array by `theta` radians ccw using shearing algorithm. + + Note that this algorithm may have artifacts near the rotation boundary + and will have artifacts outside the rotation boundary. + Users can avoid these by zero padding the input image then + cropping the rotated image and/or masking. + + For reference and notes: + `https://github.com/PrincetonUniversity/aspire/blob/760a43b35453e55ff2d9354339e9ffa109a25371/common/fastrotate/fastrotate.m` + + :param images: (n , px, px) array of image data + :param theta: Rotation angle in radians. + Note when `M` is supplied, `theta` must be `None`. + :param M: Optional precomputed shearing table. + Provided by `M=compute_fastrotate_interp_tables(theta, px, px)`. + Note when `M` is supplied, `theta` must be `None`. + :return: (n, px, px) array of rotated image data + """ + + # Make a stack of 1 + if images.ndim == 2: + images = images[None, :, :] + + n, px0, px1 = images.shape + assert px0 == px1, "Currently only implemented for square images." + + if M is None: + M = compute_fastrotate_interp_tables(theta, px0, px1) + elif theta is not None: + raise RuntimeError( + "`theta` must be `None` when supplying `M`." + " M is precomputed for a specific `theta`." + ) + Mx, My, Mrots = M + + # Cast interp tables to match precision of `images` + Mx = xp.asarray(Mx, complex_type(images.dtype)) + My = xp.asarray(My, complex_type(images.dtype)) + + # Determine if `images` data was provided on host (np.darray) + _host = isinstance(images, np.ndarray) + + # Copy image array to device if needed + images = xp.asarray(images) + + # Pre rotate by multiples of 90 (pi/2) + if Mrots == 1: + images = _rot90(images) + elif Mrots == 2: + images = _rot180(images) + elif Mrots == 3: + images = _rot270(images) + + # Shear 1 + img_k = fft.fft(images, axis=-1) + img_k = img_k * My + images = fft.ifft(img_k, axis=-1).real + + # Shear 2 + img_k = fft.fft(images, axis=-2) + img_k = img_k * Mx + images = fft.ifft(img_k, axis=-2).real + + # Shear 3 + img_k = fft.fft(images, axis=-1) + img_k = img_k * My + images = fft.ifft(img_k, axis=-1).real + + # Return to host if input was provided on host + if _host: + images = xp.asnumpy(images) + + return images + + +def sp_rotate(img, theta, **kwargs): + """ + Utility wrapper to form a ASPIRE compatible call to Scipy's image rotation. + + Converts `theta` from radian to degrees. + Defines stack/image axes and reshape behavior. + Image data is expected to be in last two axes in all cases. + + Additional kwargs will be passed through. + See scipy.ndimage.rotate + + :param img: Array of image data shape (L,L) or (...,L, L) + :param theta: Rotation in ccw radians. + :return: Array representing rotated `img`. + """ + + # Store original shape + original_shape = img.shape + # Image data is expected to be in last two axis in all cases + # Flatten, converts all inputs to consistent 3D shape (single stack axis). + img = img.reshape(-1, *img.shape[-2:]) + + # Scipy accepts a single scalar theta in degrees. + # Handle array of thetas and scalar case by expanding to flat array of img.shape + # Flatten all inputs + theta = np.rad2deg(np.array(theta)).reshape(-1) + # Expand single scalar input + if np.size(theta) == 1: + theta = np.full(img.shape[0], theta, img.dtype) + # Check we have an array matching `img`, both should be len(n) + if theta.shape[0] != img.shape[0]: + raise RuntimeError("Inconsistent `theta` and `img` shapes.") + + # Create result array and rotate images via loop + result = np.empty_like(img) + for i in range(img.shape[0]): + result[i] = ndimage.rotate( + img[i], theta[i], reshape=False, axes=(-2, -1), **kwargs + ) + + # Restore original shape + return result.reshape(*original_shape) diff --git a/src/aspire/numeric/complex_pca/complex_pca.py b/src/aspire/numeric/complex_pca/complex_pca.py index e38820c702..7c2f20fc03 100644 --- a/src/aspire/numeric/complex_pca/complex_pca.py +++ b/src/aspire/numeric/complex_pca/complex_pca.py @@ -15,6 +15,7 @@ import scipy.sparse as sp from sklearn.decomposition import PCA from sklearn.utils._array_api import get_namespace +from sklearn.utils.validation import check_is_fitted from .validation import check_array @@ -78,3 +79,26 @@ def _fit(self, X): raise ValueError( "Unrecognized svd_solver='{0}'" "".format(self._fit_svd_solver) ) + + def inverse_transform(self, X): + """Transform data back to its original space.""" + + xp, _ = get_namespace(X, self.components_, self.explained_variance_) + + check_is_fitted(self) + + X = check_array( + X, + dtype=[np.complex128, np.complex64, np.float64, np.float32], + ensure_2d=True, + copy=self.copy, + allow_complex=True, + ) + + if self.whiten: + scaled_components = ( + xp.sqrt(self.explained_variance_[:, np.newaxis]) * self.components_ + ) + return X @ scaled_components + self.mean_ + else: + return X @ self.components_ + self.mean_ diff --git a/src/aspire/operators/polar_ft.py b/src/aspire/operators/polar_ft.py index de57297700..bd50da727f 100644 --- a/src/aspire/operators/polar_ft.py +++ b/src/aspire/operators/polar_ft.py @@ -135,6 +135,7 @@ def _transform(self, x): resolution = x.shape[-1] + # Note, `freqs` is negated from legacy MATLAB. # nufft call should return `pf` as array type (np or cp) of `x` pf = nufft(x, self.freqs) / resolution**2 @@ -193,7 +194,7 @@ def shift(self, pfx, shifts): # Broadcast and accumulate phase shifts freqs = xp.tile(xp.asarray(self.freqs), (n, 1, 1)) - phase_shifts = xp.exp(-1j * xp.sum(freqs * -shifts[:, :, None], axis=1)) + phase_shifts = xp.exp(-1j * xp.sum(-freqs * shifts[:, :, None], axis=1)) # Reshape flat frequency grid back to (..., ntheta//2, self.nrad) phase_shifts = phase_shifts.reshape(n, self.ntheta // 2, self.nrad) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index f1dff74404..19d7ee311c 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -3,7 +3,7 @@ import logging import os.path from abc import ABC, abstractmethod -from collections import OrderedDict +from collections import OrderedDict, defaultdict from collections.abc import Iterable import mrcfile @@ -483,15 +483,17 @@ def offsets(self, values): @property def amplitudes(self): - return np.atleast_1d( - self.get_metadata( - "_rlnAmplitude", default_value=np.array(1.0, dtype=self.dtype) - ) + values = self.get_metadata( + "_aspireAmplitude", + default_value=np.array(1.0, dtype=np.float64), ) + return np.atleast_1d(np.asarray(values, dtype=np.float64)) @amplitudes.setter def amplitudes(self, values): - return self.set_metadata("_rlnAmplitude", np.array(values, dtype=self.dtype)) + # Keep amplitudes float64 so downstream filters/metadata retain precision. + values = np.asarray(values, dtype=np.float64) + self.set_metadata("_aspireAmplitude", values) @property def angles(self): @@ -1289,6 +1291,87 @@ def _populate_local_metadata(self): """ return [] + @staticmethod + def _prepare_relion_optics_blocks(metadata): + """ + Split metadata into RELION>=3.1 style `data_optics` and `data_particles` blocks. + + The optics block has one row per optics group with: + `_rlnOpticsGroup`, `_rlnOpticsGroupName`, and optics metadata columns. + The particle block keeps the remaining columns and includes a per-particle + `_rlnOpticsGroup` that references the optics block. + """ + # Columns that belong in RELION's optics table. + all_optics_fields = [ + "_rlnImagePixelSize", + "_rlnMicrographPixelSize", + "_rlnSphericalAberration", + "_rlnVoltage", + "_rlnAmplitudeContrast", + "_rlnImageSize", + "_rlnImageDimensionality", + ] + + # Some optics group fields might not always be present, but are necessary + # for reading the file in Relion. We ensure these fields exist and populate + # with a dummy value if not. + n_rows = len(metadata["_rlnImageName"]) + + missing_fields = [] + + def _ensure_column(field, value): + if field not in metadata: + missing_fields.append(field) + logger.warning( + f"Optics field {field} not found, populating with default value {value}" + ) + metadata[field] = np.full(n_rows, value) + + _ensure_column("_rlnSphericalAberration", 0) + _ensure_column("_rlnVoltage", 0) + _ensure_column("_rlnAmplitudeContrast", 0) + + # Restrict to the optics columns that are actually present on this source. + optics_value_fields = [ + field for field in all_optics_fields if field in metadata + ] + + # Map each unique optics tuple to a 1-based group ID in order encountered. + group_lookup = OrderedDict() + optics_groups = np.empty(n_rows, dtype=int) + + for idx in range(n_rows): + signature = tuple(metadata[field][idx] for field in optics_value_fields) + if signature not in group_lookup: + group_lookup[signature] = len(group_lookup) + 1 + optics_groups[idx] = group_lookup[signature] + + metadata["_rlnOpticsGroup"] = optics_groups + + # Build the optics block rows and assign group names. + optics_block = defaultdict(list) + + for signature, group_id in group_lookup.items(): + optics_block["_rlnOpticsGroup"].append(group_id) + optics_block["_rlnOpticsGroupName"].append(f"opticsGroup{group_id}") + for field, value in zip(optics_value_fields, signature): + optics_block[field].append(value) + + # Tag dummy optics if we had to synthesize any fields + if missing_fields: + optics_block["_aspireNoCTF"] = ["." for _ in range(len(group_lookup))] + + # Everything not lifted into the optics block stays with the particle metadata. + particle_block = OrderedDict() + if "_rlnOpticsGroup" in metadata: + particle_block["_rlnOpticsGroup"] = metadata["_rlnOpticsGroup"] + for key, value in metadata.items(): + if key in optics_value_fields or key == "_rlnOpticsGroup": + continue + particle_block[key] = value + + return optics_block, particle_block + def save_metadata(self, starfile_filepath, batch_size=512, save_mode=None): """ Save updated metadata to a STAR file @@ -1324,12 +1407,27 @@ def save_metadata(self, starfile_filepath, batch_size=512, save_mode=None): for x in np.char.split(metadata["_rlnImageName"].astype(np.str_), sep="@") ] + # Populate _rlnImageSize, _rlnImageDimensionality columns, required for optics_block below + if "_rlnImageSize" not in metadata: + metadata["_rlnImageSize"] = np.full(self.n, self.L, dtype=int) + + if "_rlnImageDimensionality" not in metadata: + metadata["_rlnImageDimensionality"] = np.full(self.n, 2, dtype=int) + + # Separate metadata into optics and particle blocks + optics_block, particle_block = self._prepare_relion_optics_blocks(metadata) + # initialize the star file object and save it odict = OrderedDict() - # since our StarFile only has one block, the convention is to save it with the header "data_", i.e. its name is blank - # if we had a block called "XYZ" it would be saved as "XYZ" - # thus we index the metadata block with "" - odict[""] = metadata + + # StarFile uses the `odict` keys to label the starfile block headers "data_(key)". Following RELION>=3.1 + # convention we label the blocks "data_optics" and "data_particles". + if optics_block is None: + odict["particles"] = particle_block + else: + odict["optics"] = optics_block + odict["particles"] = particle_block + out_star = StarFile(blocks=odict) out_star.write(starfile_filepath) return filename_indices @@ -1400,6 +1498,8 @@ def save_images( # for large arrays. stats.update_header(mrc) + # Add pixel size to header + mrc.voxel_size = self.pixel_size else: # save all images into multiple mrc files in batch size for i_start in np.arange(0, self.n, batch_size): @@ -1413,6 +1513,7 @@ def save_images( f"Saving ImageSource[{i_start}-{i_end-1}] to {mrcs_filepath}" ) im = self.images[i_start:i_end] + im.pixel_size = self.pixel_size im.save(mrcs_filepath, overwrite=overwrite) def estimate_signal_mean_energy( diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index c3620c9bdb..963e7ae427 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -125,6 +125,9 @@ def __init__( for key in offset_keys: del self._metadata[key] + # Detect ASPIRE-generated dummy variables + no_ctf_flag = "_aspireNoCTF" in metadata + # CTF estimation parameters coming from Relion CTF_params = [ "_rlnVoltage", @@ -162,6 +165,14 @@ def __init__( # self.unique_filters of the filter that should be applied self.filter_indices = filter_indices + # If we detect ASPIRE added dummy variables, log and initialize identity filter + elif no_ctf_flag: + logger.info( + "Detected ASPIRE-generated dummy optics; initializing identity filters." + ) + self.unique_filters = [IdentityFilter()] + self.filter_indices = np.zeros(self.n, dtype=int) + # We have provided some, but not all the required params elif any(param in metadata for param in CTF_params): logger.warning( diff --git a/src/aspire/utils/relion_interop.py b/src/aspire/utils/relion_interop.py index 996d321366..c807c7d863 100644 --- a/src/aspire/utils/relion_interop.py +++ b/src/aspire/utils/relion_interop.py @@ -20,7 +20,9 @@ "_rlnDetectorPixelSize": float, "_rlnCtfFigureOfMerit": float, "_rlnMagnification": float, + "_rlnImageDimensionality": int, "_rlnImagePixelSize": float, + "_rlnImageSize": int, "_rlnAmplitudeContrast": float, "_rlnImageName": str, "_rlnOriginalName": str, diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_covar.npy b/tests/saved_test_data/clean70SRibosome_cov2d_covar.npy deleted file mode 100644 index e55122814c..0000000000 Binary files a/tests/saved_test_data/clean70SRibosome_cov2d_covar.npy and /dev/null differ diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_covar.npz b/tests/saved_test_data/clean70SRibosome_cov2d_covar.npz new file mode 100644 index 0000000000..48d3054f43 Binary files /dev/null and b/tests/saved_test_data/clean70SRibosome_cov2d_covar.npz differ diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npy b/tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npy deleted file mode 100644 index c178c240bd..0000000000 Binary files a/tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npy and /dev/null differ diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npz b/tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npz new file mode 100644 index 0000000000..15371a69d8 Binary files /dev/null and b/tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npz differ diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_covarctf_shrink.npy b/tests/saved_test_data/clean70SRibosome_cov2d_covarctf_shrink.npy deleted file mode 100644 index ad957a08e8..0000000000 Binary files a/tests/saved_test_data/clean70SRibosome_cov2d_covarctf_shrink.npy and /dev/null differ diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_covarctf_shrink.npz b/tests/saved_test_data/clean70SRibosome_cov2d_covarctf_shrink.npz new file mode 100644 index 0000000000..117784ffaa Binary files /dev/null and b/tests/saved_test_data/clean70SRibosome_cov2d_covarctf_shrink.npz differ diff --git a/tests/test_FLEbasis2D.py b/tests/test_FLEbasis2D.py index 94ea9c4316..1fdd75d9ff 100644 --- a/tests/test_FLEbasis2D.py +++ b/tests/test_FLEbasis2D.py @@ -21,10 +21,6 @@ def show_fle_params(basis): return f"{basis.nres}-{basis.epsilon}" -def gpu_ci_skip(): - pytest.skip("1e-7 precision for FLEBasis2D") - - fle_params = [ (32, 1e-4), (32, 1e-7), @@ -80,8 +76,6 @@ class TestFLEBasis2D(UniversalBasisMixin): # check closeness guarantees for fast vs dense matrix method def testFastVDense_T(self, basis): - if backend_available("cufinufft") and basis.epsilon == 1e-7: - gpu_ci_skip() dense_b = basis._create_dense_matrix() @@ -97,8 +91,6 @@ def testFastVDense_T(self, basis): assert relerr(result_dense.T, result_fast) < (self.test_eps * basis.epsilon) def testFastVDense(self, basis): - if backend_available("cufinufft") and basis.epsilon == 1e-7: - gpu_ci_skip() dense_b = basis._create_dense_matrix() @@ -120,8 +112,6 @@ def testFastVDense(self, basis): raises=RuntimeError, ) def testEvaluateExpand(self, basis): - if backend_available("cufinufft") and basis.epsilon == 1e-7: - gpu_ci_skip() # compare result of evaluate() vs more accurate expand() # get sample coefficients @@ -135,8 +125,6 @@ def testEvaluateExpand(self, basis): @pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) def testMatchFBEvaluate(basis): - if backend_available("cufinufft") and basis.epsilon == 1e-7: - gpu_ci_skip() # ensure that the basis functions are identical when in match_fb mode fb_basis = FBBasis2D(basis.nres, dtype=np.float64) @@ -170,8 +158,6 @@ def testMatchFBDenseEvaluate(basis): @pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) def testMatchFBEvaluate_t(basis): - if backend_available("cufinufft") and basis.epsilon == 1e-7: - gpu_ci_skip() # ensure that coefficients are the same when evaluating images fb_basis = FBBasis2D(basis.nres, dtype=np.float64) diff --git a/tests/test_array_image_source.py b/tests/test_array_image_source.py index 8dc2a28fb7..a2c3ee4f0c 100644 --- a/tests/test_array_image_source.py +++ b/tests/test_array_image_source.py @@ -323,10 +323,10 @@ def test_dtype_passthrough(dtype): # Check dtypes np.testing.assert_equal(src.dtype, dtype) np.testing.assert_equal(src.images[:].dtype, dtype) - np.testing.assert_equal(src.amplitudes.dtype, dtype) - # offsets are always stored as doubles + # offsets and amplitudes are always stored as doubles np.testing.assert_equal(src.offsets.dtype, np.float64) + np.testing.assert_equal(src.amplitudes.dtype, np.float64) def test_stack_1d_only(): diff --git a/tests/test_commonline_sync3n.py b/tests/test_commonline_sync3n.py index 8305aa85e9..3381a51869 100644 --- a/tests/test_commonline_sync3n.py +++ b/tests/test_commonline_sync3n.py @@ -93,6 +93,7 @@ def test_build_clmatrix(source_orientation_objs): assert within_5 / angle_diffs.size > tol +@pytest.mark.xfail(reason="Issue #1340") def test_estimate_shifts_with_gt_rots(source_orientation_objs): src, orient_est = source_orientation_objs @@ -115,6 +116,7 @@ def test_estimate_shifts_with_gt_rots(source_orientation_objs): np.testing.assert_allclose(mean_dist, 0) +@pytest.mark.xfail(reason="Issue #1340") def test_estimate_shifts_with_est_rots(source_orientation_objs): src, orient_est = source_orientation_objs # Estimate shifts using estimated rotations. diff --git a/tests/test_commonline_utils.py b/tests/test_commonline_utils.py new file mode 100644 index 0000000000..64d015d1dd --- /dev/null +++ b/tests/test_commonline_utils.py @@ -0,0 +1,118 @@ +import numpy as np +import pytest + +from aspire.abinitio import JSync +from aspire.abinitio.commonline_utils import ( + _complete_third_row_to_rot, + _estimate_third_rows, + build_outer_products, +) +from aspire.utils import J_conjugate, Rotation, randn, utest_tolerance + +DTYPES = [np.float32, np.float64] + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +def test_estimate_third_rows(dtype): + """ + Test we accurately estimate a set of 3rd rows of rotation matrices + given the 3rd row outer products vijs = vi @ vj.T and viis = vi @ vi.T. + """ + n_img = 20 + + # `build_outer_products` generates a set of ground truth 3rd rows + # of rotation matrices, then forms the outer products vijs = vi @ vj.T + # and viis = vi @ vi.T. + vijs, viis, gt_vis = build_outer_products(n_img, dtype) + + # Estimate third rows from outer products. + # Due to factorization of V, these might be negated third rows. + vis = _estimate_third_rows(vijs, viis) + + # Check if all-close up to difference of sign + ground_truth = np.sign(gt_vis[0, 0]) * gt_vis + estimate = np.sign(vis[0, 0]) * vis + np.testing.assert_allclose(ground_truth, estimate, rtol=1e-05, atol=1e-08) + + # Check dtype passthrough + assert vis.dtype == dtype + + +def test_complete_third_row(dtype): + """ + Test that `complete_third_row_to_rot` produces a proper rotations + given a set of 3rd rows. + """ + # Build random third rows. + r3 = randn(10, 3, seed=123).astype(dtype) + r3 /= np.linalg.norm(r3, axis=1)[..., np.newaxis] + + # Set first row to be identical with z-axis. + r3[0] = np.array([0, 0, 1], dtype=dtype) + + # Generate rotations. + R = _complete_third_row_to_rot(r3) + + # Check dtype passthrough + assert R.dtype == dtype + + # Assert that first rotation is the identity matrix. + np.testing.assert_allclose(R[0], np.eye(3, dtype=dtype)) + + # Assert that each rotation is orthogonal with determinant 1. + assert np.allclose( + R @ R.transpose((0, 2, 1)), np.eye(3, dtype=dtype), atol=utest_tolerance(dtype) + ) + assert np.allclose(np.linalg.det(R), 1) + + +def test_J_sync(dtype): + """ + Test that the J_sync `power_method` returns a set of signs indicating + the set of relative rotations that need to be J-conjugated to attain + global handedness consistency, and that `global_J_sync` returns the + ground truth rotations up to a spurious J-conjugation. + """ + n = 25 + rots = Rotation.generate_random_rotations(n, dtype=dtype).matrices + + # Generate ground truth and randomly J-conjugate relative rotations, + # keeping track of the signs associated with J-conjugated rotations. + n_choose_2 = (n * (n - 1)) // 2 + signs = np.random.randint(0, 2, n_choose_2) * 2 - 1 + Rijs_gt = np.zeros((n_choose_2, 3, 3), dtype=dtype) + Rijs_conjugated = np.zeros((n_choose_2, 3, 3), dtype=dtype) + ij = 0 + for i in range(n - 1): + Ri = rots[i] + for j in range(i + 1, n): + Rj = rots[j] + Rijs_gt[ij] = Rij = Ri.T @ Rj + if signs[ij] == -1: + Rij = J_conjugate(Rij) + Rijs_conjugated[ij] = Rij + ij += 1 + + # Initialize JSync instance with default params. + J_sync = JSync(n) + + # Perform power method and check that signs are correct up to + # multilication by -1. Also check dtype pass-through. + signs_est = J_sync.power_method(Rijs_conjugated) + np.testing.assert_allclose(signs[0] * signs, signs_est[0] * signs_est) + assert signs_est.dtype == dtype + + # Perform global J sync and check that rotations are correct up to + # a spurious J conjugation. Also check dtype pass-through. + Rijs_sync = J_sync.global_J_sync(Rijs_conjugated) + + # If the first is off by a J, J-conjugate the whole set. + if np.allclose(Rijs_gt[0], J_conjugate(Rijs_sync[0])): + Rijs_sync = J_conjugate(Rijs_sync) + + np.testing.assert_allclose(Rijs_sync, Rijs_gt) + assert Rijs_sync.dtype == dtype diff --git a/tests/test_coordinate_source.py b/tests/test_coordinate_source.py index d2f93d5f65..3b0f29f1ee 100644 --- a/tests/test_coordinate_source.py +++ b/tests/test_coordinate_source.py @@ -526,7 +526,7 @@ def testSave(self): # load saved particle stack saved_star = StarFile(star_path) # we want to read the saved mrcs file from the STAR file - image_name_column = saved_star.get_block_by_index(0)["_rlnImageName"] + image_name_column = saved_star.get_block_by_index(1)["_rlnImageName"] # we're reading a string of the form 0000X@mrcs_path.mrcs _particle, mrcs_path = image_name_column[0].split("@") saved_mrcs_stack = mrcfile.open(os.path.join(self.data_folder, mrcs_path)).data @@ -535,15 +535,31 @@ def testSave(self): self.assertTrue(np.array_equal(imgs.asnumpy()[i], saved_mrcs_stack[i])) # assert that the star file has the correct metadata self.assertEqual( - list(saved_star[""].keys()), + list(saved_star["particles"].keys()), [ - "_rlnImagePixelSize", + "_rlnOpticsGroup", "_rlnSymmetryGroup", "_rlnImageName", "_rlnCoordinateX", "_rlnCoordinateY", ], ) + + self.assertEqual( + list(saved_star["optics"].keys()), + [ + "_rlnOpticsGroup", + "_rlnOpticsGroupName", + "_rlnImagePixelSize", + "_rlnSphericalAberration", + "_rlnVoltage", + "_rlnAmplitudeContrast", + "_rlnImageSize", + "_rlnImageDimensionality", + "_aspireNoCTF", + ], + ) + # assert that all the correct coordinates were saved for i in range(10): self.assertEqual( diff --git a/tests/test_covar2d.py b/tests/test_covar2d.py index b1bf41e231..66e88c4ff8 100644 --- a/tests/test_covar2d.py +++ b/tests/test_covar2d.py @@ -128,15 +128,14 @@ def test_get_mean(cov2d_fixture): def test_get_covar(cov2d_fixture): results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npy"), - allow_pickle=True, + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npz"), ) cov2d, coef_clean = cov2d_fixture[1], cov2d_fixture[2] covar_coef = cov2d._get_covar(coef_clean.asnumpy()) - for im, mat in enumerate(results.tolist()): + for im, mat in enumerate(results.values()): np.testing.assert_allclose(mat, covar_coef[im], rtol=1e-05) @@ -210,13 +209,12 @@ def test_shrinkage(cov2d_fixture, shrinker): cov2d, coef_clean = cov2d_fixture[1], cov2d_fixture[2] results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npy"), - allow_pickle=True, + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npz"), ) covar_coef = cov2d.get_covar(coef_clean, covar_est_opt={"shrinker": shrinker}) - for im, mat in enumerate(results.tolist()): + for im, mat in enumerate(results.values()): np.testing.assert_allclose( mat, covar_coef[im], atol=utest_tolerance(cov2d.dtype) ) @@ -288,12 +286,11 @@ def test_get_covar_ctf(cov2d_fixture, ctf_enabled): sim, cov2d, _, coef, h_ctf_fb, h_idx = cov2d_fixture results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf.npy"), - allow_pickle=True, + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf.npz"), ) covar_coef_ctf = cov2d.get_covar(coef, h_ctf_fb, h_idx, noise_var=NOISE_VAR) - for im, mat in enumerate(results.tolist()): + for im, mat in enumerate(results.values()): # These tolerances were adjusted slightly (1e-8 to 3e-8) to accomodate MATLAB CTF repro changes np.testing.assert_allclose(mat, covar_coef_ctf[im], rtol=3e-05, atol=3e-08) @@ -306,8 +303,7 @@ def test_get_covar_ctf_shrink(cov2d_fixture, ctf_enabled): pytest.skip(reason="Reference file n/a.") results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf_shrink.npy"), - allow_pickle=True, + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf_shrink.npz"), ) covar_opt = { @@ -328,5 +324,5 @@ def test_get_covar_ctf_shrink(cov2d_fixture, ctf_enabled): covar_est_opt=covar_opt, ) - for im, mat in enumerate(results.tolist()): + for im, mat in enumerate(results.values()): np.testing.assert_allclose(mat, covar_coef_ctf_shrink[im]) diff --git a/tests/test_image.py b/tests/test_image.py index 70f5906a64..009be52364 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -11,8 +11,8 @@ from pytest import raises from scipy.datasets import face -from aspire.image import Image -from aspire.utils import Rotation, powerset, utest_tolerance +from aspire.image import Image, compute_fastrotate_interp_tables, fastrotate, sp_rotate +from aspire.utils import Rotation, gaussian_2d, grid_2d, powerset, utest_tolerance from aspire.volume import CnSymmetryGroup from .test_utils import matplotlib_dry_run @@ -564,3 +564,152 @@ def test_save_load_pixel_size(get_images, dtype): np.testing.assert_almost_equal( im2.pixel_size, im.pixel_size, err_msg="Image pixel_size incorrect save-load" ) + + +@pytest.fixture( + params=Image.rotation_methods, ids=lambda x: f"method={x}", scope="module" +) +def rotation_method(request): + return request.param + + +def test_image_rotate(dtype, rotation_method): + """ + Compare image rotations against rotated gaussian blobs. + """ + + L = 129 # Test image size in pixels + num_test_angles = 42 + # Create mask, used to zero edge artifacts + mask = grid_2d(L, normalized=True)["r"] < 0.9 + + def _gen_image(angle, L, n=1, K=10): + """ + Generate `n` `L-by-L` image arrays, + each constructed by a sequence of `K` gaussian blobs, + and reference images with the blob centers rotated by `angle`. + + Return tuple of unrotated and rotated image arrays (n-by-L-by-L). + + :param angle: rotation angle + :param L: size (L-by-L) in pixels + :param K: Number of blobs + :return: + - Array of unrotated data (float64) + - Array of rotated data (float64) + """ + + im = np.zeros((n, L, L), dtype=np.float64) + rotated_im = np.zeros_like(im) + + centers = np.random.randint(-L // 4, L // 4, size=(n, 10, 2)) + sigmas = np.full((n, K, 2), L / 10, dtype=np.float64) + + # Rotate the gaussian specifications + R = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) + rotated_centers = centers @ R + + # Construct each image independently + for i in range(n): + for center, sigma in zip(centers[i], sigmas[i]): + im[i] = im[i] + gaussian_2d(L, center, sigma, dtype=np.float64) + + for center, sigma in zip(rotated_centers[i], sigmas[i]): + rotated_im[i] = rotated_im[i] + gaussian_2d( + L, center, sigma, dtype=np.float64 + ) + + return im, rotated_im + + # Test over a variety of angles `theta` + for theta in np.linspace(0, 2 * np.pi, num_test_angles): + # Generate images and reference (`theta`) rotated images + im, ref = _gen_image(theta, L, n=3) + im = Image(im.astype(dtype, copy=False)) + + # Rotate using `Image`'s `rotation_method` + im_rot = im.rotate(theta, method=rotation_method) + + # Mask off boundary artifacts + masked_diff = (im_rot - ref) * mask + + # Compute L1 error of masked diff, per image + L1_error = np.mean(np.abs(masked_diff), axis=(-1, -2)) + np.testing.assert_array_less( + L1_error, + 0.1, + err_msg=f"{L} pixels using {rotation_method} @ {theta} radians", + ) + + +def test_sp_rotate_inputs(dtype): + """ + Smoke test various input combinations to the scipy rotation wrapper. + """ + + imgs = np.zeros((6, 8, 8), dtype=dtype) + thetas = np.arange(6, dtype=dtype) + theta = thetas[0] # scalar + + # # These are the only supported calls admitted by the function doc. + # singleton, scalar + _ = sp_rotate(imgs[0], theta) + # stack, scalar + _ = sp_rotate(imgs, theta) + + # # These happen to also work with the code, so were put under test. + # # We're not advertising them, as there really isn't a good use + # # case for this wrapper code outside of the internal wrapping + # # application. + # singleton, single element array + _ = sp_rotate(imgs[0], thetas[0:1]) + # stack, single element array + _ = sp_rotate(imgs, thetas[0:1]) + # stack, stack + _ = sp_rotate(imgs, thetas) + # md-stack, md-stack + _ = sp_rotate(imgs.reshape(2, 3, 8, 8), thetas.reshape(2, 3, 1)) + _ = sp_rotate(imgs.reshape(2, 3, 8, 8), thetas.reshape(2, 3)) + + +def test_fastrotate_inputs(dtype): + """ + Smoke test various input combinations to `fastrotate`. + """ + + imgs = np.zeros((6, 8, 8), dtype=dtype) + theta = 42 + + # # These are the supported calls + # singleton, scalar + _ = fastrotate(imgs[0], theta) + # stack, scalar + _ = fastrotate(imgs, theta) + + # # These can also remain under test, but are not advertised. + # stack, single element array + _ = fastrotate(imgs, np.array(theta)) + # singleton, single element array + _ = fastrotate(imgs[0], np.array(theta)) + + +def test_fastrotate_M_arg(dtype): + """ + Smoke test precomputed `M` input to `fastrotate`. + """ + + imgs = np.random.randn(6, 8, 8).astype(dtype) + theta = np.random.uniform(0, 2 * np.pi) + + # Precompute M + M = compute_fastrotate_interp_tables(theta, *imgs.shape[-2:]) + + # Call with theta None + im_rot_M = fastrotate(imgs, None, M=M) + # Compare to calling withou `M` + im_rot = fastrotate(imgs, theta) + np.testing.assert_allclose(im_rot_M, im_rot) + + # Call with theta, should raise + with raises(RuntimeError, match=r".*`theta` must be `None`.*"): + _ = fastrotate(imgs, theta, M=M) diff --git a/tests/test_mean_estimator.py b/tests/test_mean_estimator.py index c7083d4ef1..b56a06f4ce 100644 --- a/tests/test_mean_estimator.py +++ b/tests/test_mean_estimator.py @@ -109,7 +109,7 @@ def test_estimate(sim, estimator, mask): np.testing.assert_array_equal(sim.pixel_size, estimate.pixel_size) -def test_adjoint(sim, basis, estimator): +def test_adjoint(sim): """ Test = for random volume `v` and random images `u`. @@ -128,7 +128,10 @@ def test_adjoint(sim, basis, estimator): lhs = np.dot(proj.asnumpy().flatten(), u.flatten()) rhs = np.dot(backproj.asnumpy().flatten(), v.flatten()) - np.testing.assert_allclose(lhs, rhs, rtol=1e-6) + rtol = 1e-07 # default rtol for assert_allclose + if sim.dtype == np.float32: + rtol = 1e-05 + np.testing.assert_allclose(lhs, rhs, rtol=rtol) def test_src_adjoint(sim, basis, estimator): diff --git a/tests/test_orient_symmetric.py b/tests/test_orient_symmetric.py index 9ab63459df..ed9c5a6904 100644 --- a/tests/test_orient_symmetric.py +++ b/tests/test_orient_symmetric.py @@ -1,10 +1,15 @@ import numpy as np import pytest -from numpy import pi, random -from numpy.linalg import det, norm -from aspire.abinitio import CLSymmetryC2, CLSymmetryC3C4, CLSymmetryCn +from aspire.abinitio import ( + CLSymmetryC2, + CLSymmetryC3C4, + CLSymmetryCn, + build_outer_products, + g_sync, +) from aspire.abinitio.commonline_cn import MeanOuterProductEstimator +from aspire.abinitio.commonline_utils import _cl_angles_to_ind from aspire.source import Simulation from aspire.utils import ( J_conjugate, @@ -12,8 +17,6 @@ all_pairs, cyclic_rotations, mean_aligned_angular_distance, - randn, - utest_tolerance, ) from aspire.volume import CnSymmetricVolume @@ -120,7 +123,7 @@ def test_estimate_rotations(n_img, L, order, dtype): rots_gt = src.rotations # g-synchronize ground truth rotations. - rots_gt_sync = cl_symm.g_sync(rots_est, order, rots_gt) + rots_gt_sync = g_sync(rots_est, order, rots_gt) # Register estimates to ground truth rotations and check that the # mean angular distance between them is less than 3 degrees. @@ -134,9 +137,7 @@ def test_relative_rotations(n_img, L, order, dtype): src, cl_symm = source_orientation_objs(n_img, L, order, dtype) # Estimate relative viewing directions. - cl_symm.build_clmatrix() - cl = cl_symm.clmatrix - Rijs = cl_symm._estimate_all_Rijs_c3_c4(cl) + Rijs = cl_symm._estimate_all_Rijs_c3_c4() # Each Rij belongs to the set {Ri.Tg_n^sRj, JRi.Tg_n^sRjJ}, # s = 1, 2, ..., order. We find the mean squared error over @@ -319,8 +320,8 @@ def test_self_commonlines(n_img, L, order, dtype): # Get angle difference between scl_gt and scl. scl_diff1 = scl_gt - scl scl_diff2 = scl_gt - np.flip(scl, 1) # Order of indices might be switched. - scl_diff1_angle = scl_diff1 * 2 * pi / n_theta - scl_diff2_angle = scl_diff2 * 2 * pi / n_theta + scl_diff1_angle = scl_diff1 * 2 * np.pi / n_theta + scl_diff2_angle = scl_diff2 * 2 * np.pi / n_theta # cosine is invariant to 2pi, and abs is invariant to +-pi due to J-conjugation. # We take the mean deviation wrt to the two lines in each image. @@ -332,7 +333,7 @@ def test_self_commonlines(n_img, L, order, dtype): min_mean_angle_diff = scl_idx.choose(scl_diff_angle_mean) # Assert scl detection rate is 100% for 5 degree angle tolerance - angle_tol_err = 5 * pi / 180 + angle_tol_err = 5 * np.pi / 180 detection_rate = np.count_nonzero(min_mean_angle_diff < angle_tol_err) / len(scl) assert np.allclose(detection_rate, 1.0) @@ -477,48 +478,6 @@ def test_global_J_sync(n_img, dtype): assert np.allclose(viis, viis_sync) -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_estimate_third_rows(dtype): - L = 16 - n_img = 20 - order = 3 # test not dependent on order - _, orient_est = source_orientation_objs(n_img, L, order, dtype) - - # Build outer products vijs, viis, and get ground truth third rows. - vijs, viis, gt_vis = build_outer_products(n_img, dtype) - - # Estimate third rows from outer products. - # Due to factorization of V, these might be negated third rows. - vis = orient_est._estimate_third_rows(vijs, viis) - - # Check if all-close up to difference of sign - ground_truth = np.sign(gt_vis[0, 0]) * gt_vis - estimate = np.sign(vis[0, 0]) * vis - assert np.allclose(ground_truth, estimate) - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_complete_third_row(dtype): - # Build random third rows. - r3 = randn(10, 3, seed=123).astype(dtype) - r3 /= norm(r3, axis=1)[..., np.newaxis] - - # Set first row to be identical with z-axis. - r3[0] = np.array([0, 0, 1], dtype=dtype) - - # Generate rotations. - R = CLSymmetryC3C4._complete_third_row_to_rot(r3) - - # Assert that first rotation is the identity matrix. - assert np.allclose(R[0], np.eye(3, dtype=dtype)) - - # Assert that each rotation is orthogonal with determinant 1. - assert np.allclose( - R @ R.transpose((0, 2, 1)), np.eye(3, dtype=dtype), atol=utest_tolerance(dtype) - ) - assert np.allclose(det(R), 1) - - @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_dtype_pass_through(dtype): L = 16 @@ -554,31 +513,6 @@ def build_self_commonlines_matrix(n_theta, rots, order): return scl_gt -def build_outer_products(n_img, dtype): - # Build random third rows, ground truth vis (unit vectors) - gt_vis = np.zeros((n_img, 3), dtype=dtype) - for i in range(n_img): - random.seed(i) - v = random.randn(3) - gt_vis[i] = v / norm(v) - - # Find outer products viis and vijs for i tol +@pytest.mark.expensive def test_estimate_rotations(source_orientation_objs): src, orient_est = source_orientation_objs @@ -106,6 +132,7 @@ def test_estimate_rotations(source_orientation_objs): mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=1) +@pytest.mark.expensive def test_estimate_shifts_with_gt_rots(source_orientation_objs): src, orient_est = source_orientation_objs @@ -119,15 +146,17 @@ def test_estimate_shifts_with_gt_rots(source_orientation_objs): # Calculate the mean 2D distance between estimates and ground truth. error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() - # Assert that on average estimated shifts are close (within 0.5 pix) to src.offsets + # Assert that on average estimated shifts are close to src.offsets if src.offsets.all() != 0: - np.testing.assert_array_less(mean_dist, 0.5) + np.testing.assert_array_less(mean_dist, 2) else: np.testing.assert_allclose(mean_dist, 0) +@pytest.mark.expensive def test_estimate_shifts_with_est_rots(source_orientation_objs): src, orient_est = source_orientation_objs @@ -138,13 +167,14 @@ def test_estimate_shifts_with_est_rots(source_orientation_objs): error = src.offsets - est_shifts mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() - # Assert that on average estimated shifts are close (within 0.5 pix) to src.offsets + # Assert that on average estimated shifts are close to src.offsets if src.offsets.all() != 0: - np.testing.assert_array_less(mean_dist, 0.5) + np.testing.assert_array_less(mean_dist, 2) else: np.testing.assert_allclose(mean_dist, 0) +@pytest.mark.expensive def test_estimate_rotations_fuzzy_mask(): noisy_src = Simulation( n=35, @@ -226,3 +256,41 @@ def test_command_line(): ) # check that the command completed successfully assert result.exit_code == 0 + + +@pytest.mark.parametrize("cl_algo", CL_ALGOS) +def test_offset_param_passthrough(cl_algo): + """ + Systematically test that offset search configuration passes through all CL classes. + """ + + src = ArrayImageSource(np.random.randn(4, 4), pixel_size=1.23) + + test_args = { + "offsets_max_shift": 0.5, + "offsets_shift_step": 0.1, + "offsets_equations_factor": 1, + "offsets_max_memory": 200, + } + + # Handle special case classes + if cl_algo == CLSymmetryC3C4: + test_args["symmetry"] = "C3" + elif cl_algo == CLSymmetryCn: + test_args["symmetry"] = "C17" + + # Instantiate the CL class under test + orient_est = cl_algo(src, **test_args) + + # Loop over the args and assert they are correctly assigned + for arg, val in test_args.items(): + + # Handle special case arguments + if arg == "offsets_max_shift": + # convert from ratio to pixels + val = np.ceil(val * src.L) + elif arg == "symmetry": + # convert from string `symmetry` to int `order` + arg, val = "order", int(val[1:]) + + assert getattr(orient_est, arg) == val diff --git a/tests/test_relion_source.py b/tests/test_relion_source.py index 009ecd321d..64703ed491 100644 --- a/tests/test_relion_source.py +++ b/tests/test_relion_source.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from aspire.source import RelionSource, Simulation +from aspire.source import ImageSource, RelionSource, Simulation from aspire.utils import RelionStarFile from aspire.volume import SymmetryGroup @@ -61,6 +61,41 @@ def test_symmetry_group(caplog): assert str(src_override_sym.symmetry_group) == "C6" +def test_prepare_relion_optics_blocks_warns(caplog): + """ + Test we warn when optics group metadata is missing. + """ + # metadata dict with no CTF values + metadata = { + "_rlnImagePixelSize": np.array([1.234]), + "_rlnImageSize": np.array([32]), + "_rlnImageDimensionality": np.array([2]), + "_rlnImageName": np.array(["000001@stack.mrcs"]), + } + + caplog.clear() + with caplog.at_level(logging.WARNING): + optics_block, particle_block = ImageSource._prepare_relion_optics_blocks( + metadata.copy() + ) + + # We should get and optics block + assert optics_block is not None + + # Verify defaults were injected. + np.testing.assert_allclose(optics_block["_rlnImagePixelSize"], [1.234]) + np.testing.assert_array_equal(optics_block["_rlnImageSize"], [32]) + np.testing.assert_array_equal(optics_block["_rlnImageDimensionality"], [2]) + np.testing.assert_allclose(optics_block["_rlnVoltage"], [0]) + np.testing.assert_allclose(optics_block["_rlnSphericalAberration"], [0]) + np.testing.assert_allclose(optics_block["_rlnAmplitudeContrast"], [0]) + + # Caplog should contain the warnings about the three missing fields. + assert "Optics field _rlnSphericalAberration not found" in caplog.text + assert "Optics field _rlnVoltage not found" in caplog.text + assert "Optics field _rlnAmplitudeContrast not found" in caplog.text + + def test_pixel_size(caplog): """ Instantiate RelionSource from starfiles containing the following pixel size diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 5859aae934..58bea32128 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -3,13 +3,14 @@ import tempfile from unittest import TestCase +import mrcfile import numpy as np import pytest from aspire.noise import WhiteNoiseAdder from aspire.operators import RadialCTFFilter from aspire.source import RelionSource, Simulation, _LegacySimulation -from aspire.utils import utest_tolerance +from aspire.utils import RelionStarFile, utest_tolerance from aspire.volume import LegacyVolume, SymmetryGroup, Volume from .test_utils import matplotlib_dry_run @@ -627,6 +628,111 @@ def testSimulationSaveFile(self): ) +def test_simulation_save_optics_block(tmp_path): + res = 32 + + # Radial CTF Filters. Should make 3 distinct optics blocks + kv_min, kv_max, kv_ct = 200, 300, 3 + voltages = np.linspace(kv_min, kv_max, kv_ct) + ctf_filters = [RadialCTFFilter(voltage=kv) for kv in voltages] + + # Generate and save Simulation + sim = Simulation( + n=9, L=res, C=1, unique_filters=ctf_filters, pixel_size=1.34 + ).cache() + starpath = tmp_path / "sim.star" + sim.save(starpath, overwrite=True) + + star = RelionStarFile(str(starpath)) + assert star.relion_version == "3.1" + assert star.blocks.keys() == {"optics", "particles"} + + optics = star["optics"] + expected_optics_fields = [ + "_rlnOpticsGroup", + "_rlnOpticsGroupName", + "_rlnImagePixelSize", + "_rlnSphericalAberration", + "_rlnVoltage", + "_rlnAmplitudeContrast", + "_rlnImageSize", + "_rlnImageDimensionality", + ] + + # Check all required fields are present + for field in expected_optics_fields: + assert field in optics + + # Optics group and group name should be 1-indexed + np.testing.assert_array_equal( + optics["_rlnOpticsGroup"], np.arange(1, kv_ct + 1, dtype=int) + ) + np.testing.assert_array_equal( + optics["_rlnOpticsGroupName"], + np.array([f"opticsGroup{i}" for i in range(1, kv_ct + 1)]), + ) + + # Check image size (res) and image dimensionality (2) + np.testing.assert_array_equal(optics["_rlnImageSize"], np.full(kv_ct, res)) + np.testing.assert_array_equal(optics["_rlnImageDimensionality"], np.full(kv_ct, 2)) + + # Due to Simulation random indexing, voltages will be unordered + np.testing.assert_allclose(np.sort(optics["_rlnVoltage"]), voltages) + + # Check that each row of the data_particles block has an associated optics group + particles = star["particles"] + assert "_rlnOpticsGroup" in particles + assert len(particles["_rlnOpticsGroup"]) == sim.n + np.testing.assert_array_equal( + np.sort(np.unique(particles["_rlnOpticsGroup"])), + np.arange(1, kv_ct + 1, dtype=int), + ) + + # Test phase_flip after save/load round trip to ensure correct optics group mapping + rln_src = RelionSource(starpath) + np.testing.assert_allclose( + sim.phase_flip().images[:], rln_src.phase_flip().images[:] + ) + + +def test_simulation_slice_save_roundtrip(tmp_path): + # Radial CTF Filters + kv_min, kv_max, kv_ct = 200, 300, 3 + voltages = np.linspace(kv_min, kv_max, kv_ct) + ctf_filters = [RadialCTFFilter(voltage=kv) for kv in voltages] + + # Generate and save slice of Simulation + sim = Simulation(n=9, L=16, C=1, unique_filters=ctf_filters, pixel_size=1.34) + sliced_sim = sim[::2] + save_path = tmp_path / "sliced_sim.star" + sliced_sim.save(save_path, overwrite=True) + + # Load saved slice and compare to original + reloaded = RelionSource(save_path) + + # Check images + np.testing.assert_allclose( + reloaded.images[:].asnumpy(), + sliced_sim.images[:].asnumpy(), + ) + + # Check metadata related to optics block + metadata_fields = [ + "_rlnVoltage", + "_rlnDefocusU", + "_rlnDefocusV", + "_rlnDefocusAngle", + "_rlnSphericalAberration", + "_rlnAmplitudeContrast", + "_rlnImagePixelSize", + ] + for field in metadata_fields: + np.testing.assert_allclose( + reloaded.get_metadata(field), + sliced_sim.get_metadata(field), + ) + + def test_default_symmetry_group(): # Check that default is "C1". sim = Simulation() @@ -809,6 +915,48 @@ def test_save_overwrite(caplog): check_metadata(sim2, sim2_loaded_renamed) +def test_save_load_dummy_ctf_values(tmp_path, caplog): + """ + Test we populate optics group field with dummy values when none + are present. These values should be detected upon reloading the source. + """ + star_path = tmp_path / "no_ctf.star" + sim = Simulation(n=8, L=16) # no unique_filters, ie. no CTF info + sim.save(star_path, overwrite=True) + + # STAR file should contain our fallback tag + star = RelionStarFile(star_path) + optics_block = star.get_block_by_index(0) + assert "_aspireNoCTF" in optics_block + + # Tag should survive round-trip + caplog.clear() + reloaded = RelionSource(star_path) + assert "_aspireNoCTF" in reloaded._metadata + + # Check message is logged about detecting dummy variables + assert "Detected ASPIRE-generated dummy optics" in caplog.text + + +@pytest.mark.parametrize("batch_size", [1, 6]) +def test_simulation_save_sets_voxel_size(tmp_path, batch_size): + """ + Test we save with pixel_size appended to the mrcfile header. + """ + # Note, n=6 and batch_size=6 exercises save_mode=='single' branch. + sim = Simulation(n=6, L=24, pixel_size=1.37) + info = sim.save(tmp_path / "pixel_size.star", batch_size=batch_size, overwrite=True) + + for stack_name in info["mrcs"]: + stack_path = tmp_path / stack_name + with mrcfile.open(stack_path, permissive=True) as f: + vs = f.voxel_size + header_vals = np.array( + [float(vs.x), float(vs.y), float(vs.z)], dtype=np.float64 + ) + np.testing.assert_allclose(header_vals, sim.pixel_size) + + def check_metadata(sim_src, relion_src): """ Helper function to test if metadata fields in a Simulation match