diff --git a/src/aspire/abinitio/__init__.py b/src/aspire/abinitio/__init__.py index 8ec67c0adc..8f309b0166 100644 --- a/src/aspire/abinitio/__init__.py +++ b/src/aspire/abinitio/__init__.py @@ -5,6 +5,7 @@ g_sync, ) from .commonline_base import CLOrient3D +from .commonline_matrix import CLMatrixOrient3D from .commonline_sdp import CommonlineSDP from .commonline_lud import CommonlineLUD from .commonline_irls import CommonlineIRLS diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 2947195769..75b2add767 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -1,13 +1,12 @@ import logging import math -import os import numpy as np import scipy.sparse as sparse from aspire.image import Image from aspire.operators import PolarFT -from aspire.utils import Rotation, complex_type, fuzzy_mask, tqdm +from aspire.utils import Rotation, fuzzy_mask from aspire.utils.random import choice from .commonline_utils import _generate_shift_phase_and_filter @@ -101,31 +100,7 @@ def __init__( self.mask = mask self._pf = None - # Sanity limit to match potential clmatrix dtype of int16. - if self.n_img > (2**15 - 1): - raise NotImplementedError( - "Commonlines implementation limited to <2**15 images." - ) - - # Auto configure GPU - self.__gpu_module = None - try: - import cupy as cp - - if cp.cuda.runtime.getDeviceCount() >= 1: - gpu_id = cp.cuda.runtime.getDevice() - logger.info( - f"cupy and GPU {gpu_id} found by cuda runtime; enabling cupy." - ) - self.__gpu_module = self.__init_cupy_module() - else: - logger.info("GPU not found, defaulting to numpy.") - - except ModuleNotFoundError: - logger.info("cupy not found, defaulting to numpy.") - # Outputs - self.clmatrix = None self.rotations = None self.shifts = None @@ -189,25 +164,6 @@ def estimate_rotations(self): """ raise NotImplementedError("subclasses should implement this") - @property - def clmatrix(self): - """ - Returns Common Lines Matrix. - - Computes if `clmatrix` is None. - - :return: Common Lines Matrix - """ - if self._clmatrix is None: - self.build_clmatrix() - else: - logger.info("Using existing estimated `clmatrix`.") - return self._clmatrix - - @clmatrix.setter - def clmatrix(self, value): - self._clmatrix = value - @property def rotations(self): """ @@ -246,228 +202,6 @@ def shifts(self): def shifts(self, value): self._shifts = value - def build_clmatrix(self): - """ - Build common-lines matrix from Fourier stack of 2D images - - Wrapper for cpu/gpu dispatch. - """ - - logger.info("Begin building Common Lines Matrix") - - # host/gpu dispatch - if self.__gpu_module: - res = self.build_clmatrix_cu() - else: - res = self.build_clmatrix_host() - - # Unpack result - self._shifts_1d, self.clmatrix = res - - return self.clmatrix - - def build_clmatrix_host(self): - """ - Build common-lines matrix from Fourier stack of 2D images - """ - - n_img = self.n_img - n_check = self.n_check - - if self.n_theta % 2 == 1: - msg = "n_theta must be even" - logger.error(msg) - raise NotImplementedError(msg) - - n_theta_half = self.n_theta // 2 - - # need to do a copy to prevent modifying self.pf for other functions - pf = self.pf.copy() - - # Allocate local variables for return - # clmatrix represents the common lines matrix. - # Namely, clmatrix[i,j] contains the index in image i of - # the common line with image j. Note the common line index - # starts from 0 instead of 1 as Matlab version. -1 means - # there is no common line such as clmatrix[i,i]. - clmatrix = -np.ones((n_img, n_img), dtype=self.dtype) - # When cl_dist[i, j] is not -1, it stores the maximum value - # of correlation between image i and j for all possible 1D shifts. - # We will use cl_dist[i, j] = -1 (including j<=i) to - # represent that there is no need to check common line - # between i and j. Since it is symmetric, - # only above the diagonal entries are necessary. - cl_dist = -np.ones((n_img, n_img), dtype=self.dtype) - - # Allocate variables used for shift estimation - - # set maximum value of 1D shift (in pixels) to search - # between common-lines. - max_shift = self.max_shift - # Set resolution of shift estimation in pixels. Note that - # shift_step can be any positive real number. - shift_step = self.shift_step - # 1D shift between common-lines - shifts_1d = np.zeros((n_img, n_img)) - - # Prepare the shift phases to try and generate filter for common-line detection - r_max = pf.shape[2] - shifts, shift_phases, h = _generate_shift_phase_and_filter( - r_max, max_shift, shift_step, self.dtype - ) - - # Apply bandpass filter, normalize each ray of each image - # Note that only use half of each ray - pf = self._apply_filter_and_norm("ijk, k -> ijk", pf, r_max, h) - - # Setup a progress bar - _total_pairs_to_test = self.n_img * (self.n_check - 1) // 2 - pbar = tqdm(desc="Searching over common line pairs", total=_total_pairs_to_test) - - # Search for common lines between [i, j] pairs of images. - # Creating pf and building common lines are different to the Matlab version. - # The random selection is implemented. - for i in range(n_img - 1): - p1 = pf[i] - p1_real = np.real(p1) - p1_imag = np.imag(p1) - - # build the subset of j images if n_check < n_img - n_remaining = n_img - i - 1 - n_j = min(n_remaining, n_check) - subset_j = np.sort(choice(n_remaining, n_j, replace=False) + i + 1) - - for j in subset_j: - p2_flipped = np.conj(pf[j]) - - for shift in range(len(shifts)): - shift_phase = shift_phases[shift] - p2_shifted_flipped = (shift_phase * p2_flipped).T - # Compute correlations in the positive r direction - part1 = p1_real.dot(np.real(p2_shifted_flipped)) - # Compute correlations in the negative r direction - part2 = p1_imag.dot(np.imag(p2_shifted_flipped)) - - c1 = part1 - part2 - sidx = c1.argmax() - cl1, cl2 = np.unravel_index(sidx, c1.shape) - sval = c1[cl1, cl2] - - c2 = part1 + part2 - sidx = c2.argmax() - cl1_2, cl2_2 = np.unravel_index(sidx, c2.shape) - sval2 = c2[cl1_2, cl2_2] - - if sval2 > sval: - cl1 = cl1_2 - cl2 = cl2_2 + n_theta_half - sval = sval2 - sval = 2 * sval - if sval > cl_dist[i, j]: - clmatrix[i, j] = cl1 - clmatrix[j, i] = cl2 - cl_dist[i, j] = sval - shifts_1d[i, j] = shifts[shift] - pbar.update() - pbar.close() - - return shifts_1d, clmatrix - - def build_clmatrix_cu(self): - """ - Build common-lines matrix from Fourier stack of 2D images - """ - - import cupy as cp - - n_img = self.n_img - r = self.pf.shape[2] - - if self.n_theta % 2 == 1: - msg = "n_theta must be even" - logger.error(msg) - raise NotImplementedError(msg) - - # Copy to prevent modifying self.pf for other functions - # Simultaneously place on GPU - pf = cp.array(self.pf) - - # Allocate local variables for return - # clmatrix represents the common lines matrix. - # Namely, clmatrix[i,j] contains the index in image i of - # the common line with image j. Note the common line index - # starts from 0 instead of 1 as Matlab version. -1 means - # there is no common line such as clmatrix[i,i]. - clmatrix = -cp.ones((n_img, n_img), dtype=np.int16) - - # Allocate variables used for shift estimation - # - # Set maximum value of 1D shift (in pixels) to search - # between common-lines. - # Set resolution of shift estimation in pixels. Note that - # shift_step can be any positive real number. - # - # Prepare the shift phases to try and generate filter for common-line detection - # - # Note the CUDA implementation has been optimized to not - # compute or return diagnostic 1d shifts. - _, shift_phases, h = _generate_shift_phase_and_filter( - r, self.max_shift, self.shift_step, self.dtype - ) - # Transfer to device, dtypes must match kernel header. - shift_phases = cp.asarray(shift_phases, dtype=complex_type(self.dtype)) - - # Apply bandpass filter, normalize each ray of each image - # Note that this only uses half of each ray - pf = self._apply_filter_and_norm("ijk, k -> ijk", pf, r, h) - - # Tranpose `pf` for better (CUDA) memory access pattern, and cast as needed. - pf = cp.ascontiguousarray(pf.T, dtype=complex_type(self.dtype)) - - # Get kernel - if self.dtype == np.float64: - build_clmatrix_kernel = self.__gpu_module.get_function( - "build_clmatrix_kernel" - ) - elif self.dtype == np.float32: - build_clmatrix_kernel = self.__gpu_module.get_function( - "fbuild_clmatrix_kernel" - ) - else: - raise NotImplementedError( - "build_clmatrix_kernel only implemented for float32 and float64." - ) - - # Configure grid of blocks - blkszx = 32 - # Enough blocks to cover n_img-1 - nblkx = (self.n_img + blkszx - 2) // blkszx - blkszy = 32 - # Enough blocks to cover n_img - nblky = (self.n_img + blkszy - 1) // blkszy - - # Launch - logger.info("Launching `build_clmatrix_kernel`.") - build_clmatrix_kernel( - (nblkx, nblky), - (blkszx, blkszy), - ( - n_img, - pf.shape[1], - r, - pf, - clmatrix, - len(shift_phases), - shift_phases, - ), - ) - - # Copy result device arrays to host - clmatrix = clmatrix.get().astype(self.dtype, copy=False) - - # Note diagnostic 1d shifts are not computed in the CUDA implementation. - return None, clmatrix - def estimate_shifts(self): """ Estimate 2D shifts in images @@ -757,27 +491,3 @@ def _apply_filter_and_norm(self, subscripts, pf, r_max, h): pf /= np.linalg.norm(pf, axis=-1)[..., np.newaxis] return pf - - @staticmethod - def __init_cupy_module(): - """ - Private utility method to read in CUDA source and return as - compiled CuPy module. - """ - - import cupy as cp - - # Read in contents of file - fp = os.path.join(os.path.dirname(__file__), "commonline_base.cu") - with open(fp, "r") as fh: - module_code = fh.read() - - # CuPy compile the CUDA code - # Note these optimizations are to steer aggresive optimization - # for single precision code. Fast math will potentionally - # reduce accuracy in single precision. - return cp.RawModule( - code=module_code, - backend="nvcc", - options=("-O3", "--use_fast_math", "--extra-device-vectorization"), - ) diff --git a/src/aspire/abinitio/commonline_c2.py b/src/aspire/abinitio/commonline_c2.py index febd63951a..bae97a243f 100644 --- a/src/aspire/abinitio/commonline_c2.py +++ b/src/aspire/abinitio/commonline_c2.py @@ -3,7 +3,7 @@ import numpy as np from scipy.linalg import eigh -from aspire.abinitio import CLOrient3D, JSync +from aspire.abinitio import CLMatrixOrient3D, JSync from aspire.abinitio.sync_voting import _syncmatrix_ij_vote_3n from aspire.utils import J_conjugate, Rotation, all_pairs @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -class CLSymmetryC2(CLOrient3D): +class CLSymmetryC2(CLMatrixOrient3D): """ Define a class to estimate 3D orientations using common lines methods for molecules with C2 cyclic symmetry. @@ -240,16 +240,13 @@ def _estimate_relative_viewing_directions(self): vi is the third row of the i'th rotation matrix Ri. """ logger.info(f"Estimating relative viewing directions for {self.n_img} images.") - # Step 1: Detect the two pairs of mutual common-lines between each pair of images - self.build_clmatrix() - - # Step 2: Calculate relative rotations associated with both mutual common lines. + # Step 1: Calculate relative rotations associated with both mutual common lines. Rijs, Rijgs = self._estimate_all_Rijs_c2() - # Step 3: Inner J-synchronization + # Step 2: Inner J-synchronization Rijs, Rijgs = self._local_J_sync_c2(Rijs, Rijgs) - # Step 4: Global J-synchronization. + # Step 3: Global J-synchronization. logger.info("Performing global handedness synchronization.") vijs, Rijs, Rijgs = self._global_J_sync(Rijs, Rijgs) diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index 34d0ee4763..eafcea2ecc 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -3,7 +3,7 @@ import numpy as np from numpy.linalg import norm, svd -from aspire.abinitio import CLOrient3D, JSync +from aspire.abinitio import CLMatrixOrient3D, JSync from aspire.abinitio.sync_voting import _syncmatrix_ij_vote_3n from aspire.operators import PolarFT from aspire.utils import J_conjugate, Rotation, all_pairs, anorm, trange @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -class CLSymmetryC3C4(CLOrient3D): +class CLSymmetryC3C4(CLMatrixOrient3D): """ Define a class to estimate 3D orientations using common lines methods for molecules with C3 and C4 cyclic symmetry. @@ -49,6 +49,7 @@ def __init__( degree_res=1, seed=None, mask=True, + disable_gpu=False, **kwargs, ): """ @@ -66,6 +67,9 @@ def __init__( :param seed: Optional seed for RNG. :param mask: Option to mask `src.images` with a fuzzy mask (boolean). Default, `True`, applies a mask. + :param disable_gpu: Disables GPU acceleration; + forces CPU only code for this module. + Defaults to automatically using GPU when available. """ super().__init__( @@ -75,6 +79,7 @@ def __init__( max_shift=max_shift, shift_step=shift_step, mask=mask, + disable_gpu=disable_gpu, **kwargs, ) @@ -137,19 +142,16 @@ def _estimate_relative_viewing_directions(self): vi is the third row of the i'th rotation matrix Ri. """ logger.info(f"Estimating relative viewing directions for {self.n_img} images.") - # Step 1: Detect a single pair of common-lines between each pair of images - self.build_clmatrix() - - # Step 2: Detect self-common-lines in each image + # Step 1: Detect self-common-lines in each image sclmatrix = self._self_clmatrix_c3_c4() - # Step 3: Calculate self-relative-rotations + # Step 2: Calculate self-relative-rotations Riis = self._estimate_all_Riis_c3_c4(sclmatrix) - # Step 4: Calculate relative rotations + # Step 3: Calculate relative rotations Rijs = self._estimate_all_Rijs_c3_c4() - # Step 5: Inner J-synchronization + # Step 4: Inner J-synchronization vijs, viis = self._local_J_sync_c3_c4(Rijs, Riis) return vijs, viis diff --git a/src/aspire/abinitio/commonline_ev.py b/src/aspire/abinitio/commonline_ev.py deleted file mode 100644 index 6aafd507e6..0000000000 --- a/src/aspire/abinitio/commonline_ev.py +++ /dev/null @@ -1,30 +0,0 @@ -import logging - -from aspire.abinitio import CLOrient3D - -logger = logging.getLogger(__name__) - - -class CommLineEV(CLOrient3D): - """ - Class to estimate 3D orientations using Eigenvector method - :cite:`DBLP:journals/siamis/SingerS11` - """ - - def __init__(self, src): - """ - constructor of an object for estimating 3D orientations - """ - pass - - def estimate(self): - """ - perform estimation of orientations - """ - pass - - def output(self): - """ - Output the 3D orientations - """ - pass diff --git a/src/aspire/abinitio/commonline_gcar.py b/src/aspire/abinitio/commonline_gcar.py deleted file mode 100644 index 94794099cd..0000000000 --- a/src/aspire/abinitio/commonline_gcar.py +++ /dev/null @@ -1,33 +0,0 @@ -import logging - -from aspire.abinitio import CLOrient3D - -logger = logging.getLogger(__name__) - - -class CommLineGCAR(CLOrient3D): - """ - Define a derived class to estimate 3D orientations using Globally Consistent Angular Reconstitution described - as below: - R. Coifman, Y. Shkolnisky, F. J. Sigworth, and A. Singer, Reference Free Structure Determination through - Eigenvestors of Center of Mass Operators, Applied and Computational Harmonic Analysis, 28, 296-312 (2010). - - """ - - def __init__(self, src): - """ - constructor of an object for estimating 3D oreintations - """ - pass - - def estimate(self): - """ - perform estimation of orientations - """ - pass - - def output(self): - """ - Output the 3D orientations - """ - pass diff --git a/src/aspire/abinitio/commonline_irls.py b/src/aspire/abinitio/commonline_irls.py index f608adc679..39d77535d6 100644 --- a/src/aspire/abinitio/commonline_irls.py +++ b/src/aspire/abinitio/commonline_irls.py @@ -26,6 +26,7 @@ def __init__( alpha=None, max_rankZ=None, max_rankW=None, + disable_gpu=False, **kwargs, ): """ @@ -41,6 +42,9 @@ def __init__( If None, defaults to max(6, n_img // 4). :param max_rankW: Maximum rank used for projecting the W matrix (for adaptive projection). If None, defaults to max(6, n_img // 4). + :param disable_gpu: Disables GPU acceleration; + forces CPU only code for this module. + Defaults to automatically using GPU when available. """ self.num_itrs = num_itrs @@ -56,6 +60,7 @@ def __init__( max_rankZ=max_rankZ, max_rankW=max_rankW, alpha=alpha, + disable_gpu=disable_gpu, **kwargs, ) @@ -63,9 +68,6 @@ def estimate_rotations(self): """ Estimate rotation matrices using the common lines method with IRLS optimization. """ - logger.info("Computing the common lines matrix.") - self.build_clmatrix() - self.S = self._construct_S(self.clmatrix) weights = np.ones(2 * self.n_img, dtype=self.dtype) gram = np.eye(2 * self.n_img, dtype=self.dtype) diff --git a/src/aspire/abinitio/commonline_lud.py b/src/aspire/abinitio/commonline_lud.py index 759818d8df..f4f69181a6 100644 --- a/src/aspire/abinitio/commonline_lud.py +++ b/src/aspire/abinitio/commonline_lud.py @@ -37,6 +37,7 @@ def __init__( max_mu_itr=20, delta_mu_l=0.1, delta_mu_u=10, + disable_gpu=False, **kwargs, ): """ @@ -84,6 +85,9 @@ def __init__( Default is 0.1. :param delta_mu_u: Upper bound for relative drop ratio to trigger an increase in `mu`. Default is 10. + :param disable_gpu: Disables GPU acceleration; + forces CPU only code for this module. + Defaults to automatically using GPU when available. """ # Handle parameters specific to CommonlineLUD @@ -138,9 +142,6 @@ def estimate_rotations(self): """ Estimate rotation matrices using the common lines method with LUD optimization. """ - logger.info("Computing the common lines matrix.") - self.build_clmatrix() - self._cl_to_C(self.clmatrix) gram = self._compute_Gram() gram = self._restructure_Gram(gram) diff --git a/src/aspire/abinitio/commonline_base.cu b/src/aspire/abinitio/commonline_matrix.cu similarity index 100% rename from src/aspire/abinitio/commonline_base.cu rename to src/aspire/abinitio/commonline_matrix.cu diff --git a/src/aspire/abinitio/commonline_matrix.py b/src/aspire/abinitio/commonline_matrix.py new file mode 100644 index 0000000000..feeb71dea6 --- /dev/null +++ b/src/aspire/abinitio/commonline_matrix.py @@ -0,0 +1,318 @@ +import logging +import os + +import numpy as np + +from aspire.abinitio import CLOrient3D +from aspire.utils import complex_type, tqdm +from aspire.utils.random import choice + +from .commonline_utils import _generate_shift_phase_and_filter + +logger = logging.getLogger(__name__) + + +class CLMatrixOrient3D(CLOrient3D): + """ + An intermediate base class to serve commonline algorithms that use + a commonline matrix. + """ + + def __init__(self, src, disable_gpu=False, **kwargs): + """ + Initialize an object for estimating 3D orientations with a + commonline algorithm that uses a constructed commonlines matrix. + """ + super().__init__(src, **kwargs) + + # Sanity limit to match potential clmatrix dtype of int16. + if self.n_img > (2**15 - 1): + raise NotImplementedError( + "Commonlines implementation limited to <2**15 images." + ) + + # Auto configure GPU + self.__gpu_module = None + if not disable_gpu: + try: + import cupy as cp + + if cp.cuda.runtime.getDeviceCount() >= 1: + gpu_id = cp.cuda.runtime.getDevice() + logger.info( + f"cupy and GPU {gpu_id} found by cuda runtime; enabling cupy." + ) + self.__gpu_module = self.__init_cupy_module() + else: + logger.info("GPU not found, defaulting to numpy.") + + except ModuleNotFoundError: + logger.info("cupy not found, defaulting to numpy.") + + # Outputs + self.clmatrix = None + + @property + def clmatrix(self): + """ + Returns Common Lines Matrix. + + Computes if `clmatrix` is None. + + :return: Common Lines Matrix + """ + if self._clmatrix is None: + self.build_clmatrix() + else: + logger.info("Using existing estimated `clmatrix`.") + return self._clmatrix + + @clmatrix.setter + def clmatrix(self, value): + self._clmatrix = value + + def build_clmatrix(self): + """ + Build common-lines matrix from Fourier stack of 2D images + + Wrapper for cpu/gpu dispatch. + """ + + logger.info("Begin building Common Lines Matrix") + + # host/gpu dispatch + if self.__gpu_module: + res = self.build_clmatrix_cu() + else: + res = self.build_clmatrix_host() + + # Unpack result + self._shifts_1d, self.clmatrix = res + + return res + + def build_clmatrix_host(self): + """ + Build common-lines matrix from Fourier stack of 2D images + """ + + n_img = self.n_img + n_check = self.n_check + + if self.n_theta % 2 == 1: + msg = "n_theta must be even" + logger.error(msg) + raise NotImplementedError(msg) + + n_theta_half = self.n_theta // 2 + + # need to do a copy to prevent modifying self.pf for other functions + pf = self.pf.copy() + + # Allocate local variables for return + # clmatrix represents the common lines matrix. + # Namely, clmatrix[i,j] contains the index in image i of + # the common line with image j. Note the common line index + # starts from 0 instead of 1 as Matlab version. -1 means + # there is no common line such as clmatrix[i,i]. + clmatrix = -np.ones((n_img, n_img), dtype=self.dtype) + # When cl_dist[i, j] is not -1, it stores the maximum value + # of correlation between image i and j for all possible 1D shifts. + # We will use cl_dist[i, j] = -1 (including j<=i) to + # represent that there is no need to check common line + # between i and j. Since it is symmetric, + # only above the diagonal entries are necessary. + cl_dist = -np.ones((n_img, n_img), dtype=self.dtype) + + # Allocate variables used for shift estimation + + # set maximum value of 1D shift (in pixels) to search + # between common-lines. + max_shift = self.max_shift + # Set resolution of shift estimation in pixels. Note that + # shift_step can be any positive real number. + shift_step = self.shift_step + # 1D shift between common-lines + shifts_1d = np.zeros((n_img, n_img)) + + # Prepare the shift phases to try and generate filter for common-line detection + r_max = pf.shape[2] + shifts, shift_phases, h = _generate_shift_phase_and_filter( + r_max, max_shift, shift_step, self.dtype + ) + + # Apply bandpass filter, normalize each ray of each image + # Note that only use half of each ray + pf = self._apply_filter_and_norm("ijk, k -> ijk", pf, r_max, h) + + # Setup a progress bar + _total_pairs_to_test = self.n_img * (self.n_check - 1) // 2 + pbar = tqdm(desc="Searching over common line pairs", total=_total_pairs_to_test) + + # Search for common lines between [i, j] pairs of images. + # Creating pf and building common lines are different to the Matlab version. + # The random selection is implemented. + for i in range(n_img - 1): + p1 = pf[i] + p1_real = np.real(p1) + p1_imag = np.imag(p1) + + # build the subset of j images if n_check < n_img + n_remaining = n_img - i - 1 + n_j = min(n_remaining, n_check) + subset_j = np.sort(choice(n_remaining, n_j, replace=False) + i + 1) + + for j in subset_j: + p2_flipped = np.conj(pf[j]) + + for shift in range(len(shifts)): + shift_phase = shift_phases[shift] + p2_shifted_flipped = (shift_phase * p2_flipped).T + # Compute correlations in the positive r direction + part1 = p1_real.dot(np.real(p2_shifted_flipped)) + # Compute correlations in the negative r direction + part2 = p1_imag.dot(np.imag(p2_shifted_flipped)) + + c1 = part1 - part2 + sidx = c1.argmax() + cl1, cl2 = np.unravel_index(sidx, c1.shape) + sval = c1[cl1, cl2] + + c2 = part1 + part2 + sidx = c2.argmax() + cl1_2, cl2_2 = np.unravel_index(sidx, c2.shape) + sval2 = c2[cl1_2, cl2_2] + + if sval2 > sval: + cl1 = cl1_2 + cl2 = cl2_2 + n_theta_half + sval = sval2 + sval = 2 * sval + if sval > cl_dist[i, j]: + clmatrix[i, j] = cl1 + clmatrix[j, i] = cl2 + cl_dist[i, j] = sval + shifts_1d[i, j] = shifts[shift] + pbar.update() + pbar.close() + + return shifts_1d, clmatrix + + def build_clmatrix_cu(self): + """ + Build common-lines matrix from Fourier stack of 2D images + """ + + import cupy as cp + + n_img = self.n_img + r = self.pf.shape[2] + + if self.n_theta % 2 == 1: + msg = "n_theta must be even" + logger.error(msg) + raise NotImplementedError(msg) + + # Copy to prevent modifying self.pf for other functions + # Simultaneously place on GPU + pf = cp.array(self.pf) + + # Allocate local variables for return + # clmatrix represents the common lines matrix. + # Namely, clmatrix[i,j] contains the index in image i of + # the common line with image j. Note the common line index + # starts from 0 instead of 1 as Matlab version. -1 means + # there is no common line such as clmatrix[i,i]. + clmatrix = -cp.ones((n_img, n_img), dtype=np.int16) + + # Allocate variables used for shift estimation + # + # Set maximum value of 1D shift (in pixels) to search + # between common-lines. + # Set resolution of shift estimation in pixels. Note that + # shift_step can be any positive real number. + # + # Prepare the shift phases to try and generate filter for common-line detection + # + # Note the CUDA implementation has been optimized to not + # compute or return diagnostic 1d shifts. + _, shift_phases, h = _generate_shift_phase_and_filter( + r, self.max_shift, self.shift_step, self.dtype + ) + # Transfer to device, dtypes must match kernel header. + shift_phases = cp.asarray(shift_phases, dtype=complex_type(self.dtype)) + + # Apply bandpass filter, normalize each ray of each image + # Note that this only uses half of each ray + pf = self._apply_filter_and_norm("ijk, k -> ijk", pf, r, h) + + # Tranpose `pf` for better (CUDA) memory access pattern, and cast as needed. + pf = cp.ascontiguousarray(pf.T, dtype=complex_type(self.dtype)) + + # Get kernel + if self.dtype == np.float64: + build_clmatrix_kernel = self.__gpu_module.get_function( + "build_clmatrix_kernel" + ) + elif self.dtype == np.float32: + build_clmatrix_kernel = self.__gpu_module.get_function( + "fbuild_clmatrix_kernel" + ) + else: + raise NotImplementedError( + "build_clmatrix_kernel only implemented for float32 and float64." + ) + + # Configure grid of blocks + blkszx = 32 + # Enough blocks to cover n_img-1 + nblkx = (self.n_img + blkszx - 2) // blkszx + blkszy = 32 + # Enough blocks to cover n_img + nblky = (self.n_img + blkszy - 1) // blkszy + + # Launch + logger.info("Launching `build_clmatrix_kernel`.") + build_clmatrix_kernel( + (nblkx, nblky), + (blkszx, blkszy), + ( + n_img, + pf.shape[1], + r, + pf, + clmatrix, + len(shift_phases), + shift_phases, + ), + ) + + # Copy result device arrays to host + clmatrix = clmatrix.get().astype(self.dtype, copy=False) + + # Note diagnostic 1d shifts are not computed in the CUDA implementation. + return None, clmatrix + + @staticmethod + def __init_cupy_module(): + """ + Private utility method to read in CUDA source and return as + compiled CuPy module. + """ + + import cupy as cp + + # Read in contents of file + fp = os.path.join(os.path.dirname(__file__), "commonline_matrix.cu") + with open(fp, "r") as fh: + module_code = fh.read() + + # CuPy compile the CUDA code + # Note these optimizations are to steer aggresive optimization + # for single precision code. Fast math will potentionally + # reduce accuracy in single precision. + return cp.RawModule( + code=module_code, + backend="nvcc", + options=("-O3", "--use_fast_math", "--extra-device-vectorization"), + ) diff --git a/src/aspire/abinitio/commonline_sdp.py b/src/aspire/abinitio/commonline_sdp.py index 140fe27733..d2baba27c5 100644 --- a/src/aspire/abinitio/commonline_sdp.py +++ b/src/aspire/abinitio/commonline_sdp.py @@ -4,14 +4,14 @@ import numpy as np from scipy.sparse import csr_array -from aspire.abinitio import CLOrient3D +from aspire.abinitio import CLMatrixOrient3D from aspire.utils import nearest_rotations from aspire.utils.matlab_compat import stable_eigsh logger = logging.getLogger(__name__) -class CommonlineSDP(CLOrient3D): +class CommonlineSDP(CLMatrixOrient3D): """ Class to estimate 3D orientations using semi-definite programming. @@ -27,9 +27,6 @@ def estimate_rotations(self): """ Estimate rotation matrices using the common lines method with semi-definite programming. """ - logger.info("Computing the common lines matrix.") - self.build_clmatrix() - S = self._construct_S(self.clmatrix) A, b = self._sdp_prep() gram = self._compute_gram_SDP(S, A, b) diff --git a/src/aspire/abinitio/commonline_sync.py b/src/aspire/abinitio/commonline_sync.py index 6ed774958b..0a8aa3397d 100644 --- a/src/aspire/abinitio/commonline_sync.py +++ b/src/aspire/abinitio/commonline_sync.py @@ -2,7 +2,7 @@ import numpy as np -from aspire.abinitio import CLOrient3D +from aspire.abinitio import CLMatrixOrient3D from aspire.abinitio.sync_voting import _rotratio_eulerangle_vec, _vote_ij from aspire.utils import nearest_rotations from aspire.utils.matlab_compat import stable_eigsh @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -class CLSyncVoting(CLOrient3D): +class CLSyncVoting(CLMatrixOrient3D): """ Define a class to estimate 3D orientations using synchronization matrix and voting method. @@ -34,6 +34,7 @@ def __init__( hist_bin_width=3, full_width=6, mask=True, + disable_gpu=False, **kwargs, ): """ @@ -52,6 +53,9 @@ def __init__( `hist_bin_width`s required to find at least one valid image index. :param mask: Option to mask `src.images` with a fuzzy mask (boolean). Default, `True`, applies a mask. + :param disable_gpu: Disables GPU acceleration; + forces CPU only code for this module. + Defaults to automatically using GPU when available. """ super().__init__( src, @@ -62,6 +66,7 @@ def __init__( hist_bin_width=hist_bin_width, full_width=full_width, mask=mask, + disable_gpu=disable_gpu, **kwargs, ) self.syncmatrix = None diff --git a/src/aspire/abinitio/commonline_sync3n.py b/src/aspire/abinitio/commonline_sync3n.py index 7be4b37c2c..f6c6a4690e 100644 --- a/src/aspire/abinitio/commonline_sync3n.py +++ b/src/aspire/abinitio/commonline_sync3n.py @@ -6,7 +6,7 @@ from numpy.linalg import norm from scipy.optimize import curve_fit -from aspire.abinitio import CLOrient3D +from aspire.abinitio import CLMatrixOrient3D from aspire.abinitio.sync_voting import _syncmatrix_ij_vote_3n from aspire.utils import J_conjugate, all_pairs, nearest_rotations, random, tqdm, trange from aspire.utils.matlab_compat import stable_eigsh @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -class CLSync3N(CLOrient3D): +class CLSync3N(CLMatrixOrient3D): """ Define a class to estimate 3D orientations using common lines Sync3N methods (2017). @@ -102,6 +102,7 @@ def __init__( hist_bin_width=hist_bin_width, full_width=full_width, mask=mask, + disable_gpu=disable_gpu, **kwargs, ) @@ -154,10 +155,6 @@ def estimate_rotations(self): """ logger.info(f"Estimating relative viewing directions for {self.n_img} images.") - - # Detect a single pair of common-lines between each pair of images - self.build_clmatrix() - # Initial estimate of viewing directions # Calculate relative rotations Rijs0 = self._estimate_all_Rijs(self.clmatrix) diff --git a/tests/test_commonline_matrix.py b/tests/test_commonline_matrix.py new file mode 100644 index 0000000000..fd6a70e35f --- /dev/null +++ b/tests/test_commonline_matrix.py @@ -0,0 +1,89 @@ +import numpy as np +import pytest + +from aspire.abinitio import ( + CLMatrixOrient3D, + CLSymmetryC2, + CLSymmetryC3C4, + CLSync3N, + CLSyncVoting, + CommonlineIRLS, + CommonlineLUD, + CommonlineSDP, +) +from aspire.downloader import emdb_2660 +from aspire.source import Simulation + +SUBCLASSES = [ + CLSymmetryC2, + CLSymmetryC3C4, + CLSync3N, + CLSyncVoting, + CommonlineIRLS, + CommonlineLUD, + CommonlineSDP, +] + + +DTYPES = [ + np.float32, + pytest.param(np.float64, marks=pytest.mark.expensive), +] + + +@pytest.fixture(params=SUBCLASSES, ids=lambda x: f"subclass={x}", scope="module") +def subclass(request): + return request.param + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(scope="module") +def src(dtype): + src = Simulation( + n=10, + vols=emdb_2660().astype(dtype).downsample(32), + offsets=0, + amplitudes=1, + seed=0, + ).cache() + + return src + + +def test_class_structure(subclass): + assert issubclass(subclass, CLMatrixOrient3D) + + +def test_clmatrix_lazy_eval(subclass, src, caplog): + """ + Test lazy evaluation of commonlines matrix and associated log message. + """ + cl_kwargs = dict(src=src) + if subclass == CLSymmetryC3C4: + cl_kwargs["symmetry"] = "C3" + + caplog.clear() + msg = "Using existing estimated `clmatrix`." + + # Initialize commonlines class + clmat_algo = subclass(**cl_kwargs) + + # clmatrix should be none at this point + assert clmat_algo._clmatrix is None + assert msg not in caplog.text + + # Request clmatrix + _ = clmat_algo.clmatrix + + # clmatrix should be populated + assert clmat_algo._clmatrix is not None + assert msg not in caplog.text + + # 2nd request should access cached matrix and log message + # that we are using the stored matrix + _ = clmat_algo.clmatrix + assert msg in caplog.text