From 8577e1dedab321c25ade6b1b7b3ff2d9f62a7d24 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 21 Oct 2025 09:06:24 -0400 Subject: [PATCH 01/91] Update FLE tests now that cufinufft improved --- tests/test_FLEbasis2D.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/test_FLEbasis2D.py b/tests/test_FLEbasis2D.py index 94ea9c4316..1fdd75d9ff 100644 --- a/tests/test_FLEbasis2D.py +++ b/tests/test_FLEbasis2D.py @@ -21,10 +21,6 @@ def show_fle_params(basis): return f"{basis.nres}-{basis.epsilon}" -def gpu_ci_skip(): - pytest.skip("1e-7 precision for FLEBasis2D") - - fle_params = [ (32, 1e-4), (32, 1e-7), @@ -80,8 +76,6 @@ class TestFLEBasis2D(UniversalBasisMixin): # check closeness guarantees for fast vs dense matrix method def testFastVDense_T(self, basis): - if backend_available("cufinufft") and basis.epsilon == 1e-7: - gpu_ci_skip() dense_b = basis._create_dense_matrix() @@ -97,8 +91,6 @@ def testFastVDense_T(self, basis): assert relerr(result_dense.T, result_fast) < (self.test_eps * basis.epsilon) def testFastVDense(self, basis): - if backend_available("cufinufft") and basis.epsilon == 1e-7: - gpu_ci_skip() dense_b = basis._create_dense_matrix() @@ -120,8 +112,6 @@ def testFastVDense(self, basis): raises=RuntimeError, ) def testEvaluateExpand(self, basis): - if backend_available("cufinufft") and basis.epsilon == 1e-7: - gpu_ci_skip() # compare result of evaluate() vs more accurate expand() # get sample coefficients @@ -135,8 +125,6 @@ def testEvaluateExpand(self, basis): @pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) def testMatchFBEvaluate(basis): - if backend_available("cufinufft") and basis.epsilon == 1e-7: - gpu_ci_skip() # ensure that the basis functions are identical when in match_fb mode fb_basis = FBBasis2D(basis.nres, dtype=np.float64) @@ -170,8 +158,6 @@ def testMatchFBDenseEvaluate(basis): @pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) def testMatchFBEvaluate_t(basis): - if backend_available("cufinufft") and basis.epsilon == 1e-7: - gpu_ci_skip() # ensure that coefficients are the same when evaluating images fb_basis = FBBasis2D(basis.nres, dtype=np.float64) From 8df0fdecb28c7937264bb8ab18cd4d3c08a36673 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 21 Oct 2025 15:53:31 -0400 Subject: [PATCH 02/91] update test data files --- .../clean70SRibosome_cov2d_covar.npy | Bin 1431 -> 1009 bytes .../clean70SRibosome_cov2d_covarctf.npy | Bin 1709 -> 1185 bytes ...clean70SRibosome_cov2d_covarctf_shrink.npy | Bin 1903 -> 1513 bytes 3 files changed, 0 insertions(+), 0 deletions(-) diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_covar.npy b/tests/saved_test_data/clean70SRibosome_cov2d_covar.npy index e55122814c264f363639763f4ab77dcd71c46fc7..68a6cd0c7f48d89a188e74840295b511312913d4 100644 GIT binary patch literal 1009 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1Jdf4+)AW9}r@n&e9;>?&drF}}!6b)}iZ^70n8O%Mb zDJ7K!sUR&({uWa@J0P;mJ&ZL|{QUg9{sRG+@U}0RlGGVHMI(bLgB4^{20P3k7H^ie zDH%LHOlc?v*q|A}05ZV2K`%-r)3(smzCWPp3QMo-e!*q?kA&JN?GIa^@ic+z+c!|faUVCrGwCd!{brJt0tFXnD^P%>qOn*&o16ZgD0 z?_B%mFLqOa&cO5ovo|x)50ahE9>+grXIKJa?B{`AZ0`(g4hzWwzgr-b*0_F#Jy@Y&S$UOf5D3^RL~ZG~F`7fRYk-3nR1 zB8Fo>jIXhn?G4*9J^Sy5HrA>nU1Ck$5gk&RpA(pncNIU12Hokk>&Q_Xgnh7aU(b$5zkJ~X;-zY;(EpZIP5p_ zLeYt_@kw_mW=7nBv58SW5;V==WME<}$PIhIjD%v*I5#Ik+zv{n*&YpV_1M9tDKn54 z3WQ?eNGNWPMCHN``F@+CC^{7Qi^P32$VZ|cC`>>RgSV2fi^1DY&8C&kfc-yUm`{$2 z?~46-qQAQlN;E~YC1E#%cbr;=HU;l?b#-2=V%T9XY+f%Z$1?P7U;jgsVTdHzeVf2XppcJinJv4g?R}~VkFlFWJ z=&GxtKWXC|nmwq-i}7YlWqqO5v;(r;&_VZbiSA%heEh--9~ffThZr2@a752e=R-N2 z3I<2zbn>lCsZFTd|QhPH1CitmVWCCl@`pluW8+*LX`}6jKOgZCv@XzHl9Pqb26yT#rr4| zZ$!ky+q7O5qE%dMP!HiNz!yGzT&@n%!(Lp+XP0+j+ERAnlOkFz!WFkluRXL_Re}rU zn6`FX>o$B%bF@4_&x=(!Dbu+a)NpX?Ms+q_txWeZ1CLCXr{yK!lT5lgIlMdgV=b*+ zqYaPRiND}&T*CQE`WsU_pJ<7m?x$w~TC}NfN{Z?j)N^Ri4Nq2dT8bJOoXLrrG9tGS zoyS+&KD^G*P5?S(Q_}HlZNvQKI71?8?9Nx1!=g* zphFtOnSjoWp+p!O>G6=5_Y$s^sh~-~=L{}!=+cdjETCHgdKma}0JaTiAI9I0<8L@e zZ@ln@1oSfKho0XV;xRN&+r3xWZvbHwGdc0=oNO D2#+EL diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npy b/tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npy index c178c240bd11634d52086311be90a7d4765b1196..4658b5eb7aa5c073ca6a5ce5d173ae4d7e4675c0 100644 GIT binary patch literal 1185 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1Jdf4+)AW9}r@n&e9;>?&drF}}!6b)}iZ^70n8O%Mb zDJ7K!sUR&({uWa@J0P;mJ&ZL|{QUg9{sRG+@U}0RlGGVHMI(bLgB4^{20P3k7H^ie zDH%LHOlc;_2H2n(zyLD9*&!``la+|Kgp~oK@vep)Y3AZUw%E=$R)E07L5AY(-jGY_N zy*xp8kZ%T)Hxnd01e|oWEPXpzc7m{F-Z=~GekWf&`;0#pYo1-+0m8k9Hxe|TRQ2Hw w?<^PC>w>Y^w`zjs6Q`chn-LN|EHhWl!I7p&Qv=GO*uxfK4z<;!0`r0%01lO;!2kdN literal 1709 zcmbu6?Nb|76o)qp^VR4G)W zNGK}!g0|G6tzZ058H;|PB|78Gy?>bAleKq7{7`4OJ9GE$ne#lqv%3q)qXS1y1dP+h zL^PK#J6_a^L?;KM@kq2#^0=2Nrb}KfPvt|I;c{Np%Wh^gubyL_J@LKq$l1sSR=5&$ zien?A9LUvcf)Bf@05yV?u|J-48v~@7jyNVFaowQji8fu zO1ZpK4(E$>2%{ktFbpG$P-?Tf4`=vLu@jrB*v!y4gWDKxkJ$lx>omgu2y=XVRB?Bt zf~tQofje!(4$R;#hPz|7Z=1wD$z<~IUuxAiyN_c_cJ?H;#$3a#4|1=o?DsJ=kv(LG zs@UfG%0f4qedTtBmYd36-2ZQ-!SO(L)@?|o)EI?R;y0_wZm~K7`$SDt#EnAFcJ0_?9(2{G*uk)q zV^{X(?7M092*X1(dxPCj#lvn3jnbt?5!d!RR!&sp`GlBDb%>gOJ|kA-{gaD?2PaLLqw?9_4sUheRnP#;{k1v~DQmaSDkuv~jfSkS8dlgP~K0bon9O>P_U7 ztjRwuvARpvOnIp!7EN)ze!ebNV5TR-(D2e8&&SUrceO zug%|@d5(hCoFa|OFr+o_+(Pp%X@(g_IEq?RBF!iRNTaSQJYUnOrV$s# zRk38EOeAL*<2bF!_lSI-VO*1E{!2~}ImvLA;{#28NMwcKBTatnlOdH!3sZ!B!Z6KI z)z}PS=NM)+_USt8Gr~S+nB$n&*cXI-$?%oNzFvnd5cUnjw;bm+c7d=(hKm~eZXI@s RuqB4?Iey5he{(t8{x^h+QTG4< diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_covarctf_shrink.npy b/tests/saved_test_data/clean70SRibosome_cov2d_covarctf_shrink.npy index ad957a08e842e5df374ca8febd60924a5d58f004..d6904791132c9731bc44d131a874efb63a92bcb2 100644 GIT binary patch literal 1513 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1Jdf4+)AW9}r@n&e9;>?&drF}}!6b)}iZ^70n8O%Mb zDJ7K!sUR&({uWa@J0P;mJ&ZL|{QUg9{sRG+@U}0RlGGVHMI(bLgB4^{20P3k7H^ie zDH%LHOlc?v*q|A}05qV(xdGypfg`5?oq_2GW^ZPoA0#`RJ*effQG3V>N$laqCmq5#xy`-t}zJTEz=Q6i-AV1z$yriZxAxLEE&ybcC0c2+n!L#R-SlT|x0HAjUBbpIALMis;gG;# znQrFv5yDYUdo^zX$BJqHADlIgv+@Mtq(F;Qo?>My;j~n?m7&yFd4|v~5Ru9bR(2B3 zN+p${JjcrOgf4*#QhAY;Q9`#=UMfR*nUz-vJpxyyGRDel1YIiQWhf0+CJ3fLuT)-V zrA6qI${S@UZ?ZB;ND0_dnPz2%09LwJ0ghEO;lABfznnO-j0!?dpkJc57`;u%OY}|& zbb!%8!jQmSi4HS5LbxZ<`z6o^j6Nic3KS$d#^^X Date: Wed, 22 Oct 2025 11:29:22 -0400 Subject: [PATCH 03/91] save data files with npz format. Refactor tests to handle npz dictionaries. --- .../clean70SRibosome_cov2d_covar.npy | Bin 1009 -> 0 bytes .../clean70SRibosome_cov2d_covar.npz | Bin 0 -> 4860 bytes .../clean70SRibosome_cov2d_covarctf.npy | Bin 1185 -> 0 bytes .../clean70SRibosome_cov2d_covarctf.npz | Bin 0 -> 4532 bytes ...clean70SRibosome_cov2d_covarctf_shrink.npy | Bin 1513 -> 0 bytes ...clean70SRibosome_cov2d_covarctf_shrink.npz | Bin 0 -> 4860 bytes tests/test_covar2d.py | 20 +++++++----------- 7 files changed, 8 insertions(+), 12 deletions(-) delete mode 100644 tests/saved_test_data/clean70SRibosome_cov2d_covar.npy create mode 100644 tests/saved_test_data/clean70SRibosome_cov2d_covar.npz delete mode 100644 tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npy create mode 100644 tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npz delete mode 100644 tests/saved_test_data/clean70SRibosome_cov2d_covarctf_shrink.npy create mode 100644 tests/saved_test_data/clean70SRibosome_cov2d_covarctf_shrink.npz diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_covar.npy b/tests/saved_test_data/clean70SRibosome_cov2d_covar.npy deleted file mode 100644 index 68a6cd0c7f48d89a188e74840295b511312913d4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1009 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1Jdf4+)AW9}r@n&e9;>?&drF}}!6b)}iZ^70n8O%Mb zDJ7K!sUR&({uWa@J0P;mJ&ZL|{QUg9{sRG+@U}0RlGGVHMI(bLgB4^{20P3k7H^ie zDH%LHOlc?v*q|A}05ZV2K`%-r)3(smzCWPp3QMo-e!*q?kA&JN?GIa^@ic+z+c!|faUVCrGwCd!{brJt0tFXnD^P%>qOn*&o16ZgD0 z?_B%mFLqOa&cO5ovo|x)50ahE9>+grXIKJa?B{`AZ0`(g4hzWwzgr-b*0_F#Jy@Y&S$UOf5D3^RL~ZG~F`7fRYk-3nR1 zB8Fo>jIXhn?G4*9J^Sy5HM0|ULhf=WgP0R{#}s7fe3*)P;LAd-=x zjG* zW`>!)%(lWUfeR(=qi%()UlGHxAI8^M%=U(DnV$W3!yC#x(!fL5r2abErzjs&TM#ij{kOhG-k}14_42)Z1WbG@J!cUuHt#iS9g~EFuKKiN~l#n zdR2wpa~8wHb8h~%#h5W?Iaoaxr=$GSW3IZr@~jdir!_zJ!RWget88j(KB6@;uzJoA z-a>;FBrtlQ3l!2kcfasq3r5Tt!dp-s(9#k{Q{*{A9R))QOG}G~ap&)?6-O%-v3mmE z0)janMh{L;Jh^kWA2TN4Ef|>NVe~+I!Yyr`>XQDSXn6&@C*Um*nB!se;Pixa;e{m3 zn1Ht+V2+2;1MP{)!X;B3jyyn%3GAMLw-jKGhtY$_6O2r{%&6UCa32`h|K)}jMj)yI z+#g2jEC+a_Y6Nu+PXHByFsepdpjL=Ra4|=`MkgkiMsU%Bt`XGw#%`lKGfX2mI}@+b zhXtk)oQ2Ref?9(Mfc^tvR6ho?!Zd=DE%6%l*kBsLNd;Xas1<|VMs9YPMsS0Tc#SXM z8o`Yrbd8|e7rTu|Ibb$|8yduGT*L|02&u%;HG*oB4q%9aFlwmQbHOxXO*yzVx&X}q zCpkl`DF;+)Vz-f#2c{8g%E7Ji99$#TlmjXqu-iD97iJ^Yl!IGi1RqQzH07`YlM(}i OAP}Aerhi#}5Dx$w3%m#b literal 0 HcmV?d00001 diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npy b/tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npy deleted file mode 100644 index 4658b5eb7aa5c073ca6a5ce5d173ae4d7e4675c0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1185 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1Jdf4+)AW9}r@n&e9;>?&drF}}!6b)}iZ^70n8O%Mb zDJ7K!sUR&({uWa@J0P;mJ&ZL|{QUg9{sRG+@U}0RlGGVHMI(bLgB4^{20P3k7H^ie zDH%LHOlc;_2H2n(zyLD9*&!``la+|Kgp~oK@vep)Y3AZUw%E=$R)E07L5AY(-jGY_N zy*xp8kZ%T)Hxnd01e|oWEPXpzc7m{F-Z=~GekWf&`;0#pYo1-+0m8k9Hxe|TRQ2Hw w?<^PC>w>Y^w`zjs6Q`chn-LN|EHhWl!I7p&Qv=GO*uxfK4z<;!0`r0%01lO;!2kdN diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npz b/tests/saved_test_data/clean70SRibosome_cov2d_covarctf.npz new file mode 100644 index 0000000000000000000000000000000000000000..15371a69d814852c415625f1e21790874943a3cc GIT binary patch literal 4532 zcmWIWW@gc4fB;2?c}bV0{zCybg9t-nQBk~sfq`CLK_w%D0K);OLKr>SFVr_6l98c| zp;|p9wK%y*-AX~-Ce1`$M?pO;zo?`rF)u#9C?ypn?v|KSoC*{#&PXgs1@bjabQDZ9 zbrfn9$Oc?#>6@%Xyd|s*7>#!|>_{^g2eQR>zOj1!TW8k|6{}rFK=BVicD%ul0ogWvYb*m$JYOw1*9+70hA_`BfqEWBQ{;JL9R*_wJ-^;8*(#;snniDf z!7dP8wc)-6i-68f5T6xDhaHfyBF2M8!^eYr9Ca%&W6*f8dhqbC|5n@Nep%kmpJBNj zNPpi@Xc=F#%JL47p8=#T-pixKAa)O$3?C2LHYGjAj6qYF2PZ)DAdIFc4;twx7*Uu9 zopiM>~+Cd>{~Tj%wYA5A-n|w%lI&QpbG%}p5ghs><327 z7{Xf=Ezr^oMpNV&LmdS}3QMz@tLC6JiLiSB-eQ0`97Yd54|E;-kcSxq@D>5g;V^oj zJpf7&sC5W-55Q}Dn8RW8Q1Jk~Mu#~ZMh_Jaz-w`s!(sH`@Bkx|E;DMk?f@`pfiO1^ z!$1Qgh=tV23-CtO2Y(8o|XG@fz;|?S$9}E=JHbf?BxPZQRcU(+JMK z#A{r@4AThCHs~5btu#<|0>Y?%Y-E9H1SePGHTto_G=h@}x<*jz1G|mVY%q=Bh86J| zU%)kj8!_k_LA5G&8@IBwLh}YP}0o4epoY6IcYK|6Qh=MR`sQPolG-6FTxHSrJ z!8Bq`IiQjfyNws&8nLDv+!`kVT@FsGhFDV$ZjHe_FpXGK4sMMiyfBT>l*0;)SOx|` OAUp|7##e#T3=9Cg+YS5x literal 0 HcmV?d00001 diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_covarctf_shrink.npy b/tests/saved_test_data/clean70SRibosome_cov2d_covarctf_shrink.npy deleted file mode 100644 index d6904791132c9731bc44d131a874efb63a92bcb2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1513 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1Jdf4+)AW9}r@n&e9;>?&drF}}!6b)}iZ^70n8O%Mb zDJ7K!sUR&({uWa@J0P;mJ&ZL|{QUg9{sRG+@U}0RlGGVHMI(bLgB4^{20P3k7H^ie zDH%LHOlc?v*q|A}05qV(xdGypfg`5?oq_2GW^ZPoA0#`RJ*effQG3V>N$laqM0|ULhf=WgP0R{#}s7fe3*)P;LAd-=x zjG%1J4i^cqgFFgwd1)p0SRC@n8u&VgkWvG!P67j7I~( zz`$f=1;W#lzvVH@1XEZbEP$2?Fq)Dw!AM8Jh}LBSYF@|gIkVyHIrHJ{Ig8=!Im=<} zIYW304OWoA=%H1haBB^ic8yyafVtJd7Sno`AO?V2+2;L&+2HmIBQ2FnaKKf{{s= z8MS*1?gIn+zueHm2t+l2`@=|`#t`uuk8;3l1UEE@*SLrisu5C&qiY1!CZKEq!lMyx3Zx5jgDjaX9-ZjF<9VK!n-Ik+`O@WC`fQw}RIDKRhz0^vzu J`j_Pg@c_C7cZ~o5 literal 0 HcmV?d00001 diff --git a/tests/test_covar2d.py b/tests/test_covar2d.py index b1bf41e231..66e88c4ff8 100644 --- a/tests/test_covar2d.py +++ b/tests/test_covar2d.py @@ -128,15 +128,14 @@ def test_get_mean(cov2d_fixture): def test_get_covar(cov2d_fixture): results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npy"), - allow_pickle=True, + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npz"), ) cov2d, coef_clean = cov2d_fixture[1], cov2d_fixture[2] covar_coef = cov2d._get_covar(coef_clean.asnumpy()) - for im, mat in enumerate(results.tolist()): + for im, mat in enumerate(results.values()): np.testing.assert_allclose(mat, covar_coef[im], rtol=1e-05) @@ -210,13 +209,12 @@ def test_shrinkage(cov2d_fixture, shrinker): cov2d, coef_clean = cov2d_fixture[1], cov2d_fixture[2] results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npy"), - allow_pickle=True, + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npz"), ) covar_coef = cov2d.get_covar(coef_clean, covar_est_opt={"shrinker": shrinker}) - for im, mat in enumerate(results.tolist()): + for im, mat in enumerate(results.values()): np.testing.assert_allclose( mat, covar_coef[im], atol=utest_tolerance(cov2d.dtype) ) @@ -288,12 +286,11 @@ def test_get_covar_ctf(cov2d_fixture, ctf_enabled): sim, cov2d, _, coef, h_ctf_fb, h_idx = cov2d_fixture results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf.npy"), - allow_pickle=True, + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf.npz"), ) covar_coef_ctf = cov2d.get_covar(coef, h_ctf_fb, h_idx, noise_var=NOISE_VAR) - for im, mat in enumerate(results.tolist()): + for im, mat in enumerate(results.values()): # These tolerances were adjusted slightly (1e-8 to 3e-8) to accomodate MATLAB CTF repro changes np.testing.assert_allclose(mat, covar_coef_ctf[im], rtol=3e-05, atol=3e-08) @@ -306,8 +303,7 @@ def test_get_covar_ctf_shrink(cov2d_fixture, ctf_enabled): pytest.skip(reason="Reference file n/a.") results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf_shrink.npy"), - allow_pickle=True, + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf_shrink.npz"), ) covar_opt = { @@ -328,5 +324,5 @@ def test_get_covar_ctf_shrink(cov2d_fixture, ctf_enabled): covar_est_opt=covar_opt, ) - for im, mat in enumerate(results.tolist()): + for im, mat in enumerate(results.values()): np.testing.assert_allclose(mat, covar_coef_ctf_shrink[im]) From 5cc94eef096ff5b125580a1fe814a12b95576296 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 23 Oct 2025 11:38:56 -0400 Subject: [PATCH 04/91] Adjust adjoint test rtol. cleanup test fixtures --- tests/test_mean_estimator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_mean_estimator.py b/tests/test_mean_estimator.py index c7083d4ef1..b56a06f4ce 100644 --- a/tests/test_mean_estimator.py +++ b/tests/test_mean_estimator.py @@ -109,7 +109,7 @@ def test_estimate(sim, estimator, mask): np.testing.assert_array_equal(sim.pixel_size, estimate.pixel_size) -def test_adjoint(sim, basis, estimator): +def test_adjoint(sim): """ Test = for random volume `v` and random images `u`. @@ -128,7 +128,10 @@ def test_adjoint(sim, basis, estimator): lhs = np.dot(proj.asnumpy().flatten(), u.flatten()) rhs = np.dot(backproj.asnumpy().flatten(), v.flatten()) - np.testing.assert_allclose(lhs, rhs, rtol=1e-6) + rtol = 1e-07 # default rtol for assert_allclose + if sim.dtype == np.float32: + rtol = 1e-05 + np.testing.assert_allclose(lhs, rhs, rtol=rtol) def test_src_adjoint(sim, basis, estimator): From 1d67fbb0bde583c9e8752b9d4732b5cb3631bba4 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 22 Aug 2025 11:01:04 -0400 Subject: [PATCH 05/91] move estimate_third_rows to util file --- src/aspire/abinitio/__init__.py | 1 + src/aspire/abinitio/commonline_c2.py | 4 +- src/aspire/abinitio/commonline_c3_c4.py | 53 +----------------------- src/aspire/abinitio/commonline_utils.py | 55 +++++++++++++++++++++++++ src/aspire/utils/__init__.py | 1 + tests/test_orient_symmetric.py | 12 +++--- 6 files changed, 68 insertions(+), 58 deletions(-) create mode 100644 src/aspire/abinitio/commonline_utils.py diff --git a/src/aspire/abinitio/__init__.py b/src/aspire/abinitio/__init__.py index fe91b322ad..a7acb5fc6a 100644 --- a/src/aspire/abinitio/__init__.py +++ b/src/aspire/abinitio/__init__.py @@ -1,6 +1,7 @@ from .commonline_base import CLOrient3D # isort: off +from .commonline_utils import estimate_third_rows from .commonline_sdp import CommonlineSDP from .commonline_lud import CommonlineLUD from .commonline_irls import CommonlineIRLS diff --git a/src/aspire/abinitio/commonline_c2.py b/src/aspire/abinitio/commonline_c2.py index 6bf6cf99ed..538c71cc56 100644 --- a/src/aspire/abinitio/commonline_c2.py +++ b/src/aspire/abinitio/commonline_c2.py @@ -3,7 +3,7 @@ import numpy as np from scipy.linalg import eigh -from aspire.abinitio import CLSymmetryC3C4 +from aspire.abinitio import CLSymmetryC3C4, estimate_third_rows from aspire.utils import J_conjugate, Rotation, all_pairs logger = logging.getLogger(__name__) @@ -224,7 +224,7 @@ def estimate_rotations(self): viis = np.vstack((np.eye(3, dtype=self.dtype),) * self.n_img).reshape( self.n_img, 3, 3 ) - vis = self._estimate_third_rows(vijs, viis) + vis = estimate_third_rows(vijs, viis) logger.info("Estimating in-plane rotations and rotations matrices.") Ris = self._estimate_inplane_rotations(vis, Rijs, Rijgs) diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index 0170d8b88f..13fd9dc4d3 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -3,7 +3,7 @@ import numpy as np from numpy.linalg import eigh, norm, svd -from aspire.abinitio import CLOrient3D, SyncVotingMixin +from aspire.abinitio import CLOrient3D, SyncVotingMixin, estimate_third_rows from aspire.operators import PolarFT from aspire.utils import ( J_conjugate, @@ -110,7 +110,7 @@ def estimate_rotations(self): vijs, viis = self._global_J_sync(vijs, viis) logger.info("Estimating third rows of rotation matrices.") - vis = self._estimate_third_rows(vijs, viis) + vis = estimate_third_rows(vijs, viis) logger.info("Estimating in-plane rotations and rotations matrices.") Ris = self._estimate_inplane_rotations(vis) @@ -209,55 +209,6 @@ def _global_J_sync(self, vijs, viis): viis[i] = vii_J return vijs, viis - def _estimate_third_rows(self, vijs, viis): - """ - Find the third row of each rotation matrix given a collection of matrices - representing the outer products of the third rows from each rotation matrix. - - :param vijs: An (n-choose-2)x3x3 array where each 3x3 slice holds the third rows - outer product of the rotation matrices Ri and Rj. - - :param viis: An n_imgx3x3 array where the i'th 3x3 slice holds the outer product of - the third row of Ri with itself. - - :param order: The underlying molecular symmetry. - - :return: vis, An n_imgx3 matrix whose i'th row is the third row of the rotation matrix Ri. - """ - - n_img = self.n_img - - # Build matrix V whose (i,j)-th block of size 3x3 holds the outer product vij - V = np.zeros((n_img, n_img, 3, 3), dtype=vijs.dtype) - - # All pairs (i,j) where i Date: Mon, 25 Aug 2025 09:13:43 -0400 Subject: [PATCH 06/91] migrate estimate_inplane_rotations, complete_third_row_to_rot to commonline utils. --- src/aspire/abinitio/__init__.py | 7 +- src/aspire/abinitio/commonline_c2.py | 8 +- src/aspire/abinitio/commonline_c3_c4.py | 212 +----------------------- src/aspire/abinitio/commonline_cn.py | 8 +- src/aspire/abinitio/commonline_utils.py | 204 ++++++++++++++++++++++- tests/test_orient_symmetric.py | 12 +- 6 files changed, 229 insertions(+), 222 deletions(-) diff --git a/src/aspire/abinitio/__init__.py b/src/aspire/abinitio/__init__.py index a7acb5fc6a..3685479732 100644 --- a/src/aspire/abinitio/__init__.py +++ b/src/aspire/abinitio/__init__.py @@ -1,7 +1,12 @@ from .commonline_base import CLOrient3D # isort: off -from .commonline_utils import estimate_third_rows +from .commonline_utils import ( + cl_angles_to_ind, + estimate_third_rows, + complete_third_row_to_rot, + estimate_inplane_rotations, +) from .commonline_sdp import CommonlineSDP from .commonline_lud import CommonlineLUD from .commonline_irls import CommonlineIRLS diff --git a/src/aspire/abinitio/commonline_c2.py b/src/aspire/abinitio/commonline_c2.py index 538c71cc56..52b63cc9e8 100644 --- a/src/aspire/abinitio/commonline_c2.py +++ b/src/aspire/abinitio/commonline_c2.py @@ -3,7 +3,11 @@ import numpy as np from scipy.linalg import eigh -from aspire.abinitio import CLSymmetryC3C4, estimate_third_rows +from aspire.abinitio import ( + CLSymmetryC3C4, + complete_third_row_to_rot, + estimate_third_rows, +) from aspire.utils import J_conjugate, Rotation, all_pairs logger = logging.getLogger(__name__) @@ -301,7 +305,7 @@ def _estimate_inplane_rotations(self, vis, Rijs, Rijgs): H = np.zeros((self.n_img, self.n_img), dtype=complex) # Step 1: Construct all rotation matrices Ris_tilde whose third rows are equal to # the corresponding third rows vis. - Ris_tilde = self._complete_third_row_to_rot(vis) + Ris_tilde = complete_third_row_to_rot(vis) pairs = all_pairs(self.n_img) for idx, (i, j) in enumerate(pairs): diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index 13fd9dc4d3..bd9522e865 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -3,7 +3,12 @@ import numpy as np from numpy.linalg import eigh, norm, svd -from aspire.abinitio import CLOrient3D, SyncVotingMixin, estimate_third_rows +from aspire.abinitio import ( + CLOrient3D, + SyncVotingMixin, + estimate_inplane_rotations, + estimate_third_rows, +) from aspire.operators import PolarFT from aspire.utils import ( J_conjugate, @@ -12,7 +17,6 @@ all_triplets, anorm, cyclic_rotations, - tqdm, trange, ) from aspire.utils.random import randn @@ -113,7 +117,7 @@ def estimate_rotations(self): vis = estimate_third_rows(vijs, viis) logger.info("Estimating in-plane rotations and rotations matrices.") - Ris = self._estimate_inplane_rotations(vis) + Ris = estimate_inplane_rotations(self, vis) self.rotations = Ris @@ -209,140 +213,6 @@ def _global_J_sync(self, vijs, viis): viis[i] = vii_J return vijs, viis - def _estimate_inplane_rotations(self, vis): - """ - Estimate the rotation matrices for each image by constructing arbitrary rotation matrices - populated with the given third rows, vis, and then rotating by an appropriate in-plane rotation. - - :param vis: An n_imgx3 array where the i'th row holds the estimate for the third row of - the i'th rotation matrix. - - :return: Rotation matrices Ris and in-plane rotation matrices R_thetas, both size n_imgx3x3. - """ - pf = self.pf - n_img = self.n_img - n_theta = self.n_theta - max_shift_1d = self.max_shift - shift_step = self.shift_step - order = self.order - degree_res = self.degree_res - - # Step 1: Construct all rotation matrices Ri_tildes whose third rows are equal to - # the corresponding third rows vis. - Ri_tildes = self._complete_third_row_to_rot(vis) - - # Step 2: Construct all in-plane rotation matrices, R_theta_ijs. - max_angle = (360 // order) * order - theta_ijs = np.arange(0, max_angle, degree_res) * np.pi / 180 - R_theta_ijs = Rotation.about_axis("z", theta_ijs, dtype=self.dtype).matrices - - # Step 3: Compute the correlation over all shifts. - # Generate shifts. - r_max = pf.shape[-1] - shifts, shift_phases, _ = self._generate_shift_phase_and_filter( - r_max, max_shift_1d, shift_step - ) - n_shifts = len(shifts) - - # Q is the n_img x n_img Hermitian matrix defined by Q = q*q^H, - # where q = (exp(i*order*theta_0), ..., exp(i*order*theta_{n_img-1}))^H, - # and theta_i in [0, 2pi/order) is the in-plane rotation angle for the i'th image. - Q = np.zeros((n_img, n_img), dtype=complex) - - # Reconstruct the full polar Fourier for use in correlation. self.pf only consists of - # rays in the range [180, 360), with shape (n_img, n_theta//2, n_rad-1). - pf = PolarFT.half_to_full(pf) - - # Normalize rays. - pf /= norm(pf, axis=-1)[..., np.newaxis] - - n_pairs = n_img * (n_img - 1) // 2 - with tqdm(total=n_pairs) as pbar: - idx = 0 - # Note: the ordering of i and j in these loops should not be changed as - # they correspond to the ordered tuples (i, j), for i= 1e-5 - - # If the third row coincides with the z-axis we return the identity matrix. - rots[~mask] = np.eye(3, dtype=r3.dtype) - - # 'norm_12' is non-zero since r3 does not coincide with the z-axis. - norm_12 = np.sqrt(r3[mask, 0] ** 2 + r3[mask, 1] ** 2) - - # Populate 1st rows with vector orthogonal to row 3. - rots[mask, 0, 0] = r3[mask, 1] / norm_12 - rots[mask, 0, 1] = -r3[mask, 0] / norm_12 - - # Populate 2nd rows such that r3 = r1 x r2 - rots[mask, 1, 0] = r3[mask, 0] * r3[mask, 2] / norm_12 - rots[mask, 1, 1] = r3[mask, 1] * r3[mask, 2] / norm_12 - rots[mask, 1, 2] = -norm_12 - - if singleton: - rots = rots.reshape(3, 3) - - return rots - - @staticmethod - def cl_angles_to_ind(cl_angles, n_theta): - thetas = np.arctan2(cl_angles[:, 1], cl_angles[:, 0]) - - # Shift from [-pi,pi] to [0,2*pi). - thetas = np.mod(thetas, 2 * np.pi) - - # linear scale from [0,2*pi) to [0,n_theta). - ind = np.mod(np.round(thetas / (2 * np.pi) * n_theta), n_theta).astype(int) - - # Return scalar for single value. - if ind.size == 1: - ind = ind.flat[0] - - return ind - @staticmethod def g_sync(rots, order, rots_gt): """ diff --git a/src/aspire/abinitio/commonline_cn.py b/src/aspire/abinitio/commonline_cn.py index a1440cadc3..e0d69cb086 100644 --- a/src/aspire/abinitio/commonline_cn.py +++ b/src/aspire/abinitio/commonline_cn.py @@ -3,7 +3,7 @@ import numpy as np from numpy.linalg import norm -from aspire.abinitio import CLSymmetryC3C4 +from aspire.abinitio import CLSymmetryC3C4, cl_angles_to_ind, complete_third_row_to_rot from aspire.operators import PolarFT from aspire.utils import ( J_conjugate, @@ -298,8 +298,8 @@ def relative_rots_to_cl_indices(relative_rots, n_theta): c1s = np.array((-relative_rots[:, 1, 2], relative_rots[:, 0, 2])).T c2s = np.array((relative_rots[:, 2, 1], -relative_rots[:, 2, 0])).T - c1s = CLSymmetryC3C4.cl_angles_to_ind(c1s, n_theta) - c2s = CLSymmetryC3C4.cl_angles_to_ind(c2s, n_theta) + c1s = cl_angles_to_ind(c1s, n_theta) + c2s = cl_angles_to_ind(c2s, n_theta) inds = np.where(c1s >= n_theta // 2) c1s[inds] -= n_theta // 2 @@ -331,7 +331,7 @@ def generate_candidate_rots(n, equator_threshold, order, degree_res, seed): while counter < n: third_row = randn(3) third_row /= anorm(third_row, axes=(-1,)) - Ri_tilde = CLSymmetryC3C4._complete_third_row_to_rot(third_row) + Ri_tilde = complete_third_row_to_rot(third_row) # Exclude candidates that represent equator images. Equator candidates # induce collinear self-common-lines, which always have perfect correlation. diff --git a/src/aspire/abinitio/commonline_utils.py b/src/aspire/abinitio/commonline_utils.py index 4a0e88bd5b..bfd1e5ce5f 100644 --- a/src/aspire/abinitio/commonline_utils.py +++ b/src/aspire/abinitio/commonline_utils.py @@ -1,8 +1,12 @@ +import logging + import numpy as np -from numpy.linalg import eigh +from numpy.linalg import eigh, norm + +from aspire.operators import PolarFT +from aspire.utils import Rotation, all_pairs, anorm, tqdm -from aspire.utils.matrix import anorm -from aspire.utils.misc import all_pairs +logger = logging.getLogger(__name__) def estimate_third_rows(vijs, viis): @@ -53,3 +57,197 @@ def estimate_third_rows(vijs, viis): vis /= anorm(vis, axes=(-1,))[:, np.newaxis] return vis + + +def estimate_inplane_rotations(cl_class, vis): + """ + Estimate the rotation matrices for each image by constructing arbitrary rotation matrices + populated with the given third rows, vis, and then rotating by an appropriate in-plane rotation. + + :cl_class: A commonlines class instance. + :param vis: An n_imgx3 array where the i'th row holds the estimate for the third row of + the i'th rotation matrix. + + :return: Rotation matrices Ris and in-plane rotation matrices R_thetas, both size n_imgx3x3. + """ + pf = cl_class.pf + n_img = cl_class.n_img + n_theta = cl_class.n_theta + max_shift_1d = cl_class.max_shift + shift_step = cl_class.shift_step + order = cl_class.order + degree_res = cl_class.degree_res + + # Step 1: Construct all rotation matrices Ri_tildes whose third rows are equal to + # the corresponding third rows vis. + Ri_tildes = complete_third_row_to_rot(vis) + + # Step 2: Construct all in-plane rotation matrices, R_theta_ijs. + max_angle = (360 // order) * order + theta_ijs = np.arange(0, max_angle, degree_res) * np.pi / 180 + R_theta_ijs = Rotation.about_axis("z", theta_ijs, dtype=cl_class.dtype).matrices + + # Step 3: Compute the correlation over all shifts. + # Generate shifts. + r_max = pf.shape[-1] + shifts, shift_phases, _ = cl_class._generate_shift_phase_and_filter( + r_max, max_shift_1d, shift_step + ) + n_shifts = len(shifts) + + # Q is the n_img x n_img Hermitian matrix defined by Q = q*q^H, + # where q = (exp(i*order*theta_0), ..., exp(i*order*theta_{n_img-1}))^H, + # and theta_i in [0, 2pi/order) is the in-plane rotation angle for the i'th image. + Q = np.zeros((n_img, n_img), dtype=complex) + + # Reconstruct the full polar Fourier for use in correlation. cl_class.pf only consists of + # rays in the range [180, 360), with shape (n_img, n_theta//2, n_rad-1). + pf = PolarFT.half_to_full(pf) + + # Normalize rays. + pf /= norm(pf, axis=-1)[..., np.newaxis] + + n_pairs = n_img * (n_img - 1) // 2 + with tqdm(total=n_pairs) as pbar: + idx = 0 + # Note: the ordering of i and j in these loops should not be changed as + # they correspond to the ordered tuples (i, j), for i= 1e-5 + + # If the third row coincides with the z-axis we return the identity matrix. + rots[~mask] = np.eye(3, dtype=r3.dtype) + + # 'norm_12' is non-zero since r3 does not coincide with the z-axis. + norm_12 = np.sqrt(r3[mask, 0] ** 2 + r3[mask, 1] ** 2) + + # Populate 1st rows with vector orthogonal to row 3. + rots[mask, 0, 0] = r3[mask, 1] / norm_12 + rots[mask, 0, 1] = -r3[mask, 0] / norm_12 + + # Populate 2nd rows such that r3 = r1 x r2 + rots[mask, 1, 0] = r3[mask, 0] * r3[mask, 2] / norm_12 + rots[mask, 1, 1] = r3[mask, 1] * r3[mask, 2] / norm_12 + rots[mask, 1, 2] = -norm_12 + + if singleton: + rots = rots.reshape(3, 3) + + return rots + + +def cl_angles_to_ind(cl_angles, n_theta): + thetas = np.arctan2(cl_angles[:, 1], cl_angles[:, 0]) + + # Shift from [-pi,pi] to [0,2*pi). + thetas = np.mod(thetas, 2 * np.pi) + + # linear scale from [0,2*pi) to [0,n_theta). + ind = np.mod(np.round(thetas / (2 * np.pi) * n_theta), n_theta).astype(int) + + # Return scalar for single value. + if ind.size == 1: + ind = ind.flat[0] + + return ind diff --git a/tests/test_orient_symmetric.py b/tests/test_orient_symmetric.py index aa44fcfdfb..d7c4c4716f 100644 --- a/tests/test_orient_symmetric.py +++ b/tests/test_orient_symmetric.py @@ -7,6 +7,8 @@ CLSymmetryC2, CLSymmetryC3C4, CLSymmetryCn, + cl_angles_to_ind, + complete_third_row_to_rot, estimate_third_rows, ) from aspire.abinitio.commonline_cn import MeanOuterProductEstimator @@ -509,7 +511,7 @@ def test_complete_third_row(dtype): r3[0] = np.array([0, 0, 1], dtype=dtype) # Generate rotations. - R = CLSymmetryC3C4._complete_third_row_to_rot(r3) + R = complete_third_row_to_rot(r3) # Assert that first rotation is the identity matrix. assert np.allclose(R[0], np.eye(3, dtype=dtype)) @@ -640,10 +642,6 @@ def _gt_cl_c2(n_theta, rots_gt): U = Ri.T @ g @ Rj c1 = np.array([-U[1, 2], U[0, 2]]) c2 = np.array([U[2, 1], -U[2, 0]]) - clmatrix_gt[idx, i, j] = CLSymmetryC3C4.cl_angles_to_ind( - c1[np.newaxis, :], n_theta - ) - clmatrix_gt[idx, j, i] = CLSymmetryC3C4.cl_angles_to_ind( - c2[np.newaxis, :], n_theta - ) + clmatrix_gt[idx, i, j] = cl_angles_to_ind(c1[np.newaxis, :], n_theta) + clmatrix_gt[idx, j, i] = cl_angles_to_ind(c2[np.newaxis, :], n_theta) return clmatrix_gt From 5562e23c14c4f9d5d2d529344c21f4bb1aea38a0 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 25 Aug 2025 11:29:18 -0400 Subject: [PATCH 07/91] migrate syncmatrix_ij_vote_3n to SyncVotingMixin. Remove C3C4 implimentation. --- src/aspire/abinitio/commonline_c3_c4.py | 29 --------------------- src/aspire/abinitio/commonline_sync3n.py | 31 ----------------------- src/aspire/abinitio/sync_voting.py | 32 ++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 60 deletions(-) diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index bd9522e865..af74dc6f6c 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -369,35 +369,6 @@ def _estimate_all_Rijs_c3_c4(self, clmatrix): return Rijs - def _syncmatrix_ij_vote_3n(self, clmatrix, i, j, k_list, n_theta): - """ - Compute the (i,j) rotation block of the synchronization matrix using voting method - - Given the common lines matrix `clmatrix`, a list of images specified in k_list - and the number of common lines n_theta, find the (i, j) rotation block Rij. - - :param clmatrix: The common lines matrix - :param i: The i image - :param j: The j image - :param k_list: The list of images for the third image for voting algorithm - :param n_theta: The number of points in the theta direction (common lines) - :return: The (i,j) rotation block of the synchronization matrix - """ - _, good_k = self._vote_ij(clmatrix, n_theta, i, j, k_list) - - rots = self._rotratio_eulerangle_vec(clmatrix, i, j, good_k, n_theta) - - if rots is not None: - rot_mean = np.mean(rots, 0) - - else: - # This is for the case that images i and j correspond to the same - # viewing direction and differ only by in-plane rotation. - # We set to zero as in the Matlab code. - rot_mean = np.zeros((3, 3)) - - return rot_mean - def _local_J_sync_c3_c4(self, Rijs, Riis): """ Estimate viis and vijs. In order to estimate vij = vi @ vj.T, it is necessary for Rii, Rjj, diff --git a/src/aspire/abinitio/commonline_sync3n.py b/src/aspire/abinitio/commonline_sync3n.py index 841f9dfceb..39f57ab978 100644 --- a/src/aspire/abinitio/commonline_sync3n.py +++ b/src/aspire/abinitio/commonline_sync3n.py @@ -969,37 +969,6 @@ def _estimate_all_Rijs_host(self, clmatrix): return Rijs - def _syncmatrix_ij_vote_3n(self, clmatrix, i, j, k_list, n_theta): - """ - Compute the (i,j) rotation block of the synchronization matrix using voting method - - Given the common lines matrix `clmatrix`, a list of images specified in k_list - and the number of common lines n_theta, find the (i, j) rotation block Rij. - - :param clmatrix: The common lines matrix - :param i: The i image - :param j: The j image - :param k_list: The list of images for the third image for voting algorithm - :param n_theta: The number of points in the theta direction (common lines) - :return: The (i,j) rotation block of the synchronization matrix - """ - alphas, good_k = self._vote_ij(clmatrix, n_theta, i, j, k_list, sync=True) - - angles = np.zeros(3) - - if alphas is not None: - angles[0] = clmatrix[i, j] * 2 * np.pi / n_theta + np.pi / 2 - angles[1] = np.mean(alphas) - angles[2] = -np.pi / 2 - clmatrix[j, i] * 2 * np.pi / n_theta - rot = Rotation.from_euler(angles).matrices - - else: - # This is for the case that images i and j correspond to the same - # viewing direction and differ only by in-plane rotation. - # We set to zero as in the Matlab code. - rot = np.zeros((3, 3)) - - return rot ####################################### # Secondary Methods for Global J Sync # diff --git a/src/aspire/abinitio/sync_voting.py b/src/aspire/abinitio/sync_voting.py index 651ba35c5e..b866f1e8ff 100644 --- a/src/aspire/abinitio/sync_voting.py +++ b/src/aspire/abinitio/sync_voting.py @@ -13,6 +13,38 @@ class SyncVotingMixin(object): which are shared by CLSynVoting and CLSymmetryC3C4 """ + def _syncmatrix_ij_vote_3n(self, clmatrix, i, j, k_list, n_theta): + """ + Compute the (i,j) rotation block of the synchronization matrix using voting method + + Given the common lines matrix `clmatrix`, a list of images specified in k_list + and the number of common lines n_theta, find the (i, j) rotation block Rij. + + :param clmatrix: The common lines matrix + :param i: The i image + :param j: The j image + :param k_list: The list of images for the third image for voting algorithm + :param n_theta: The number of points in the theta direction (common lines) + :return: The (i,j) rotation block of the synchronization matrix + """ + alphas, good_k = self._vote_ij(clmatrix, n_theta, i, j, k_list, sync=True) + + angles = np.zeros(3) + + if alphas is not None: + angles[0] = clmatrix[i, j] * 2 * np.pi / n_theta + np.pi / 2 + angles[1] = np.mean(alphas) + angles[2] = -np.pi / 2 - clmatrix[j, i] * 2 * np.pi / n_theta + rot = Rotation.from_euler(angles).matrices + + else: + # This is for the case that images i and j correspond to the same + # viewing direction and differ only by in-plane rotation. + # We set to zero as in the Matlab code. + rot = np.zeros((3, 3)) + + return rot + def _rotratio_eulerangle_vec(self, clmatrix, i, j, good_k, n_theta): """ Compute the rotation that takes image i to image j From 34f82dab86c029a148952c5193bf404e0d2fe4df Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 26 Aug 2025 15:01:35 -0400 Subject: [PATCH 08/91] remove unused import --- src/aspire/abinitio/commonline_sync3n.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/aspire/abinitio/commonline_sync3n.py b/src/aspire/abinitio/commonline_sync3n.py index 39f57ab978..25bfd63015 100644 --- a/src/aspire/abinitio/commonline_sync3n.py +++ b/src/aspire/abinitio/commonline_sync3n.py @@ -7,15 +7,7 @@ from scipy.optimize import curve_fit from aspire.abinitio import CLOrient3D, SyncVotingMixin -from aspire.utils import ( - J_conjugate, - Rotation, - all_pairs, - nearest_rotations, - random, - tqdm, - trange, -) +from aspire.utils import J_conjugate, all_pairs, nearest_rotations, random, tqdm, trange from aspire.utils.matlab_compat import stable_eigsh logger = logging.getLogger(__name__) @@ -969,7 +961,6 @@ def _estimate_all_Rijs_host(self, clmatrix): return Rijs - ####################################### # Secondary Methods for Global J Sync # ####################################### From 56ecfc542be5af86e8d1756726b6b6fff9e819ea Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 11 Nov 2025 12:01:40 -0500 Subject: [PATCH 09/91] Nasty Nasty bug --- src/aspire/operators/polar_ft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/operators/polar_ft.py b/src/aspire/operators/polar_ft.py index de57297700..ce3dd8ca9d 100644 --- a/src/aspire/operators/polar_ft.py +++ b/src/aspire/operators/polar_ft.py @@ -136,7 +136,7 @@ def _transform(self, x): resolution = x.shape[-1] # nufft call should return `pf` as array type (np or cp) of `x` - pf = nufft(x, self.freqs) / resolution**2 + pf = nufft(x, -self.freqs) / resolution**2 return pf.reshape(*stack_shape, self.ntheta // 2, self.nrad) From fd71a0ea7cbbe68cdfb78faeff5850ab3a45f914 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 11 Nov 2025 12:03:48 -0500 Subject: [PATCH 10/91] Checkpoint, shifts close to MATLAB [skip ci] --- src/aspire/abinitio/commonline_base.py | 123 ++++++++++++++++++++----- 1 file changed, 99 insertions(+), 24 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 041c78d736..d20055aa42 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -73,14 +73,15 @@ def __init__( if str(full_width).lower() == "adaptive": full_width = -1 self.full_width = int(full_width) - self.max_shift = math.ceil(max_shift * self.n_res) + self.max_shift = 15 # match MATLAB workflow for now math.ceil(max_shift * self.n_res) self.shift_step = shift_step self.offsets_max_shift = self.max_shift if offsets_max_shift is not None: - self.offsets_max_shift = math.ceil(offsets_max_shift * self.n_res) + self.offsets_max_shift = 15 # match MATLAB workflow math.ceil(offsets_max_shift * self.n_res) self.offsets_shift_step = offsets_shift_step or self.shift_step self.mask = mask self._pf = None + self._m_pf = None # Sanity limit to match potential clmatrix dtype of int16. if self.n_img > (2**15 - 1): @@ -135,6 +136,12 @@ def pf(self): self._prepare_pf() return self._pf + @property + def m_pf(self): + if self._m_pf is None: + self._prepare_pf() + return self._m_pf + def _prepare_pf(self): """ Prepare the polar Fourier transform used for correlations. @@ -156,6 +163,7 @@ def _prepare_pf(self): (self.n_res, self.n_res), self.n_rad, self.n_theta, dtype=self.dtype ) pf = pft.transform(imgs) + self._m_pf = pft.half_to_full(pf) # We remove the DC the component. pf has size (n_img) x (n_theta/2) x (n_rad-1), # with pf[:, :, 0] containing low frequency content and pf[:, :, -1] containing @@ -234,6 +242,13 @@ def build_clmatrix(self): Wrapper for cpu/gpu dispatch. """ + force_fn = "force_cl.npz" + if os.path.exists(force_fn): + logger.warning(f"FORCING Common Lines Matrix from {force_fn}") + res = np.load(force_fn) + self.clmatrix = res["clmatrix"] + return self.clmatrix + logger.info("Begin building Common Lines Matrix") # host/gpu dispatch @@ -245,6 +260,10 @@ def build_clmatrix(self): # Unpack result self._shifts_1d, self.clmatrix = res + # save result + logger.warning(f"Saving Common Lines to {force_fn}") + np.savez(force_fn, clmatrix=self.clmatrix) + return self.clmatrix def build_clmatrix_host(self): @@ -449,7 +468,7 @@ def build_clmatrix_cu(self): # Note diagnostic 1d shifts are not computed in the CUDA implementation. return None, clmatrix - def estimate_shifts(self, equations_factor=1, max_memory=4000): + def estimate_shifts(self, equations_factor=1, max_memory=10000): """ Estimate 2D shifts in images @@ -538,7 +557,19 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): # `estimate_shifts()` requires that rotations have already been estimated. rotations = Rotation(self.rotations) - pf = self.pf.copy() + _pf = self.pf.copy() + pf = self.m_pf + # # compare + # from scipy.io import loadmat + # matlab_pf = loadmat('matlab_pf_shift_dbg.mat')['pf'].T + # pf = matlab_pf + # breakpoint() + # %pf=[flipdim(pf(2:end,n_theta/2+1:end,:),1) ; pf(:,1:n_theta/2,:) ]; + # breakpoint() + pf = np.concatenate( + (np.flip(pf[:, n_theta_half:, 1:], axis=-1), pf[:, :n_theta_half, :]), + axis=-1, + ) # Estimate number of equations that will be used to calculate the shifts n_equations = self._estimate_num_shift_equations( @@ -560,10 +591,11 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): # The shift phases are pre-defined in a range of max_shift that can be # applied to maximize the common line calculation. The common-line filter # is also applied to the radial direction for easier detection. - r_max = pf.shape[2] - _, shift_phases, h = self._generate_shift_phase_and_filter( + r_max = (pf.shape[2] - 1) // 2 # pf is a different Matlab size + _, shift_phases, h = self._m_generate_shift_phase_and_filter( r_max, self.offsets_max_shift, self.offsets_shift_step ) + # breakpoint() d_theta = np.pi / n_theta_half @@ -578,6 +610,7 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): j = idx_j[shift_eq_idx] # get the common line indices based on the rotations from i and j images c_ij, c_ji = self._get_cl_indices(rotations, i, j, n_theta_half) + # breakpoint() # Extract the Fourier rays that correspond to the common line pf_i = pf[i, c_ij] @@ -593,16 +626,23 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): pf_j = pf[j, c_ji - n_theta_half] # perform bandpass filter, normalize each ray of each image, - pf_i = self._apply_filter_and_norm("i, i -> i", pf_i, r_max, h) - pf_j = self._apply_filter_and_norm("i, i -> i", pf_j, r_max, h) + pf_i = pf_i * h + pf_i[r_max - 1 : r_max + 1] = 0 + pf_i = pf_i / np.linalg.norm(pf_i) + pf_i = pf_i[:r_max] + + pf_j = pf_j * h + pf_j[r_max - 1 : r_max + 1] = 0 + pf_j = pf_j / np.linalg.norm(pf_j) + pf_j = pf_j[:r_max] # apply the shifts to images pf_i_flipped = np.conj(pf_i) - pf_i_stack = np.einsum("i, ji -> ij", pf_i, shift_phases) - pf_i_flipped_stack = np.einsum("i, ji -> ij", pf_i_flipped, shift_phases) + pf_i_stack = pf_i[:, None] * shift_phases + pf_i_flipped_stack = pf_i_flipped[:, None] * shift_phases - c1 = 2 * np.real(np.dot(np.conj(pf_i_stack.T), pf_j)) - c2 = 2 * np.real(np.dot(np.conj(pf_i_flipped_stack.T), pf_j)) + c1 = 2 * np.real(np.dot(pf_i_stack.T.conj(), pf_j)) + c2 = 2 * np.real(np.dot(pf_i_flipped_stack.T.conj(), pf_j)) # find the indices for the maximum values # and apply corresponding shifts @@ -623,17 +663,25 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): shift_b[shift_eq_idx] = dx # Compute the coefficients of the current equation - coefs = np.array( - [ - np.cos(shift_alpha), - np.sin(shift_alpha), - -np.cos(shift_beta), - -np.sin(shift_beta), - ] - ) - shift_eq[shift_eq_idx] = ( - [-1, -1, 0, 0] * coefs if is_pf_j_flipped else coefs - ) + if not is_pf_j_flipped: + shift_eq[shift_eq_idx] = np.array( + [ + np.sin(shift_alpha), + np.cos(shift_alpha), + -np.sin(shift_beta), + -np.cos(shift_beta), + ] + ) + else: + shift_beta = shift_beta - np.pi + shift_eq[shift_eq_idx] = np.array( + [ + -np.sin(shift_alpha), + -np.cos(shift_alpha), + -np.sin(shift_beta), + -np.cos(shift_beta), + ] + ) # create sparse matrix object only containing non-zero elements shift_equations = sparse.csr_matrix( @@ -665,6 +713,7 @@ def _estimate_num_shift_equations(self, n_img, equations_factor=1, max_memory=40 """ # Number of equations that will be used to estimation the shifts n_equations_total = int(np.ceil(n_img * (self.n_check - 1) / 2)) + # Estimated memory requirements for the full system of equation. # This ignores the sparsity of the system, since backslash seems to # ignore it. @@ -691,6 +740,31 @@ def _estimate_num_shift_equations(self, n_img, equations_factor=1, max_memory=40 return n_equations + def _m_generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): + """ + Port the code from MATLAB first, grumble grumble, inside thoughts bro. + """ + + # Number of shifts to try + n_shifts = int(np.ceil(2 * max_shift / shift_step + 1)) + + # only half of ray, excluding the DC component. + rk = np.arange(-r_max, r_max + 1, dtype=self.dtype) + rk2 = rk[:r_max] + + shift_phases = np.zeros((r_max, n_shifts), dtype=np.complex128) + for shiftidx in range(n_shifts): + # zero based shiftidx + shift = -max_shift + shiftidx * shift_step + shift_phases[:, shiftidx] = np.exp( + -2 * np.pi * 1j * rk2 * shift / (2 * r_max + 1) + ) + + h = np.sqrt(np.abs(rk)) * np.exp(-(rk**2) / (2 * (r_max / 4) ** 2)) + + # breakpoint() # matchy matchy + return None, shift_phases, h + def _generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): """ Prepare the shift phases and generate filter for common-line detection @@ -733,7 +807,8 @@ def _generate_index_pairs(self, n_equations): idx_j = np.array(idx_j, dtype="int") # Select random pairs based on the size of n_equations - rp = choice(np.arange(len(idx_j)), size=n_equations, replace=False) + # rp = choice(np.arange(len(idx_j)), size=n_equations, replace=False) + rp = np.arange(n_equations, dtype=int) return idx_i[rp], idx_j[rp] From eee57a631e668b93b87be1e01661973eefd4459c Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 12 Nov 2025 15:50:13 -0500 Subject: [PATCH 11/91] checkpoint before working backwards --- src/aspire/abinitio/commonline_base.py | 48 ++++++++++++++++---------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index d20055aa42..3b2789f90a 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -241,13 +241,21 @@ def build_clmatrix(self): Wrapper for cpu/gpu dispatch. """ - - force_fn = "force_cl.npz" - if os.path.exists(force_fn): - logger.warning(f"FORCING Common Lines Matrix from {force_fn}") - res = np.load(force_fn) - self.clmatrix = res["clmatrix"] - return self.clmatrix + # # load matlab clstack for comparison + # from scipy.io import loadmat + # m_cl_fn = 'shift_dbg_clstack.mat' + # clstack = loadmat(m_cl_fn)['clstack'].astype(int,order='C') - 1 # convert to 0 based + # logger.warning(f"FORCING MATLAB Common Lines Matrix from {m_cl_fn}") + # # breakpoint() + # self.clmatrix = clstack + # return self.clmatrix + + # force_fn = "force_cl.npz" + # if os.path.exists(force_fn): + # logger.warning(f"FORCING Common Lines Matrix from {force_fn}") + # res = np.load(force_fn) + # self.clmatrix = res["clmatrix"] + # return self.clmatrix logger.info("Begin building Common Lines Matrix") @@ -260,9 +268,9 @@ def build_clmatrix(self): # Unpack result self._shifts_1d, self.clmatrix = res - # save result - logger.warning(f"Saving Common Lines to {force_fn}") - np.savez(force_fn, clmatrix=self.clmatrix) + # # save result + # logger.warning(f"Saving Common Lines to {force_fn}") + # np.savez(force_fn, clmatrix=self.clmatrix) return self.clmatrix @@ -508,7 +516,7 @@ def estimate_shifts(self, equations_factor=1, max_memory=10000): show = True # Estimate shifts. - est_shifts = sparse.linalg.lsqr(shift_equations, shift_b, show=show)[0] + est_shifts = sparse.linalg.lsqr(shift_equations, shift_b, atol=1e-8, btol=1e-8, iter_lim=100, show=show)[0] self.shifts = est_shifts.reshape((self.n_img, 2)) return self.shifts @@ -559,13 +567,15 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): _pf = self.pf.copy() pf = self.m_pf - # # compare + # # # compare # from scipy.io import loadmat - # matlab_pf = loadmat('matlab_pf_shift_dbg.mat')['pf'].T + # pf_fn = 'matlab_pf_shift_dbg.mat' + # logger.warning(f"FORCING MATLAB PF from {pf_fn}") + # matlab_pf = loadmat(pf_fn)['pf'].T # pf = matlab_pf - # breakpoint() - # %pf=[flipdim(pf(2:end,n_theta/2+1:end,:),1) ; pf(:,1:n_theta/2,:) ]; - # breakpoint() + # # %pf=[flipdim(pf(2:end,n_theta/2+1:end,:),1) ; pf(:,1:n_theta/2,:) ]; + # # breakpoint() + pf = np.concatenate( (np.flip(pf[:, n_theta_half:, 1:], axis=-1), pf[:, :n_theta_half, :]), axis=-1, @@ -610,7 +620,6 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): j = idx_j[shift_eq_idx] # get the common line indices based on the rotations from i and j images c_ij, c_ji = self._get_cl_indices(rotations, i, j, n_theta_half) - # breakpoint() # Extract the Fourier rays that correspond to the common line pf_i = pf[i, c_ij] @@ -627,12 +636,12 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): # perform bandpass filter, normalize each ray of each image, pf_i = pf_i * h - pf_i[r_max - 1 : r_max + 1] = 0 + pf_i[r_max - 1 : r_max + 2] = 0 pf_i = pf_i / np.linalg.norm(pf_i) pf_i = pf_i[:r_max] pf_j = pf_j * h - pf_j[r_max - 1 : r_max + 1] = 0 + pf_j[r_max - 1 : r_max + 2] = 0 pf_j = pf_j / np.linalg.norm(pf_j) pf_j = pf_j[:r_max] @@ -720,6 +729,7 @@ def _estimate_num_shift_equations(self, n_img, equations_factor=1, max_memory=40 memory_total = equations_factor * ( n_equations_total * 2 * n_img * self.dtype.itemsize ) + #breakpoint() if memory_total < (max_memory * 10**6): n_equations = int(np.ceil(equations_factor * n_equations_total)) else: From e7bf5bdf82b298320d581cdfffc1dcad66d132da Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 12 Nov 2025 15:52:58 -0500 Subject: [PATCH 12/91] initial cleanup --- src/aspire/abinitio/commonline_base.py | 46 ++++++-------------------- 1 file changed, 11 insertions(+), 35 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 3b2789f90a..bbc8f49170 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -73,11 +73,15 @@ def __init__( if str(full_width).lower() == "adaptive": full_width = -1 self.full_width = int(full_width) - self.max_shift = 15 # match MATLAB workflow for now math.ceil(max_shift * self.n_res) + self.max_shift = ( + 15 # match MATLAB workflow for now math.ceil(max_shift * self.n_res) + ) self.shift_step = shift_step self.offsets_max_shift = self.max_shift if offsets_max_shift is not None: - self.offsets_max_shift = 15 # match MATLAB workflow math.ceil(offsets_max_shift * self.n_res) + self.offsets_max_shift = ( + 15 # match MATLAB workflow math.ceil(offsets_max_shift * self.n_res) + ) self.offsets_shift_step = offsets_shift_step or self.shift_step self.mask = mask self._pf = None @@ -241,21 +245,6 @@ def build_clmatrix(self): Wrapper for cpu/gpu dispatch. """ - # # load matlab clstack for comparison - # from scipy.io import loadmat - # m_cl_fn = 'shift_dbg_clstack.mat' - # clstack = loadmat(m_cl_fn)['clstack'].astype(int,order='C') - 1 # convert to 0 based - # logger.warning(f"FORCING MATLAB Common Lines Matrix from {m_cl_fn}") - # # breakpoint() - # self.clmatrix = clstack - # return self.clmatrix - - # force_fn = "force_cl.npz" - # if os.path.exists(force_fn): - # logger.warning(f"FORCING Common Lines Matrix from {force_fn}") - # res = np.load(force_fn) - # self.clmatrix = res["clmatrix"] - # return self.clmatrix logger.info("Begin building Common Lines Matrix") @@ -268,10 +257,6 @@ def build_clmatrix(self): # Unpack result self._shifts_1d, self.clmatrix = res - # # save result - # logger.warning(f"Saving Common Lines to {force_fn}") - # np.savez(force_fn, clmatrix=self.clmatrix) - return self.clmatrix def build_clmatrix_host(self): @@ -516,7 +501,9 @@ def estimate_shifts(self, equations_factor=1, max_memory=10000): show = True # Estimate shifts. - est_shifts = sparse.linalg.lsqr(shift_equations, shift_b, atol=1e-8, btol=1e-8, iter_lim=100, show=show)[0] + est_shifts = sparse.linalg.lsqr( + shift_equations, shift_b, atol=1e-8, btol=1e-8, iter_lim=100, show=show + )[0] self.shifts = est_shifts.reshape((self.n_img, 2)) return self.shifts @@ -567,14 +554,6 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): _pf = self.pf.copy() pf = self.m_pf - # # # compare - # from scipy.io import loadmat - # pf_fn = 'matlab_pf_shift_dbg.mat' - # logger.warning(f"FORCING MATLAB PF from {pf_fn}") - # matlab_pf = loadmat(pf_fn)['pf'].T - # pf = matlab_pf - # # %pf=[flipdim(pf(2:end,n_theta/2+1:end,:),1) ; pf(:,1:n_theta/2,:) ]; - # # breakpoint() pf = np.concatenate( (np.flip(pf[:, n_theta_half:, 1:], axis=-1), pf[:, :n_theta_half, :]), @@ -605,7 +584,6 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): _, shift_phases, h = self._m_generate_shift_phase_and_filter( r_max, self.offsets_max_shift, self.offsets_shift_step ) - # breakpoint() d_theta = np.pi / n_theta_half @@ -729,7 +707,7 @@ def _estimate_num_shift_equations(self, n_img, equations_factor=1, max_memory=40 memory_total = equations_factor * ( n_equations_total * 2 * n_img * self.dtype.itemsize ) - #breakpoint() + if memory_total < (max_memory * 10**6): n_equations = int(np.ceil(equations_factor * n_equations_total)) else: @@ -772,7 +750,6 @@ def _m_generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): h = np.sqrt(np.abs(rk)) * np.exp(-(rk**2) / (2 * (r_max / 4) ** 2)) - # breakpoint() # matchy matchy return None, shift_phases, h def _generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): @@ -817,8 +794,7 @@ def _generate_index_pairs(self, n_equations): idx_j = np.array(idx_j, dtype="int") # Select random pairs based on the size of n_equations - # rp = choice(np.arange(len(idx_j)), size=n_equations, replace=False) - rp = np.arange(n_equations, dtype=int) + rp = choice(np.arange(len(idx_j)), size=n_equations, replace=False) return idx_i[rp], idx_j[rp] From 49e9a976ef6b0613187f5b247165e582e8eca499 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 12 Nov 2025 16:15:46 -0500 Subject: [PATCH 13/91] apply sign conventions --- src/aspire/abinitio/commonline_base.py | 4 +++- src/aspire/operators/polar_ft.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index bbc8f49170..eaaf9b0ec9 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -504,7 +504,9 @@ def estimate_shifts(self, equations_factor=1, max_memory=10000): est_shifts = sparse.linalg.lsqr( shift_equations, shift_b, atol=1e-8, btol=1e-8, iter_lim=100, show=show )[0] - self.shifts = est_shifts.reshape((self.n_img, 2)) + est_shifts = est_shifts.reshape((self.n_img, 2)) + # Convert (XY) axes and negate estimated shift orientations + self.shifts = -est_shifts[:, ::-1] return self.shifts diff --git a/src/aspire/operators/polar_ft.py b/src/aspire/operators/polar_ft.py index ce3dd8ca9d..ac6ff91d03 100644 --- a/src/aspire/operators/polar_ft.py +++ b/src/aspire/operators/polar_ft.py @@ -193,7 +193,7 @@ def shift(self, pfx, shifts): # Broadcast and accumulate phase shifts freqs = xp.tile(xp.asarray(self.freqs), (n, 1, 1)) - phase_shifts = xp.exp(-1j * xp.sum(freqs * -shifts[:, :, None], axis=1)) + phase_shifts = xp.exp(-1j * xp.sum(freqs * shifts[:, :, None], axis=1)) # Reshape flat frequency grid back to (..., ntheta//2, self.nrad) phase_shifts = phase_shifts.reshape(n, self.ntheta // 2, self.nrad) From d949ca953057ce5868d4f0f3c56b6404ef5fc076 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 12 Nov 2025 16:52:51 -0500 Subject: [PATCH 14/91] tox --- src/aspire/abinitio/commonline_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index eaaf9b0ec9..307f71e1ca 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -554,8 +554,8 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): # `estimate_shifts()` requires that rotations have already been estimated. rotations = Rotation(self.rotations) - _pf = self.pf.copy() - pf = self.m_pf + # _pf = self.pf.copy() + pf = self.m_pf.copy() pf = np.concatenate( (np.flip(pf[:, n_theta_half:, 1:], axis=-1), pf[:, :n_theta_half, :]), From 410708ffe45a23d9655a75887ef332a615fe80d5 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 13 Nov 2025 08:22:57 -0500 Subject: [PATCH 15/91] resolve easy items --- src/aspire/abinitio/commonline_base.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 307f71e1ca..ab9e69adc3 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -630,8 +630,8 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): pf_i_stack = pf_i[:, None] * shift_phases pf_i_flipped_stack = pf_i_flipped[:, None] * shift_phases - c1 = 2 * np.real(np.dot(pf_i_stack.T.conj(), pf_j)) - c2 = 2 * np.real(np.dot(pf_i_flipped_stack.T.conj(), pf_j)) + c1 = 2 * np.dot(pf_i_stack.T.conj(), pf_j).real + c2 = 2 * np.dot(pf_i_flipped_stack.T.conj(), pf_j).real # find the indices for the maximum values # and apply corresponding shifts @@ -786,14 +786,9 @@ def _generate_index_pairs(self, n_equations): """ Generate two index lists for [i, j] pairs of images """ - idx_i = [] - idx_j = [] - for i in range(self.n_img - 1): - tmp_j = range(i + 1, self.n_img) - idx_i.extend([i] * len(tmp_j)) - idx_j.extend(tmp_j) - idx_i = np.array(idx_i, dtype="int") - idx_j = np.array(idx_j, dtype="int") + + # Generate the i,j tuples of indices representing the upper triangle above the diagonal. + idx_i, idx_j = np.triu_indices(N, k=1) # Select random pairs based on the size of n_equations rp = choice(np.arange(len(idx_j)), size=n_equations, replace=False) From cea4ccc2c6c59fb19c6a6bee46fcd761b68cbbf4 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 13 Nov 2025 09:22:49 -0500 Subject: [PATCH 16/91] move equations_factor and max_memory to CL init --- src/aspire/abinitio/commonline_base.py | 70 +++++++++++--------------- 1 file changed, 28 insertions(+), 42 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index ab9e69adc3..9974d8d786 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -30,6 +30,8 @@ def __init__( shift_step=1, offsets_max_shift=None, offsets_shift_step=None, + offsets_max_memory=10000, + offsets_equations_factor=1, mask=True, ): """ @@ -54,6 +56,19 @@ def __init__( :param offsets_shift_step: Resolution of shift estimation for 2D offset estimation in pixels. Default `None` inherits from `shift_step`. + :param offsets_equations_factor: The factor to rescale the + number of shift equations (=1 in default) + :param offsets_max_memory: If there are N images and N_check + selected to check for common lines, then the exact system + of equations solved for the shifts is of size 2N x + N(N_check-1)/2 (2N unknowns and N(N_check-1)/2 equations). + This may be too big if N is large. The algorithm will use + `equations_factor` times the total number of equations if + the resulting total number of memory requirements is less + than `offsets_max_memory` (in megabytes); otherwise it will reduce + the number of equations by approximation to fit in + `offsets_max_memory`. For more information see the + references in `estimate_shifts`. Defaults to 10GB. :param hist_bin_width: Bin width in smoothing histogram (degrees). :param full_width: Selection width around smoothed histogram peak (degrees). `adaptive` will attempt to automatically find the smallest number of @@ -83,6 +98,8 @@ def __init__( 15 # match MATLAB workflow math.ceil(offsets_max_shift * self.n_res) ) self.offsets_shift_step = offsets_shift_step or self.shift_step + self.offsets_equations_factor = offsets_equations_factor + self.offsets_max_memory = int(offsets_max_memory) self.mask = mask self._pf = None self._m_pf = None @@ -461,7 +478,7 @@ def build_clmatrix_cu(self): # Note diagnostic 1d shifts are not computed in the CUDA implementation. return None, clmatrix - def estimate_shifts(self, equations_factor=1, max_memory=10000): + def estimate_shifts(self): """ Estimate 2D shifts in images @@ -478,22 +495,10 @@ def estimate_shifts(self, equations_factor=1, max_memory=10000): T. Vogt, W. Dahmen, and P. Binev (Eds.) Nanostructure Science and Technology Series, Springer, 2012, pp. 147–177 - - :param equations_factor: The factor to rescale the number of shift equations - (=1 in default) - :param max_memory: If there are N images and N_check selected to check - for common lines, then the exact system of equations solved for the shifts - is of size 2N x N(N_check-1)/2 (2N unknowns and N(N_check-1)/2 equations). - This may be too big if N is large. The algorithm will use `equations_factor` - times the total number of equations if the resulting total number of memory - requirements is less than `max_memory` (in megabytes); otherwise it will - reduce the number of equations by approximation to fit in `max_memory`. """ # Generate approximated shift equations from estimated rotations - shift_equations, shift_b = self._get_shift_equations_approx( - equations_factor, max_memory - ) + shift_equations, shift_b = self._get_shift_equations_approx() # Solve the linear equation, optionally printing numerical debug details. show = False @@ -522,7 +527,7 @@ def estimate(self, **kwargs): return self.rotations, self.shifts - def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): + def _get_shift_equations_approx(self): """ Generate approximated shift equations from estimated rotations @@ -535,16 +540,6 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): This function processes the (Fourier transformed) images exactly as the `build_clmatrix` function. - :param equations_factor: The factor to rescale the number of shift equations - (=1 in default) - :param max_memory: If there are N images and N_check selected to check - for common lines, then the exact system of equations solved for the shifts - is of size 2N x N(N_check-1)/2 (2N unknowns and N(N_check-1)/2 equations). - This may be too big if N is large. The algorithm will use `equations_factor` - times the total number of equations if the resulting total number of - memory requirements is less than `max_memory` (in megabytes); otherwise it - will reduce the number of equations to fit in `max_memory`. - :return; The left and right-hand side of shift equations """ @@ -563,9 +558,7 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): ) # Estimate number of equations that will be used to calculate the shifts - n_equations = self._estimate_num_shift_equations( - n_img, equations_factor, max_memory - ) + n_equations = self._estimate_num_shift_equations(n_img) # Allocate local variables for estimating 2D shifts based on the estimated number # of equations. The shift equations are represented using a sparse matrix, @@ -681,7 +674,7 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): return shift_equations, shift_b - def _estimate_num_shift_equations(self, n_img, equations_factor=1, max_memory=4000): + def _estimate_num_shift_equations(self, n_img): """ Estimate total number of shift equations in images @@ -689,15 +682,6 @@ def _estimate_num_shift_equations(self, n_img, equations_factor=1, max_memory=40 number of images and preselected memory factor. :param n_img: The total number of input images - :param equations_factor: The factor to rescale the number of shift equations - (=1 in default) - :param max_memory: If there are N images and N_check selected to check - for common lines, then the exact system of equations solved for the shifts - is of size 2N x N(N_check-1)/2 (2N unknowns and N(N_check-1)/2 equations). - This may be too big if N is large. The algorithm will use `equations_factor` - times the total number of equations if the resulting total number of - memory requirements is less than `max_memory` (in megabytes); otherwise it - will reduce the number of equations to fit in `max_memory`. :return: Estimated number of shift equations """ # Number of equations that will be used to estimation the shifts @@ -706,14 +690,16 @@ def _estimate_num_shift_equations(self, n_img, equations_factor=1, max_memory=40 # Estimated memory requirements for the full system of equation. # This ignores the sparsity of the system, since backslash seems to # ignore it. - memory_total = equations_factor * ( + memory_total = self.offsets_equations_factor * ( n_equations_total * 2 * n_img * self.dtype.itemsize ) - if memory_total < (max_memory * 10**6): - n_equations = int(np.ceil(equations_factor * n_equations_total)) + if memory_total < (self.offets_max_memory * 10**6): + n_equations = int( + np.ceil(self.offsets_equations_factor * n_equations_total) + ) else: - subsampling_factor = (max_memory * 10**6) / memory_total + subsampling_factor = (self.offsets_max_memory * 10**6) / memory_total subsampling_factor = min(1.0, subsampling_factor) n_equations = int(np.ceil(n_equations_total * subsampling_factor)) From 6c3ee50ebbd2be46415e4a49593c1652b34ada0c Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 13 Nov 2025 11:02:08 -0500 Subject: [PATCH 17/91] move equations_factor and max_memory to CL init --- src/aspire/abinitio/commonline_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 9974d8d786..36f6f2d754 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -694,7 +694,7 @@ def _estimate_num_shift_equations(self, n_img): n_equations_total * 2 * n_img * self.dtype.itemsize ) - if memory_total < (self.offets_max_memory * 10**6): + if memory_total < (self.offsets_max_memory * 10**6): n_equations = int( np.ceil(self.offsets_equations_factor * n_equations_total) ) @@ -774,7 +774,7 @@ def _generate_index_pairs(self, n_equations): """ # Generate the i,j tuples of indices representing the upper triangle above the diagonal. - idx_i, idx_j = np.triu_indices(N, k=1) + idx_i, idx_j = np.triu_indices(self.n_img, k=1) # Select random pairs based on the size of n_equations rp = choice(np.arange(len(idx_j)), size=n_equations, replace=False) From 4244253b9e3de5c89dea515ad4835afc974423fb Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 13 Nov 2025 14:52:06 -0500 Subject: [PATCH 18/91] melding full MATLAB port backwards with py code --- src/aspire/abinitio/commonline_base.py | 35 ++++++-------------------- 1 file changed, 7 insertions(+), 28 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 36f6f2d754..522f1806bd 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -102,7 +102,6 @@ def __init__( self.offsets_max_memory = int(offsets_max_memory) self.mask = mask self._pf = None - self._m_pf = None # Sanity limit to match potential clmatrix dtype of int16. if self.n_img > (2**15 - 1): @@ -157,12 +156,6 @@ def pf(self): self._prepare_pf() return self._pf - @property - def m_pf(self): - if self._m_pf is None: - self._prepare_pf() - return self._m_pf - def _prepare_pf(self): """ Prepare the polar Fourier transform used for correlations. @@ -184,7 +177,6 @@ def _prepare_pf(self): (self.n_res, self.n_res), self.n_rad, self.n_theta, dtype=self.dtype ) pf = pft.transform(imgs) - self._m_pf = pft.half_to_full(pf) # We remove the DC the component. pf has size (n_img) x (n_theta/2) x (n_rad-1), # with pf[:, :, 0] containing low frequency content and pf[:, :, -1] containing @@ -549,13 +541,7 @@ def _get_shift_equations_approx(self): # `estimate_shifts()` requires that rotations have already been estimated. rotations = Rotation(self.rotations) - # _pf = self.pf.copy() - pf = self.m_pf.copy() - - pf = np.concatenate( - (np.flip(pf[:, n_theta_half:, 1:], axis=-1), pf[:, :n_theta_half, :]), - axis=-1, - ) + pf = self.pf.copy() # Estimate number of equations that will be used to calculate the shifts n_equations = self._estimate_num_shift_equations(n_img) @@ -575,8 +561,8 @@ def _get_shift_equations_approx(self): # The shift phases are pre-defined in a range of max_shift that can be # applied to maximize the common line calculation. The common-line filter # is also applied to the radial direction for easier detection. - r_max = (pf.shape[2] - 1) // 2 # pf is a different Matlab size - _, shift_phases, h = self._m_generate_shift_phase_and_filter( + r_max = pf.shape[2] + _, shift_phases, h = self._generate_shift_phase_and_filter( r_max, self.offsets_max_shift, self.offsets_shift_step ) @@ -608,20 +594,13 @@ def _get_shift_equations_approx(self): pf_j = pf[j, c_ji - n_theta_half] # perform bandpass filter, normalize each ray of each image, - pf_i = pf_i * h - pf_i[r_max - 1 : r_max + 2] = 0 - pf_i = pf_i / np.linalg.norm(pf_i) - pf_i = pf_i[:r_max] - - pf_j = pf_j * h - pf_j[r_max - 1 : r_max + 2] = 0 - pf_j = pf_j / np.linalg.norm(pf_j) - pf_j = pf_j[:r_max] + pf_i = self._apply_filter_and_norm("i, i -> i", pf_i, r_max, h) + pf_j = self._apply_filter_and_norm("i, i -> i", pf_j, r_max, h) # apply the shifts to images pf_i_flipped = np.conj(pf_i) - pf_i_stack = pf_i[:, None] * shift_phases - pf_i_flipped_stack = pf_i_flipped[:, None] * shift_phases + pf_i_stack = pf_i[:, None] * shift_phases.T + pf_i_flipped_stack = pf_i_flipped[:, None] * shift_phases.T c1 = 2 * np.dot(pf_i_stack.T.conj(), pf_j).real c2 = 2 * np.dot(pf_i_flipped_stack.T.conj(), pf_j).real From ace5c426aa348c828e3bb505d3588ae22550fcb0 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 13 Nov 2025 15:38:28 -0500 Subject: [PATCH 19/91] rm unused function from debugging --- src/aspire/abinitio/commonline_base.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 522f1806bd..1646486474 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -695,30 +695,6 @@ def _estimate_num_shift_equations(self, n_img): return n_equations - def _m_generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): - """ - Port the code from MATLAB first, grumble grumble, inside thoughts bro. - """ - - # Number of shifts to try - n_shifts = int(np.ceil(2 * max_shift / shift_step + 1)) - - # only half of ray, excluding the DC component. - rk = np.arange(-r_max, r_max + 1, dtype=self.dtype) - rk2 = rk[:r_max] - - shift_phases = np.zeros((r_max, n_shifts), dtype=np.complex128) - for shiftidx in range(n_shifts): - # zero based shiftidx - shift = -max_shift + shiftidx * shift_step - shift_phases[:, shiftidx] = np.exp( - -2 * np.pi * 1j * rk2 * shift / (2 * r_max + 1) - ) - - h = np.sqrt(np.abs(rk)) * np.exp(-(rk**2) / (2 * (r_max / 4) ** 2)) - - return None, shift_phases, h - def _generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): """ Prepare the shift phases and generate filter for common-line detection From fdec58f4faad615e4347adeda1f84400bc2b5a8d Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 14 Nov 2025 08:36:45 -0500 Subject: [PATCH 20/91] revert forcing the matlab offset search space --- src/aspire/abinitio/commonline_base.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 1646486474..3c457ed937 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -88,15 +88,11 @@ def __init__( if str(full_width).lower() == "adaptive": full_width = -1 self.full_width = int(full_width) - self.max_shift = ( - 15 # match MATLAB workflow for now math.ceil(max_shift * self.n_res) - ) + self.max_shift = math.ceil(max_shift * self.n_res) self.shift_step = shift_step self.offsets_max_shift = self.max_shift if offsets_max_shift is not None: - self.offsets_max_shift = ( - 15 # match MATLAB workflow math.ceil(offsets_max_shift * self.n_res) - ) + self.offsets_max_shift = math.ceil(offsets_max_shift * self.n_res) self.offsets_shift_step = offsets_shift_step or self.shift_step self.offsets_equations_factor = offsets_equations_factor self.offsets_max_memory = int(offsets_max_memory) From 242ac701b685346f2f086e22695771a26457a35a Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 14 Nov 2025 14:32:59 -0500 Subject: [PATCH 21/91] pass through memory lim kwargs --- src/aspire/abinitio/commonline_sync.py | 2 ++ src/aspire/abinitio/commonline_sync3n.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/aspire/abinitio/commonline_sync.py b/src/aspire/abinitio/commonline_sync.py index 222173f318..6fcc8daddc 100644 --- a/src/aspire/abinitio/commonline_sync.py +++ b/src/aspire/abinitio/commonline_sync.py @@ -33,6 +33,7 @@ def __init__( hist_bin_width=3, full_width=6, mask=True, + **kwargs, ): """ Initialize an object for estimating 3D orientations using synchronization matrix @@ -60,6 +61,7 @@ def __init__( hist_bin_width=hist_bin_width, full_width=full_width, mask=mask, + **kwargs, ) self.syncmatrix = None diff --git a/src/aspire/abinitio/commonline_sync3n.py b/src/aspire/abinitio/commonline_sync3n.py index 25bfd63015..ed7ca94048 100644 --- a/src/aspire/abinitio/commonline_sync3n.py +++ b/src/aspire/abinitio/commonline_sync3n.py @@ -61,6 +61,7 @@ def __init__( J_weighting=False, hist_intervals=100, disable_gpu=False, + **kwargs, ): """ Initialize object for estimating 3D orientations. @@ -100,6 +101,7 @@ def __init__( hist_bin_width=hist_bin_width, full_width=full_width, mask=mask, + **kwargs, ) # Generate pair mappings From efcd82fcc144babcd0264612153e99f1d5504b75 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 17 Nov 2025 08:24:13 -0500 Subject: [PATCH 22/91] update tests --- tests/test_orient_sync_voting.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 4bea8df6c2..f9eaf77d3e 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -9,6 +9,7 @@ from aspire.abinitio import CLOrient3D, CLSyncVoting from aspire.commands.orient3d import orient3d +from aspire.downloader import emdb_2660 from aspire.noise import WhiteNoiseAdder from aspire.source import Simulation from aspire.utils import mean_aligned_angular_distance, rots_to_clmatrix @@ -51,23 +52,15 @@ def dtype(request): @pytest.fixture(scope="module") def source_orientation_objs(resolution, offsets, dtype): src = Simulation( - n=50, + n=500, L=resolution, - vols=AsymmetricVolume(L=resolution, C=1, K=100, seed=0, dtype=dtype).generate(), + vols=emdb_2660().downsample(resolution), offsets=offsets, amplitudes=1, seed=0, - ) + ).cache() - # Search for common lines over less shifts for 0 offsets. - max_shift = 1 / resolution - shift_step = 1 - if src.offsets.all() != 0: - max_shift = 0.20 - shift_step = 0.25 # Reduce shift steps for non-integer offsets of Simulation. - orient_est = CLSyncVoting( - src, max_shift=max_shift, shift_step=shift_step, mask=False - ) + orient_est = CLSyncVoting(src) # Estimate rotations once for all tests. orient_est.estimate_rotations() @@ -119,11 +112,12 @@ def test_estimate_shifts_with_gt_rots(source_orientation_objs): # Calculate the mean 2D distance between estimates and ground truth. error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() - # Assert that on average estimated shifts are close (within 0.5 pix) to src.offsets + # Assert that on average estimated shifts are close to src.offsets if src.offsets.all() != 0: - np.testing.assert_array_less(mean_dist, 0.5) + np.testing.assert_array_less(mean_dist, 2) else: np.testing.assert_allclose(mean_dist, 0) @@ -138,9 +132,9 @@ def test_estimate_shifts_with_est_rots(source_orientation_objs): error = src.offsets - est_shifts mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() - # Assert that on average estimated shifts are close (within 0.5 pix) to src.offsets + # Assert that on average estimated shifts are close to src.offsets if src.offsets.all() != 0: - np.testing.assert_array_less(mean_dist, 0.5) + np.testing.assert_array_less(mean_dist, 2) else: np.testing.assert_allclose(mean_dist, 0) From 1fb53060305e4fb3cbe3496efaa240905e8909bb Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 17 Nov 2025 08:57:43 -0500 Subject: [PATCH 23/91] pass through kwargs --- src/aspire/abinitio/commonline_c2.py | 2 ++ src/aspire/abinitio/commonline_c3_c4.py | 2 ++ src/aspire/abinitio/commonline_cn.py | 2 ++ src/aspire/abinitio/commonline_d2.py | 2 ++ 4 files changed, 8 insertions(+) diff --git a/src/aspire/abinitio/commonline_c2.py b/src/aspire/abinitio/commonline_c2.py index 52b63cc9e8..7bd3495b84 100644 --- a/src/aspire/abinitio/commonline_c2.py +++ b/src/aspire/abinitio/commonline_c2.py @@ -48,6 +48,7 @@ def __init__( min_dist_cls=25, seed=None, mask=True, + **kwargs, ): """ Initialize object for estimating 3D orientations for molecules with C2 symmetry. @@ -77,6 +78,7 @@ def __init__( degree_res=degree_res, seed=seed, mask=mask, + **kwargs, ) self.min_dist_cls = min_dist_cls diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index af74dc6f6c..c581005c38 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -56,6 +56,7 @@ def __init__( degree_res=1, seed=None, mask=True, + **kwargs, ): """ Initialize object for estimating 3D orientations for molecules with C3 and C4 symmetry. @@ -81,6 +82,7 @@ def __init__( max_shift=max_shift, shift_step=shift_step, mask=mask, + **kwargs, ) self._check_symmetry(symmetry) diff --git a/src/aspire/abinitio/commonline_cn.py b/src/aspire/abinitio/commonline_cn.py index e0d69cb086..297bf04a6d 100644 --- a/src/aspire/abinitio/commonline_cn.py +++ b/src/aspire/abinitio/commonline_cn.py @@ -41,6 +41,7 @@ def __init__( equator_threshold=10, seed=None, mask=True, + **kwargs, ): """ Initialize object for estimating 3D orientations for molecules with Cn symmetry, n>4. @@ -74,6 +75,7 @@ def __init__( degree_res=degree_res, seed=seed, mask=mask, + **kwargs, ) self.n_points_sphere = n_points_sphere diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 8acbade2ca..f8022d3db9 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -37,6 +37,7 @@ def __init__( epsilon=0.01, seed=None, mask=True, + **kwargs, ): """ Initialize object for estimating 3D orientations for molecules with D2 symmetry. @@ -65,6 +66,7 @@ def __init__( max_shift=max_shift, shift_step=shift_step, mask=mask, + **kwargs, ) self.grid_res = grid_res From 80438d216135473613c63e9b34a4a805df0f31a3 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 17 Nov 2025 09:09:17 -0500 Subject: [PATCH 24/91] mark slow tests expensive --- tests/test_orient_sync_voting.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index f9eaf77d3e..b7c6c2d97b 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -68,6 +68,7 @@ def source_orientation_objs(resolution, offsets, dtype): return src, orient_est +@pytest.mark.expensive def test_build_clmatrix(source_orientation_objs): src, orient_est = source_orientation_objs @@ -90,6 +91,7 @@ def test_build_clmatrix(source_orientation_objs): assert within_5 / angle_diffs.size > tol +@pytest.mark.expensive def test_estimate_rotations(source_orientation_objs): src, orient_est = source_orientation_objs @@ -99,6 +101,7 @@ def test_estimate_rotations(source_orientation_objs): mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=1) +@pytest.mark.expensive def test_estimate_shifts_with_gt_rots(source_orientation_objs): src, orient_est = source_orientation_objs @@ -122,6 +125,7 @@ def test_estimate_shifts_with_gt_rots(source_orientation_objs): np.testing.assert_allclose(mean_dist, 0) +@pytest.mark.expensive def test_estimate_shifts_with_est_rots(source_orientation_objs): src, orient_est = source_orientation_objs @@ -139,6 +143,7 @@ def test_estimate_shifts_with_est_rots(source_orientation_objs): np.testing.assert_allclose(mean_dist, 0) +@pytest.mark.expensive def test_estimate_rotations_fuzzy_mask(): noisy_src = Simulation( n=35, From ff926434c625e26db2abbba02100897b0d2f138d Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 17 Nov 2025 09:41:34 -0500 Subject: [PATCH 25/91] add shift pass through test --- tests/test_orient_sync_voting.py | 75 ++++++++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index b7c6c2d97b..0cdcaa0c22 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -7,11 +7,22 @@ import pytest from click.testing import CliRunner -from aspire.abinitio import CLOrient3D, CLSyncVoting +from aspire.abinitio import ( + CLOrient3D, + CLSymmetryC2, + CLSymmetryC3C4, + CLSymmetryCn, + CLSymmetryD2, + CLSync3N, + CLSyncVoting, + CommonlineIRLS, + CommonlineLUD, + CommonlineSDP, +) from aspire.commands.orient3d import orient3d from aspire.downloader import emdb_2660 from aspire.noise import WhiteNoiseAdder -from aspire.source import Simulation +from aspire.source import ArrayImageSource, Simulation from aspire.utils import mean_aligned_angular_distance, rots_to_clmatrix from aspire.volume import AsymmetricVolume @@ -33,6 +44,18 @@ pytest.param(np.float64, marks=pytest.mark.expensive), ] +CL_ALGOS = [ + CLSymmetryC2, + CLSymmetryC3C4, + CLSymmetryCn, + CLSymmetryD2, + CLSync3N, + CLSyncVoting, + CommonlineIRLS, + CommonlineLUD, + CommonlineSDP, +] + @pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}", scope="module") def resolution(request): @@ -60,7 +83,15 @@ def source_orientation_objs(resolution, offsets, dtype): seed=0, ).cache() - orient_est = CLSyncVoting(src) + # Search for common lines over less shifts for 0 offsets. + max_shift = 1 / resolution + shift_step = 1 + if src.offsets.all() != 0: + max_shift = 0.20 + shift_step = 0.25 # Reduce shift steps for non-integer offsets of Simulation. + orient_est = CLSyncVoting( + src, max_shift=max_shift, shift_step=shift_step, mask=False + ) # Estimate rotations once for all tests. orient_est.estimate_rotations() @@ -225,3 +256,41 @@ def test_command_line(): ) # check that the command completed successfully assert result.exit_code == 0 + + +@pytest.mark.parametrize("cl_algo", CL_ALGOS) +def test_offset_param_passthrough(cl_algo): + """ + Systematically test that offset search configuration passes through all CL classes. + """ + + src = ArrayImageSource(np.random.randn(4, 4), pixel_size=1.23) + + test_args = { + "offsets_max_shift": 0.5, + "offsets_shift_step": 0.1, + "offsets_equations_factor": 1, + "offsets_max_memory": 200, + } + + # Handle special case classes + if cl_algo == CLSymmetryC3C4: + test_args["symmetry"] = "C3" + elif cl_algo == CLSymmetryCn: + test_args["symmetry"] = "C17" + + # Instantiate the CL class under test + orient_est = cl_algo(src, **test_args) + + # Loop over the args and assert they are correctly assigned + for arg, val in test_args.items(): + + # Handle special case arguments + if arg == "offsets_max_shift": + # convert from ratio to pixels + val = np.ceil(val * src.L) + elif arg == "symmetry": + # convert from string `symmetry` to int `order` + arg, val = "order", int(val[1:]) + + assert getattr(orient_est, arg) == val From c63b6ca25662c38471a07f870594fd97ff7718ef Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 19 Nov 2025 09:33:49 -0500 Subject: [PATCH 26/91] xfail another flaky shift test for 1340 --- tests/test_commonline_sync3n.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_commonline_sync3n.py b/tests/test_commonline_sync3n.py index 8305aa85e9..3381a51869 100644 --- a/tests/test_commonline_sync3n.py +++ b/tests/test_commonline_sync3n.py @@ -93,6 +93,7 @@ def test_build_clmatrix(source_orientation_objs): assert within_5 / angle_diffs.size > tol +@pytest.mark.xfail(reason="Issue #1340") def test_estimate_shifts_with_gt_rots(source_orientation_objs): src, orient_est = source_orientation_objs @@ -115,6 +116,7 @@ def test_estimate_shifts_with_gt_rots(source_orientation_objs): np.testing.assert_allclose(mean_dist, 0) +@pytest.mark.xfail(reason="Issue #1340") def test_estimate_shifts_with_est_rots(source_orientation_objs): src, orient_est = source_orientation_objs # Estimate shifts using estimated rotations. From e6dcc23c2cd5ab4c43623bf6229522396e9d1194 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 20 Nov 2025 11:55:30 -0500 Subject: [PATCH 27/91] Revert MATLAB PFT freq grid convention --- src/aspire/abinitio/commonline_base.py | 5 +++++ src/aspire/operators/polar_ft.py | 5 +++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 3c457ed937..b179a1d7df 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -589,6 +589,11 @@ def _get_shift_equations_approx(self): else: pf_j = pf[j, c_ji - n_theta_half] + # Use ray from opposite side of origin. + # Correpsonds to `freqs` convention in PFT, + # where the legacy code used a negated frequency grid. + pf_i, pf_j = np.conj(pf_i), np.conj(pf_j) + # perform bandpass filter, normalize each ray of each image, pf_i = self._apply_filter_and_norm("i, i -> i", pf_i, r_max, h) pf_j = self._apply_filter_and_norm("i, i -> i", pf_j, r_max, h) diff --git a/src/aspire/operators/polar_ft.py b/src/aspire/operators/polar_ft.py index ac6ff91d03..bd50da727f 100644 --- a/src/aspire/operators/polar_ft.py +++ b/src/aspire/operators/polar_ft.py @@ -135,8 +135,9 @@ def _transform(self, x): resolution = x.shape[-1] + # Note, `freqs` is negated from legacy MATLAB. # nufft call should return `pf` as array type (np or cp) of `x` - pf = nufft(x, -self.freqs) / resolution**2 + pf = nufft(x, self.freqs) / resolution**2 return pf.reshape(*stack_shape, self.ntheta // 2, self.nrad) @@ -193,7 +194,7 @@ def shift(self, pfx, shifts): # Broadcast and accumulate phase shifts freqs = xp.tile(xp.asarray(self.freqs), (n, 1, 1)) - phase_shifts = xp.exp(-1j * xp.sum(freqs * shifts[:, :, None], axis=1)) + phase_shifts = xp.exp(-1j * xp.sum(-freqs * shifts[:, :, None], axis=1)) # Reshape flat frequency grid back to (..., ntheta//2, self.nrad) phase_shifts = phase_shifts.reshape(n, self.ntheta // 2, self.nrad) From af19b1d1232d07adf9224877a5e9a8b17b318d18 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 25 Nov 2025 10:28:24 -0500 Subject: [PATCH 28/91] bump numpy>=1.25.0 --- environment-accelerate.yml | 2 +- environment-default.yml | 2 +- environment-intel.yml | 2 +- environment-openblas.yml | 2 +- environment-win64.yml | 2 +- pyproject.toml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/environment-accelerate.yml b/environment-accelerate.yml index 5248c3a953..0ff8097e49 100644 --- a/environment-accelerate.yml +++ b/environment-accelerate.yml @@ -7,7 +7,7 @@ channels: dependencies: - pip - python=3.9 - - numpy=1.24.1 + - numpy>=1.25.0 - scipy=1.10.1 - scikit-learn - scikit-image diff --git a/environment-default.yml b/environment-default.yml index f827ec55bc..2bcd3bc0ea 100644 --- a/environment-default.yml +++ b/environment-default.yml @@ -7,7 +7,7 @@ channels: dependencies: - pip - python=3.9 - - numpy=1.23.5 + - numpy>=1.25.0 - scipy=1.9.3 - scikit-learn - scikit-image diff --git a/environment-intel.yml b/environment-intel.yml index 840dd92f76..98c5feee7e 100644 --- a/environment-intel.yml +++ b/environment-intel.yml @@ -7,7 +7,7 @@ channels: dependencies: - pip - python=3.9 - - numpy=1.23.5 + - numpy>=1.25.0 - scipy=1.9.3 - scikit-learn - scikit-image diff --git a/environment-openblas.yml b/environment-openblas.yml index 088035f88d..41f761c696 100644 --- a/environment-openblas.yml +++ b/environment-openblas.yml @@ -7,7 +7,7 @@ channels: dependencies: - pip - python=3.9 - - numpy=1.23.5 + - numpy>=1.25.0 - scipy=1.9.3 - scikit-learn - scikit-image diff --git a/environment-win64.yml b/environment-win64.yml index 34ca5d9fa6..62e55ceda9 100644 --- a/environment-win64.yml +++ b/environment-win64.yml @@ -7,7 +7,7 @@ channels: dependencies: - pip - python=3.9 - - numpy=1.23.5 + - numpy>=1.25.0 - scipy=1.9.3 - scikit-learn - scikit-image diff --git a/pyproject.toml b/pyproject.toml index 8a7696dcb9..2e132d8e1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "joblib", "matplotlib >= 3.2.0", "mrcfile", - "numpy>=1.21.5", + "numpy>=1.25.0", "packaging", "pooch>=1.7.0", "pillow", From f35b65b0cb8e82b143999d78ab0c980cbdd55c56 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 1 Dec 2025 09:11:27 -0500 Subject: [PATCH 29/91] pin env*.yml np version --- environment-accelerate.yml | 2 +- environment-default.yml | 2 +- environment-intel.yml | 2 +- environment-openblas.yml | 2 +- environment-win64.yml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/environment-accelerate.yml b/environment-accelerate.yml index 0ff8097e49..18a1c73204 100644 --- a/environment-accelerate.yml +++ b/environment-accelerate.yml @@ -7,7 +7,7 @@ channels: dependencies: - pip - python=3.9 - - numpy>=1.25.0 + - numpy=1.25.0 - scipy=1.10.1 - scikit-learn - scikit-image diff --git a/environment-default.yml b/environment-default.yml index 2bcd3bc0ea..274292c723 100644 --- a/environment-default.yml +++ b/environment-default.yml @@ -7,7 +7,7 @@ channels: dependencies: - pip - python=3.9 - - numpy>=1.25.0 + - numpy=1.25.0 - scipy=1.9.3 - scikit-learn - scikit-image diff --git a/environment-intel.yml b/environment-intel.yml index 98c5feee7e..c302df2846 100644 --- a/environment-intel.yml +++ b/environment-intel.yml @@ -7,7 +7,7 @@ channels: dependencies: - pip - python=3.9 - - numpy>=1.25.0 + - numpy=1.25.0 - scipy=1.9.3 - scikit-learn - scikit-image diff --git a/environment-openblas.yml b/environment-openblas.yml index 41f761c696..4a216b9d9c 100644 --- a/environment-openblas.yml +++ b/environment-openblas.yml @@ -7,7 +7,7 @@ channels: dependencies: - pip - python=3.9 - - numpy>=1.25.0 + - numpy=1.25.0 - scipy=1.9.3 - scikit-learn - scikit-image diff --git a/environment-win64.yml b/environment-win64.yml index 62e55ceda9..aa5d969e8c 100644 --- a/environment-win64.yml +++ b/environment-win64.yml @@ -7,7 +7,7 @@ channels: dependencies: - pip - python=3.9 - - numpy>=1.25.0 + - numpy=1.25.0 - scipy=1.9.3 - scikit-learn - scikit-image From c92fa1754e6ee7c39b1c602d752a1382edf21bed Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 1 Dec 2025 13:46:04 -0500 Subject: [PATCH 30/91] bump scipy accel envs --- environment-accelerate.yml | 2 +- environment-default.yml | 2 +- environment-intel.yml | 2 +- environment-openblas.yml | 2 +- environment-win64.yml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/environment-accelerate.yml b/environment-accelerate.yml index 18a1c73204..5738909d7d 100644 --- a/environment-accelerate.yml +++ b/environment-accelerate.yml @@ -8,7 +8,7 @@ dependencies: - pip - python=3.9 - numpy=1.25.0 - - scipy=1.10.1 + - scipy=1.13.1 - scikit-learn - scikit-image - libblas=*=*accelerate diff --git a/environment-default.yml b/environment-default.yml index 274292c723..2b1d542fd6 100644 --- a/environment-default.yml +++ b/environment-default.yml @@ -8,6 +8,6 @@ dependencies: - pip - python=3.9 - numpy=1.25.0 - - scipy=1.9.3 + - scipy=1.13.1 - scikit-learn - scikit-image diff --git a/environment-intel.yml b/environment-intel.yml index c302df2846..3295ff934a 100644 --- a/environment-intel.yml +++ b/environment-intel.yml @@ -8,7 +8,7 @@ dependencies: - pip - python=3.9 - numpy=1.25.0 - - scipy=1.9.3 + - scipy=1.13.1 - scikit-learn - scikit-image - mkl_fft diff --git a/environment-openblas.yml b/environment-openblas.yml index 4a216b9d9c..04f80266be 100644 --- a/environment-openblas.yml +++ b/environment-openblas.yml @@ -8,7 +8,7 @@ dependencies: - pip - python=3.9 - numpy=1.25.0 - - scipy=1.9.3 + - scipy=1.13.1 - scikit-learn - scikit-image - libblas=*=*openblas diff --git a/environment-win64.yml b/environment-win64.yml index aa5d969e8c..ea9ed840f6 100644 --- a/environment-win64.yml +++ b/environment-win64.yml @@ -8,7 +8,7 @@ dependencies: - pip - python=3.9 - numpy=1.25.0 - - scipy=1.9.3 + - scipy=1.13.1 - scikit-learn - scikit-image - mkl=2024.1.* # possible regression impacts eig solver in later versions up to 2025.0 From e77a68ba63df99ac3283207214b5f18653748637 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 Dec 2025 07:51:09 -0500 Subject: [PATCH 31/91] update class alignment attrs fixed bug introduced via batching --- src/aspire/classification/averager2d.py | 54 +++++++++++++++++-------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index d982e746e5..301bc9bc37 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -84,7 +84,7 @@ def average( :param classes: class indices, refering to src. (src.n, n_nbor). :param reflections: Bool representing whether to reflect image in `classes`. - (n_clases, n_nbor) + (n_classes, n_nbor) :param coefs: Optional basis coefs (could avoid recomputing). (src.n, coef_count) :return: Stack of synthetic class average images as Image instance. @@ -152,6 +152,16 @@ def __init__( f"{self.__class__.__name__}'s composite_basis {self.composite_basis} must provide a `shift` method." ) + # Instantiate dicts to hold alignment results. + # Note dicts are used in place of arrays because: + # The entire set of `src.n` classes may not always need to be computed, + # and the order/batching where results are computed is potentially arbitrary. + # We may not know apriori how many nbors are in each class, + # and this may be variable with future methods. + self.rotations = dict() + self.shifts = dict() + self.dot_products = dict() + @abstractmethod def align(self, classes, reflections, basis_coefficients=None): """ @@ -192,9 +202,15 @@ def average( classes = np.atleast_2d(classes) reflections = np.atleast_2d(reflections) - self.rotations, self.shifts, self.dot_products = self.align( - classes, reflections, coefs - ) + rotations, shifts, dot_products = self.align(classes, reflections, coefs) + + # Assign batch results + src_indices = classes[:, 0] # First column of class table + for i, k in enumerate(src_indices): + self.rotations[k] = rotations[i] + if shifts is not None: + self.shifts[k] = shifts[i] + self.dot_products[k] = dot_products[i] n_classes, n_nbor = classes.shape @@ -212,22 +228,22 @@ def _innerloop(i): neighbors_imgs = Image(self._cls_images(classes[i])) # Do shifts - if self.shifts is not None: - neighbors_imgs = neighbors_imgs.shift(self.shifts[i]) + if shifts is not None: + neighbors_imgs = neighbors_imgs.shift(shifts[i]) neighbors_coefs = self.composite_basis.evaluate_t(neighbors_imgs) else: # Get the neighbors neighbors_ids = classes[i] neighbors_coefs = coefs[neighbors_ids] - if self.shifts is not None: + if shifts is not None: neighbors_coefs = self.composite_basis.shift( - neighbors_coefs, self.shifts[i] + neighbors_coefs, shifts[i] ) # Rotate in composite_basis neighbors_coefs = self.composite_basis.rotate( - neighbors_coefs, self.rotations[i], reflections[i] + neighbors_coefs, rotations[i], reflections[i] ) # Averaging in composite_basis @@ -580,9 +596,15 @@ def average( Otherwise is similar to `AligningAverager2D.average`. """ - self.rotations, self.shifts, self.dot_products = self.align( - classes, reflections, coefs - ) + rotations, shifts, dot_products = self.align(classes, reflections, coefs) + + # Assign batch results + src_indices = classes[:, 0] # First column of class table + for i, k in enumerate(src_indices): + self.rotations[k] = rotations[i] + if shifts is not None: + self.shifts[k] = shifts[i] + self.dot_products[k] = dot_products[i] n_classes, n_nbor = classes.shape @@ -601,14 +623,12 @@ def _innerloop(i): # Rotate in composite_basis neighbors_coefs = self.composite_basis.rotate( - neighbors_coefs, self.rotations[i], reflections[i] + neighbors_coefs, rotations[i], reflections[i] ) # Note shifts are after rotation for this approach! - if self.shifts is not None: - neighbors_coefs = self.composite_basis.shift( - neighbors_coefs, self.shifts[i] - ) + if shifts is not None: + neighbors_coefs = self.composite_basis.shift(neighbors_coefs, shifts[i]) # Averaging in composite_basis return self.image_stacker(neighbors_coefs.asnumpy()) From d3145a808fc77136f0c06c872d8e9bff25f04677 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 Dec 2025 09:27:29 -0500 Subject: [PATCH 32/91] leave progress False --- src/aspire/classification/averager2d.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index 301bc9bc37..3bbaff05b2 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -249,8 +249,8 @@ def _innerloop(i): # Averaging in composite_basis return self.image_stacker(neighbors_coefs.asnumpy()) - desc = f"Stacking and evaluating class averages from {self.composite_basis.__class__.__name__} to Cartesian" - for start in trange(0, n_classes, self.batch_size, desc=desc): + desc = f"Stacking and evaluating batch of class averages from {self.composite_basis.__class__.__name__} to Cartesian" + for start in trange(0, n_classes, self.batch_size, desc=desc, leave=False): end = min(start + self.batch_size, n_classes) for i, cls in enumerate( trange(start, end, desc="Stacking batch", leave=False) @@ -378,7 +378,7 @@ def align(self, classes, reflections, basis_coefficients=None): # This is done primarily in case of a tie later, we would take unshifted. test_shifts = self._shift_search_grid(self.src.L, self.radius, roll_zero=True) - for k in trange(n_classes, desc="Rotationally aligning classes"): + for k in trange(n_classes, desc="Rotationally aligning classes", leave=False): # We want to locally cache the original images, # because we will mutate them with shifts in the next loop. # This avoids recomputing them before each shift @@ -580,7 +580,7 @@ def _innerloop(k): dtype=self.dtype, ) - for k in trange(n_classes, desc="Rotationally aligning classes"): + for k in trange(n_classes, desc="Rotationally aligning classes", leave=False): rotations[k], shifts[k], dot_products[k] = _innerloop(k) return rotations, shifts, dot_products @@ -633,7 +633,7 @@ def _innerloop(i): # Averaging in composite_basis return self.image_stacker(neighbors_coefs.asnumpy()) - for i in trange(n_classes, desc="Stacking class averages"): + for i in trange(n_classes, desc="Stacking class averages", leave=False): b_avgs[i] = _innerloop(i) # Now we convert the averaged images from Basis to Cartesian. @@ -752,7 +752,7 @@ def _innerloop(k): return _rotations, _shifts, _dot_products - for k in trange(n_classes, desc="Rotationally aligning classes"): + for k in trange(n_classes, desc="Rotationally aligning classes", leave=False): rotations[k], shifts[k], dot_products[k] = _innerloop(k) return rotations, shifts, dot_products @@ -900,7 +900,7 @@ def align(self, classes, reflections, basis_coefficients=None): ) _images = xp.empty((n_nbor - 1, self.src.L, self.src.L), dtype=self.dtype) - for k in trange(n_classes, desc="Rotationally aligning classes"): + for k in trange(n_classes, desc="Rotationally aligning classes", leave=False): # We want to locally cache the original images, # because we will mutate them with shifts in the next loop. # This avoids recomputing them before each shift From 9d4f8c02f2a27a27e22dfa252833ff18acea54fa Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 Dec 2025 09:27:40 -0500 Subject: [PATCH 33/91] add selection message --- src/aspire/denoising/class_avg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/aspire/denoising/class_avg.py b/src/aspire/denoising/class_avg.py index 564549342d..646cdc9e2a 100644 --- a/src/aspire/denoising/class_avg.py +++ b/src/aspire/denoising/class_avg.py @@ -223,6 +223,7 @@ def _class_select(self): self._classify() # Perform class selection + logger.info("Performing class selection") _selection_indices = self.class_selector.select( self.class_indices, self.class_refl, From d7414d41004a371d76f110c313af4ca5829f362e Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 5 Dec 2025 08:26:22 -0500 Subject: [PATCH 34/91] Update class avg tutorial documentation --- gallery/tutorials/tutorials/class_averaging.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/gallery/tutorials/tutorials/class_averaging.py b/gallery/tutorials/tutorials/class_averaging.py index 7108a97846..063e10c1e0 100644 --- a/gallery/tutorials/tutorials/class_averaging.py +++ b/gallery/tutorials/tutorials/class_averaging.py @@ -226,6 +226,7 @@ est_shifts = avgs.averager.shifts est_dot_products = avgs.averager.dot_products +# These are dictionaries mapping each class to arrays of attributes. print(f"Estimated Rotations: {est_rotations}") print(f"Estimated Shifts: {est_shifts}") print(f"Estimated Dot Products: {est_dot_products}") @@ -241,7 +242,12 @@ original_img_nbr = noisy_src.images[original_img_nbr_idx].asnumpy()[0] # Rotate using estimated rotations. -angle = est_rotations[0, nbr] * 180 / np.pi +# First retrieve all angles for the `review_class` (original_img_0_idx), +# then lookup the specific neighbor `nbr` +assert ( + original_img_0_idx == review_class +), "DebugClassAvgSource should retain original source image ordering" +angle = est_rotations[original_img_0_idx][nbr] * 180 / np.pi if reflections[nbr]: print("Reflection reported.") original_img_nbr = np.flipud(original_img_nbr) From a823331fa7c427c9b32f51885cc917ea77f29274 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 10 Dec 2025 15:04:14 -0500 Subject: [PATCH 35/91] override scikit inverse_transform to allow for complex values --- src/aspire/numeric/complex_pca/complex_pca.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/aspire/numeric/complex_pca/complex_pca.py b/src/aspire/numeric/complex_pca/complex_pca.py index e38820c702..7c2f20fc03 100644 --- a/src/aspire/numeric/complex_pca/complex_pca.py +++ b/src/aspire/numeric/complex_pca/complex_pca.py @@ -15,6 +15,7 @@ import scipy.sparse as sp from sklearn.decomposition import PCA from sklearn.utils._array_api import get_namespace +from sklearn.utils.validation import check_is_fitted from .validation import check_array @@ -78,3 +79,26 @@ def _fit(self, X): raise ValueError( "Unrecognized svd_solver='{0}'" "".format(self._fit_svd_solver) ) + + def inverse_transform(self, X): + """Transform data back to its original space.""" + + xp, _ = get_namespace(X, self.components_, self.explained_variance_) + + check_is_fitted(self) + + X = check_array( + X, + dtype=[np.complex128, np.complex64, np.float64, np.float32], + ensure_2d=True, + copy=self.copy, + allow_complex=True, + ) + + if self.whiten: + scaled_components = ( + xp.sqrt(self.explained_variance_[:, np.newaxis]) * self.components_ + ) + return X @ scaled_components + self.mean_ + else: + return X @ self.components_ + self.mean_ From bcf06c29126b4e5c4a8262d502042453a7d044d5 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 10 Dec 2025 15:33:13 -0500 Subject: [PATCH 36/91] update macos runner to 15 --- .github/workflows/workflow.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index 4eae25111b..4c52cefd00 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -156,7 +156,7 @@ jobs: shell: bash -el {0} strategy: matrix: - os: [ubuntu-latest, ubuntu-22.04, macOS-latest, macOS-13] + os: [ubuntu-latest, ubuntu-22.04, macOS-latest, macOS-15] backend: [default, openblas] python-version: ['3.9'] include: From 8ccdbc3583576e71915bc77561be88f6663d2522 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 11 Dec 2025 08:37:29 -0500 Subject: [PATCH 37/91] macos 14 --- .github/workflows/workflow.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index 4c52cefd00..2bf5a462dc 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -156,7 +156,7 @@ jobs: shell: bash -el {0} strategy: matrix: - os: [ubuntu-latest, ubuntu-22.04, macOS-latest, macOS-15] + os: [ubuntu-latest, ubuntu-22.04, macOS-latest, macOS-14] backend: [default, openblas] python-version: ['3.9'] include: From 12d94e8f6d7e33329029daf27c3cb2a9e611ec09 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 1 Dec 2025 14:40:18 -0500 Subject: [PATCH 38/91] CL reorg - squash --- src/aspire/abinitio/J_sync.py | 247 +++++++++ src/aspire/abinitio/__init__.py | 11 +- src/aspire/abinitio/commonline_base.py | 42 +- src/aspire/abinitio/commonline_c2.py | 70 +-- src/aspire/abinitio/commonline_c3_c4.py | 308 ++---------- src/aspire/abinitio/commonline_cn.py | 81 ++- src/aspire/abinitio/commonline_d2.py | 6 +- src/aspire/abinitio/commonline_sync.py | 11 +- src/aspire/abinitio/commonline_sync3n.py | 15 +- src/aspire/abinitio/commonline_utils.py | 305 +++++++++--- src/aspire/abinitio/sync_voting.py | 610 ++++++++++++----------- src/aspire/utils/__init__.py | 1 - tests/test_commonline_utils.py | 118 +++++ tests/test_orient_symmetric.py | 90 +--- 14 files changed, 1088 insertions(+), 827 deletions(-) create mode 100644 src/aspire/abinitio/J_sync.py create mode 100644 tests/test_commonline_utils.py diff --git a/src/aspire/abinitio/J_sync.py b/src/aspire/abinitio/J_sync.py new file mode 100644 index 0000000000..b8458d9807 --- /dev/null +++ b/src/aspire/abinitio/J_sync.py @@ -0,0 +1,247 @@ +import logging + +import numpy as np +from numpy.linalg import norm + +from aspire.utils import J_conjugate, all_pairs, all_triplets, tqdm +from aspire.utils.random import randn + +logger = logging.getLogger(__name__) + + +class JSync: + """ + Class for handling J-synchronization methods. + """ + + def __init__( + self, + n, + epsilon=1e-2, + max_iters=1000, + seed=None, + ): + """ + Initialize JSync object for estimating global handedness synchronization for a + set of relative rotations, Rij = Ri @ Rj.T, where i <= j = 0, 1, ..., n. + + :param n: Number of images/rotations. + :param epsilon: Tolerance for the power method. + :param max_iters: Maximum iterations for the power method. + :param seed: Optional seed for power method initial random vector. + """ + self.n_img = n + self.epsilon = epsilon + self.max_iters = max_iters + self.seed = seed + + def global_J_sync(self, vijs): + """ + Global J-synchronization of all third row outer products. Given 3x3 matrices vijs, each + of which might contain a spurious J (ie. vij = J*vi*vj^T*J instead of vij = vi*vj^T), + we return vijs that all have either a spurious J or not. + + :param vijs: An (n-choose-2)x3x3 array where each 3x3 slice holds an estimate for the corresponding + outer-product vi*vj^T between the third rows of the rotation matrices Ri and Rj. Each estimate + might have a spurious J independently of other estimates. + + :return: vijs, all of which have a spurious J or not. + """ + + # Determine relative handedness of vijs. + sign_ij_J = self.power_method(vijs) + + # Synchronize vijs + vijs_sync = vijs.copy() + for i, sign in enumerate(sign_ij_J): + if sign == -1: + vijs_sync[i] = J_conjugate(vijs[i]) + + return vijs_sync + + def power_method(self, vijs): + """ + Calculate the leading eigenvector of the J-synchronization matrix + using the power method. + + As the J-synchronization matrix is of size (n-choose-2)x(n-choose-2), we + use the power method to compute the eigenvalues and eigenvectors, + while constructing the matrix on-the-fly. + + :param vijs: (n-choose-2)x3x3 array of estimates of relative orientation matrices. + + :return: An array of length n-choose-2 consisting of 1 or -1, where the sign of the + i'th entry indicates whether the i'th relative orientation matrix will be J-conjugated. + """ + + # Set power method tolerance and maximum iterations. + epsilon = self.epsilon + max_iters = self.max_iters + + # Initialize candidate eigenvectors + n_vijs = vijs.shape[0] + vec = randn(n_vijs, seed=self.seed) + vec = vec / norm(vec) + residual = 1 + itr = 0 + + # Power method iterations + logger.info( + "Initiating power method to estimate J-synchronization matrix eigenvector." + ) + while itr < max_iters and residual > epsilon: + itr += 1 + # Note, this appears to need double precision for accuracy in the following division. + vec_new = self._signs_times_v(vijs, vec).astype(np.float64, copy=False) + vec_new = vec_new / norm(vec_new) + residual = norm(vec_new - vec) + vec = vec_new + logger.info( + f"Iteration {itr}, residual {round(residual, 5)} (target {epsilon})" + ) + + # We need only the signs of the eigenvector + J_sync = np.sign(vec, dtype=vijs.dtype) + + return J_sync + + def sync_viis(self, vijs, viis): + """ + Given a set of synchronized pairwise outer products vijs, J-synchronize the set of + outer products viis. + + :param vijs: An (n-choose-2)x3x3 array where each 3x3 slice holds an estimate for the corresponding + outer-product vi*vj^T between the third rows of the rotation matrices Ri and Rj. Each estimate + might have a spurious J independently of other estimates. + + :param viis: An n_imgx3x3 array where the i'th slice holds an estimate for the outer product vi*vi^T + between the third row of matrix Ri and itself. Each estimate might have a spurious J independently + of other estimates. + + :return: J-synchronized viis. + """ + + # Synchronize viis + # We use the fact that if v_ii and v_ij are of the same handedness, then v_ii @ v_ij = v_ij. + # If they are opposite handed then Jv_iiJ @ v_ij = v_ij. We compare each v_ii against all + # previously synchronized v_ij to get a consensus on the handedness of v_ii. + _, pairs_to_linear = all_pairs(self.n_img, return_map=True) + for i in range(self.n_img): + vii = viis[i] + vii_J = J_conjugate(vii) + J_consensus = 0 + for j in range(self.n_img): + if j < i: + idx = pairs_to_linear[j, i] + vji = vijs[idx] + + err1 = norm(vji @ vii - vji) + err2 = norm(vji @ vii_J - vji) + + elif j > i: + idx = pairs_to_linear[i, j] + vij = vijs[idx] + + err1 = norm(vii @ vij - vij) + err2 = norm(vii_J @ vij - vij) + + else: + continue + + # Accumulate J consensus + if err1 < err2: + J_consensus -= 1 + else: + J_consensus += 1 + + if J_consensus > 0: + viis[i] = vii_J + return viis + + def _signs_times_v(self, vijs, vec): + """ + Multiplication of the J-synchronization matrix by a candidate eigenvector. + + The J-synchronization matrix is a matrix representation of the handedness graph, Gamma, whose set of + nodes consists of the estimates vijs and whose set of edges consists of the undirected edges between + all triplets of estimates vij, vjk, and vik, where i i: - idx = pairs_to_linear[i, j] - vij = vijs[idx] - - err1 = norm(vii @ vij - vij) - err2 = norm(vii_J @ vij - vij) - - else: - continue - - # Accumulate J consensus - if err1 < err2: - J_consensus -= 1 - else: - J_consensus += 1 - - if J_consensus > 0: - viis[i] = vii_J + vijs = self.J_sync.global_J_sync(vijs) + + # Determine relative handedness of viis, given synchronized vijs. + viis = self.J_sync.sync_viis(vijs, viis) + return vijs, viis ################################################# @@ -266,8 +230,8 @@ def _self_clmatrix_c3_c4(self): # Compute the correlation over all shifts. # Generate Shifts. r_max = pf.shape[-1] - shifts, shift_phases, _ = self._generate_shift_phase_and_filter( - r_max, max_shift_1d, shift_step + shifts, shift_phases, _ = _generate_shift_phase_and_filter( + r_max, max_shift_1d, shift_step, self.dtype ) n_shifts = len(shifts) @@ -356,17 +320,21 @@ def _estimate_all_Riis_c3_c4(self, sclmatrix): return Riis - def _estimate_all_Rijs_c3_c4(self, clmatrix): + def _estimate_all_Rijs_c3_c4(self): """ Estimate Rijs using the voting method. """ - n_img = self.n_img - n_theta = self.n_theta - pairs = all_pairs(n_img) + pairs = all_pairs(self.n_img) Rijs = np.zeros((len(pairs), 3, 3)) for idx, (i, j) in enumerate(pairs): - Rijs[idx] = self._syncmatrix_ij_vote_3n( - clmatrix, i, j, np.arange(n_img), n_theta + Rijs[idx] = _syncmatrix_ij_vote_3n( + self.clmatrix, + i, + j, + np.arange(self.n_img), + self.n_theta, + self.hist_bin_width, + self.full_width, ) return Rijs @@ -449,201 +417,3 @@ def _local_J_sync_c3_c4(self, Rijs, Riis): vijs[idx] = opts[min_idx] return vijs, viis - - ####################################### - # Secondary Methods for Global J Sync # - ####################################### - - def _J_sync_power_method(self, vijs): - """ - Calculate the leading eigenvector of the J-synchronization matrix - using the power method. - - As the J-synchronization matrix is of size (n-choose-2)x(n-choose-2), we - use the power method to compute the eigenvalues and eigenvectors, - while constructing the matrix on-the-fly. - - :param vijs: (n-choose-2)x3x3 array of estimates of relative orientation matrices. - - :return: An array of length n-choose-2 consisting of 1 or -1, where the sign of the - i'th entry indicates whether the i'th relative orientation matrix will be J-conjugated. - """ - - # Set power method tolerance and maximum iterations. - epsilon = self.epsilon - max_iters = self.max_iters - - # Initialize candidate eigenvectors - n_vijs = vijs.shape[0] - vec = randn(n_vijs, seed=self.seed) - vec = vec / norm(vec) - residual = 1 - itr = 0 - - # Power method iterations - logger.info( - "Initiating power method to estimate J-synchronization matrix eigenvector." - ) - while itr < max_iters and residual > epsilon: - itr += 1 - # Note, this appears to need double precision for accuracy in the following division. - vec_new = self._signs_times_v(vijs, vec).astype(np.float64, copy=False) - vec_new = vec_new / norm(vec_new) - residual = norm(vec_new - vec) - vec = vec_new - logger.info( - f"Iteration {itr}, residual {round(residual, 5)} (target {epsilon})" - ) - - # We need only the signs of the eigenvector - J_sync = np.sign(vec) - - return J_sync - - def _signs_times_v(self, vijs, vec): - """ - Multiplication of the J-synchronization matrix by a candidate eigenvector. - - The J-synchronization matrix is a matrix representation of the handedness graph, Gamma, whose set of - nodes consists of the estimates vijs and whose set of edges consists of the undirected edges between - all triplets of estimates vij, vjk, and vik, where i4. @@ -65,22 +73,24 @@ def __init__( super().__init__( src, - symmetry=symmetry, n_rad=n_rad, n_theta=n_theta, max_shift=max_shift, shift_step=shift_step, - epsilon=epsilon, - max_iters=max_iters, - degree_res=degree_res, - seed=seed, mask=mask, **kwargs, ) + self._check_symmetry(symmetry) + self.epsilon = epsilon + self.max_iters = max_iters + self.degree_res = degree_res + self.seed = seed self.n_points_sphere = n_points_sphere self.equator_threshold = equator_threshold + self.J_sync = JSync(src.n, self.epsilon, self.max_iters, self.seed) + def _check_symmetry(self, symmetry): if symmetry is None: raise NotImplementedError( @@ -102,7 +112,27 @@ def estimate_rotations(self): :return: Array of rotation matrices, size n_imgx3x3. """ - super().estimate_rotations() + vijs, viis = self._estimate_relative_viewing_directions() + + logger.info("Performing global handedness synchronization.") + vijs, viis = self._global_J_sync(vijs, viis) + + logger.info("Estimating third rows of rotation matrices.") + vis = _estimate_third_rows(vijs, viis) + + logger.info("Estimating in-plane rotations and rotations matrices.") + Ris = _estimate_inplane_rotations( + vis, + self.pf, + self.max_shift, + self.shift_step, + self.order, + self.degree_res, + ) + + self.rotations = Ris + + return self.rotations def _estimate_relative_viewing_directions(self): logger.info(f"Estimating relative viewing directions for {self.n_img} images.") @@ -123,8 +153,8 @@ def _estimate_relative_viewing_directions(self): # Generate shift phases. r_max = pf.shape[-1] - shifts, shift_phases, _ = self._generate_shift_phase_and_filter( - r_max, self.max_shift, self.shift_step + shifts, shift_phases, _ = _generate_shift_phase_and_filter( + r_max, self.max_shift, self.shift_step, self.dtype ) n_shifts = len(shifts) @@ -286,6 +316,31 @@ def _compute_cls_inds(self, Ris_tilde, R_theta_ijs): cij_inds[i, j, :, 1] = c2s return cij_inds + def _global_J_sync(self, vijs, viis): + """ + Global J-synchronization of all third row outer products. Given 3x3 matrices vijs and viis, each + of which might contain a spurious J (ie. vij = J*vi*vj^T*J instead of vij = vi*vj^T), + we return vijs and viis that all have either a spurious J or not. + + :param vijs: An (n-choose-2)x3x3 array where each 3x3 slice holds an estimate for the corresponding + outer-product vi*vj^T between the third rows of the rotation matrices Ri and Rj. Each estimate + might have a spurious J independently of other estimates. + + :param viis: An n_imgx3x3 array where the i'th slice holds an estimate for the outer product vi*vi^T + between the third row of matrix Ri and itself. Each estimate might have a spurious J independently + of other estimates. + + :return: vijs, viis all of which have a spurious J or not. + """ + + # Determine relative handedness of vijs. + vijs = self.J_sync.global_J_sync(vijs) + + # Determine relative handedness of viis, given synchronized vijs. + viis = self.J_sync.sync_viis(vijs, viis) + + return vijs, viis + @staticmethod def relative_rots_to_cl_indices(relative_rots, n_theta): """ @@ -300,8 +355,8 @@ def relative_rots_to_cl_indices(relative_rots, n_theta): c1s = np.array((-relative_rots[:, 1, 2], relative_rots[:, 0, 2])).T c2s = np.array((relative_rots[:, 2, 1], -relative_rots[:, 2, 0])).T - c1s = cl_angles_to_ind(c1s, n_theta) - c2s = cl_angles_to_ind(c2s, n_theta) + c1s = _cl_angles_to_ind(c1s, n_theta) + c2s = _cl_angles_to_ind(c2s, n_theta) inds = np.where(c1s >= n_theta // 2) c1s[inds] -= n_theta // 2 @@ -333,7 +388,7 @@ def generate_candidate_rots(n, equator_threshold, order, degree_res, seed): while counter < n: third_row = randn(3) third_row /= anorm(third_row, axes=(-1,)) - Ri_tilde = complete_third_row_to_rot(third_row) + Ri_tilde = _complete_third_row_to_rot(third_row) # Exclude candidates that represent equator images. Equator candidates # induce collinear self-common-lines, which always have perfect correlation. diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index f8022d3db9..e1730bf35e 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -10,6 +10,8 @@ from aspire.utils.random import randn from aspire.volume import DnSymmetryGroup +from .commonline_utils import _generate_shift_phase_and_filter + logger = logging.getLogger(__name__) @@ -131,8 +133,8 @@ def _compute_shifted_pf(self): # Generate shift phases. r_max = pf.shape[-1] max_shift_1d = np.ceil(2 * np.sqrt(2) * self.max_shift) - shifts, shift_phases, _ = self._generate_shift_phase_and_filter( - r_max, max_shift_1d, self.shift_step + shifts, shift_phases, _ = _generate_shift_phase_and_filter( + r_max, max_shift_1d, self.shift_step, self.dtype ) self.n_shifts = len(shifts) diff --git a/src/aspire/abinitio/commonline_sync.py b/src/aspire/abinitio/commonline_sync.py index 6fcc8daddc..6ed774958b 100644 --- a/src/aspire/abinitio/commonline_sync.py +++ b/src/aspire/abinitio/commonline_sync.py @@ -2,14 +2,15 @@ import numpy as np -from aspire.abinitio import CLOrient3D, SyncVotingMixin +from aspire.abinitio import CLOrient3D +from aspire.abinitio.sync_voting import _rotratio_eulerangle_vec, _vote_ij from aspire.utils import nearest_rotations from aspire.utils.matlab_compat import stable_eigsh logger = logging.getLogger(__name__) -class CLSyncVoting(CLOrient3D, SyncVotingMixin): +class CLSyncVoting(CLOrient3D): """ Define a class to estimate 3D orientations using synchronization matrix and voting method. @@ -201,9 +202,11 @@ def _syncmatrix_ij_vote(self, clmatrix, i, j, k_list, n_theta): :return: The (i,j) rotation block of the synchronization matrix """ - _, good_k = self._vote_ij(clmatrix, n_theta, i, j, k_list) + _, good_k = _vote_ij( + clmatrix, n_theta, i, j, k_list, self.hist_bin_width, self.full_width + ) - rots = self._rotratio_eulerangle_vec(clmatrix, i, j, good_k, n_theta) + rots = _rotratio_eulerangle_vec(clmatrix, i, j, good_k, n_theta) if rots is not None: rot_mean = np.mean(rots, 0) diff --git a/src/aspire/abinitio/commonline_sync3n.py b/src/aspire/abinitio/commonline_sync3n.py index ed7ca94048..7be4b37c2c 100644 --- a/src/aspire/abinitio/commonline_sync3n.py +++ b/src/aspire/abinitio/commonline_sync3n.py @@ -6,14 +6,15 @@ from numpy.linalg import norm from scipy.optimize import curve_fit -from aspire.abinitio import CLOrient3D, SyncVotingMixin +from aspire.abinitio import CLOrient3D +from aspire.abinitio.sync_voting import _syncmatrix_ij_vote_3n from aspire.utils import J_conjugate, all_pairs, nearest_rotations, random, tqdm, trange from aspire.utils.matlab_compat import stable_eigsh logger = logging.getLogger(__name__) -class CLSync3N(CLOrient3D, SyncVotingMixin): +class CLSync3N(CLOrient3D): """ Define a class to estimate 3D orientations using common lines Sync3N methods (2017). @@ -957,8 +958,14 @@ def _estimate_all_Rijs_host(self, clmatrix): Rijs = np.zeros((len(self._pairs), 3, 3)) for idx, (i, j) in enumerate(tqdm(self._pairs, desc="Estimate Rijs")): - Rijs[idx] = self._syncmatrix_ij_vote_3n( - clmatrix, i, j, np.arange(n_img), n_theta + Rijs[idx] = _syncmatrix_ij_vote_3n( + clmatrix, + i, + j, + np.arange(n_img), + n_theta, + self.hist_bin_width, + self.full_width, ) return Rijs diff --git a/src/aspire/abinitio/commonline_utils.py b/src/aspire/abinitio/commonline_utils.py index bfd1e5ce5f..45352bf026 100644 --- a/src/aspire/abinitio/commonline_utils.py +++ b/src/aspire/abinitio/commonline_utils.py @@ -4,12 +4,12 @@ from numpy.linalg import eigh, norm from aspire.operators import PolarFT -from aspire.utils import Rotation, all_pairs, anorm, tqdm +from aspire.utils import J_conjugate, Rotation, all_pairs, anorm, cyclic_rotations, tqdm logger = logging.getLogger(__name__) -def estimate_third_rows(vijs, viis): +def _estimate_third_rows(vijs, viis): """ Find the third row of each rotation matrix given a collection of matrices representing the outer products of the third rows from each rotation matrix. @@ -59,39 +59,69 @@ def estimate_third_rows(vijs, viis): return vis -def estimate_inplane_rotations(cl_class, vis): +def _generate_shift_phase_and_filter(r_max, max_shift, shift_step, dtype): """ - Estimate the rotation matrices for each image by constructing arbitrary rotation matrices - populated with the given third rows, vis, and then rotating by an appropriate in-plane rotation. + Prepare the shift phases and generate filter for common-line detection - :cl_class: A commonlines class instance. - :param vis: An n_imgx3 array where the i'th row holds the estimate for the third row of - the i'th rotation matrix. + The shift phases are pre-defined in a range of max_shift that can be + applied to maximize the common line calculation. The common-line filter + is also applied to the radial direction for easier detection. + + :param r_max: Maximum index for common line detection. + :param max_shift: Maximum value of 1D shift (in pixels) to search. + :param shift_step: Resolution of shift estimation in pixels. + :param dtype: dtype for shift phases and filter. + :return: shift phases matrix and common lines filter. + """ + + # Number of shifts to try + n_shifts = int(np.ceil(2 * max_shift / shift_step + 1)) + + # only half of ray, excluding the DC component. + rk = np.arange(1, r_max + 1, dtype=dtype) + + # Generate all shift phases + shifts = -max_shift + shift_step * np.arange(n_shifts, dtype=dtype) + shift_phases = np.exp(np.outer(shifts, -2 * np.pi * 1j * rk / (2 * r_max + 1))) + # Set filter for common-line detection + h = np.sqrt(np.abs(rk)) * np.exp(-np.square(rk) / (2 * (r_max / 4) ** 2)) + + return shifts, shift_phases, h + +def _estimate_inplane_rotations(vis, pf, max_shift, shift_step, order, degree_res): + """ + Estimate the rotation matrices for each image of a cyclically symmetric molecule by + constructing arbitrary rotation matrices populated with the given third rows, vis, and + then rotating by an appropriate in-plane rotation. + + :param vis: An n_imgx3 array where the i'th row holds the estimate for the third row of + the i'th rotation matrix. + :param pf: The polar Fourier transform of the source images, shape (n_img, n_theta/2, n_rad). + :param max_shift: Maximum range for shifts (in pixels) for estimating in-plane rotations. + :param shift_step: Shift step (in pixels) for estimating in-plane rotations. + :param order: Cyclic order. + :param degree_res: Resolution (in degrees) of in-plane rotation to search over. :return: Rotation matrices Ris and in-plane rotation matrices R_thetas, both size n_imgx3x3. """ - pf = cl_class.pf - n_img = cl_class.n_img - n_theta = cl_class.n_theta - max_shift_1d = cl_class.max_shift - shift_step = cl_class.shift_step - order = cl_class.order - degree_res = cl_class.degree_res + n_img = vis.shape[0] + dtype = vis.dtype + n_theta = pf.shape[1] * 2 # Step 1: Construct all rotation matrices Ri_tildes whose third rows are equal to # the corresponding third rows vis. - Ri_tildes = complete_third_row_to_rot(vis) + Ri_tildes = _complete_third_row_to_rot(vis) # Step 2: Construct all in-plane rotation matrices, R_theta_ijs. max_angle = (360 // order) * order theta_ijs = np.arange(0, max_angle, degree_res) * np.pi / 180 - R_theta_ijs = Rotation.about_axis("z", theta_ijs, dtype=cl_class.dtype).matrices + R_theta_ijs = Rotation.about_axis("z", theta_ijs, dtype=dtype).matrices # Step 3: Compute the correlation over all shifts. # Generate shifts. r_max = pf.shape[-1] - shifts, shift_phases, _ = cl_class._generate_shift_phase_and_filter( - r_max, max_shift_1d, shift_step + shifts, shift_phases, _ = _generate_shift_phase_and_filter( + r_max, max_shift, shift_step, dtype ) n_shifts = len(shifts) @@ -100,7 +130,7 @@ def estimate_inplane_rotations(cl_class, vis): # and theta_i in [0, 2pi/order) is the in-plane rotation angle for the i'th image. Q = np.zeros((n_img, n_img), dtype=complex) - # Reconstruct the full polar Fourier for use in correlation. cl_class.pf only consists of + # Reconstruct the full polar Fourier for use in correlation. pf only consists of # rays in the range [180, 360), with shape (n_img, n_theta//2, n_rad-1). pf = PolarFT.half_to_full(pf) @@ -108,87 +138,83 @@ def estimate_inplane_rotations(cl_class, vis): pf /= norm(pf, axis=-1)[..., np.newaxis] n_pairs = n_img * (n_img - 1) // 2 - with tqdm(total=n_pairs) as pbar: - idx = 0 - # Note: the ordering of i and j in these loops should not be changed as - # they correspond to the ordered tuples (i, j), for i 1e-12): - logger.warning( - f"Globally Consistent Angular Reconstruction (GCAR) exists" - f" numerical problem: abs(cos_phi2) > 1, with the" - f" difference of {np.abs(cos_phi2)-1}." - ) - cos_phi2 = np.clip(cos_phi2, -1, 1) - - # Store angles between i and j induced by each third image k. - phis = cos_phi2 - # Sore good indices of l in k_list of the image that creates that angle. - inds = k_list[good_idx] - - if phis.shape[0] == 0: - return None, [] - - # Parameters used to compute the smoothed angle histogram. - ntics = int(180 / self.hist_bin_width) - angles_grid = np.linspace(0, 180, ntics + 1, True) - - # Get angles between images i and j for computing the histogram - angles = np.arccos(phis[:]) * 180 / np.pi - - # Angles that are up to 10 degrees apart are considered - # similar. This sigma ensures that the width of the density - # estimation kernel is roughly 10 degrees. For 15 degrees, the - # value of the kernel is negligible. - sigma = getattr(self, "sigma", 3.0) # get from class if avail - - # Compute the histogram of the angles between images i and j - angles_distances = angles_grid[None, :] - angles[:, None] - angles_hist = np.sum(np.exp(-(angles_distances**2) / (2 * sigma**2)), axis=0) - - # We assume that at the location of the peak we get the true angle - # between images i and j. Find all third images k, that induce an - # angle between i and j that is at most 10 off the true angle. - # Even for debugging, don't put a value that is smaller than two - # tics, since the peak might move a little bit due to wrong k images - # that accidentally fall near the peak. - peak_idx = angles_hist.argmax() - - if self.full_width == -1: - # Adaptive width (MATLAB) - # Look for the estimations in the peak of the histogram - w_theta_needed = 0 - idx = [] - while sum(idx) == 0: - w_theta_needed += self.hist_bin_width # widen peak as needed - idx = np.abs(angles - angles_grid[peak_idx]) < w_theta_needed - if w_theta_needed > self.hist_bin_width: - logger.info( - f"Adaptive width {w_theta_needed} required for ({i},{j}), found {sum(idx)} indices." - ) - else: - # Fixed width - idx = np.abs(angles - angles_grid[peak_idx]) < self.full_width - - good_k = inds[idx] - alpha = np.arccos(phis[idx]) - - return alpha, good_k.astype("int") - - def _get_cos_phis(self, cl_diff1, cl_diff2, cl_diff3, n_theta, sync=False): - """ - Calculate cos values of rotation angles between i and j images - - Given C1, C2, and C3 are unit circles of image i, j, and k, compute - resulting cos values of rotation angles between i an j images when both - of them are intersecting with k. - - To ensure that the smallest singular value is big enough, controlled by - the determinant of the matrix, - C=[ 1 c1 c2 ; - c1 1 c3 ; - c2 c3 1 ], - we therefore use the condition below - 1+2*c1*c2*c3-(c1^2+c2^2+c3^2) > 1.0e-5, - so the matrix is far from singular. - - :param cl_diff1: Difference of common line indices on C1 created by - its intersection with C3 and C2 - :param cl_diff2: Difference of common line indices on C2 created by - its intersection with C1 and C3 - :param cl_diff3: Difference of common line indices on C3 created by - its intersection with C2 and C1 - :param n_theta: The number of points in the theta direction (common lines) - :param sync: Perform 180 degree ambiguity synchronization. - :return: cos values of rotation angles between i and j images - and indices for good k - """ - - # Calculate the theta values from the differences of common line indices - # C1, C2, and C3 are unit circles of image i, j, and k - # theta1 is the angle on C1 created by its intersection with C3 and C2. - # theta2 is the angle on C2 created by its intersection with C1 and C3. - # theta3 is the angle on C3 created by its intersection with C2 and C1. - theta1 = cl_diff1 * 2 * np.pi / n_theta - theta2 = cl_diff2 * 2 * np.pi / n_theta - theta3 = cl_diff3 * 2 * np.pi / n_theta - - c1 = np.cos(theta1) - c2 = np.cos(theta2) - c3 = np.cos(theta3) - - # Each common-line corresponds to a point on the unit sphere. Denote the - # coordinates of these points by (Pix, Piy Piz), and put them in the matrix - # M=[ P1x P2x P3x ; - # P1y P2y P3y ; - # P1z P2z P3z ]. - # - # Then the matrix - # C=[ 1 c1 c2 ; - # c1 1 c3 ; - # c2 c3 1 ], - # where c1, c2, c3 are given above, is given by C = M.T @ M. - # For the points P1, P2, and P3 to form a triangle on the unit sphere, a - # necessary and sufficient condition is for C to be positive definite. This - # is equivalent to - # 1+2*c1*c2*c3-(c1^2+c2^2+c3^2) > 0. - # However, this may result in a triangle that is too flat, that is, the - # angle between the projections is very close to zero. We therefore use the - # condition below - # 1+2*c1*c2*c3-(c1^2+c2^2+c3^2) > 1.0e-5. - # This ensures that the smallest singular value (which is actually - # controlled by the determinant of C) is big enough, so the matrix is far - # from singular. This condition is equivalent to computing the singular - # values of C, followed by checking that the smallest one is big enough. - - cond = 1 + 2 * c1 * c2 * c3 - (np.square(c1) + np.square(c2) + np.square(c3)) - good_idx = np.nonzero(cond > 1e-5)[0] - - # Calculated cos values of angle between i and j images - if sync: - # MATLAB - cos_phi2 = (c3[good_idx] - c1[good_idx] * c2[good_idx]) / ( - np.sqrt(1 - c1[good_idx] ** 2) * np.sqrt(1 - c2[good_idx] ** 2) - ) + :return: The rotation matrix that takes image i to image j for good index of k. + """ - # Some synchronization must be applied when common line is - # out by 180 degrees. - # Here fix the angles between c_ij(c_ji) and c_ik(c_jk) to be smaller than pi/2, - # otherwise there will be an ambiguity between alpha and pi-alpha. - TOL_idx = 1e-12 + if i == j: + return [] - # Select only good_idx - theta1 = theta1[good_idx] - theta2 = theta2[good_idx] - theta3 = theta3[good_idx] + # Prepare the theta values from the differences of common line indices + # C1, C2, and C3 are unit circles of image i, j, and k + # cl_diff1 is for the angle on C1 created by its intersection with C3 and C2. + # cl_diff2 is for the angle on C2 created by its intersection with C1 and C3. + # cl_diff3 is for the angle on C3 created by its intersection with C2 and C1. + cl_diff1 = clmatrix[i, good_k] - clmatrix[i, j] # for theta1 + cl_diff2 = clmatrix[j, good_k] - clmatrix[j, i] # for theta2 + cl_diff3 = clmatrix[good_k, j] - clmatrix[good_k, i] # for theta3 - # Check sync conditions - ind1 = (theta1 > (np.pi + TOL_idx)) | ( - (theta1 < -TOL_idx) & (theta1 > -np.pi) - ) - ind2 = (theta2 > (np.pi + TOL_idx)) | ( - (theta2 < -TOL_idx) & (theta2 > -np.pi) - ) - align180 = (ind1 & ~ind2) | (~ind1 & ind2) - - # Apply sync - cos_phi2[align180] = -cos_phi2[align180] - else: - # Python - cos_phi2 = (c3[good_idx] - c1[good_idx] * c2[good_idx]) / ( - np.sin(theta1[good_idx]) * np.sin(theta2[good_idx]) + # Calculate the cos values of rotation angles between i an j images for good k images + c_alpha, good_idx = _get_cos_phis(cl_diff1, cl_diff2, cl_diff3, n_theta, sync=False) + + if len(c_alpha) == 0: + return None + alpha = np.arccos(c_alpha) + + # Convert the Euler angles with ZYZ conversion to rotation matrices + angles = np.zeros((alpha.shape[0], 3)) + angles[:, 0] = clmatrix[i, j] * 2 * np.pi / n_theta + np.pi / 2 + angles[:, 1] = alpha + angles[:, 2] = -np.pi / 2 - clmatrix[j, i] * 2 * np.pi / n_theta + r = Rotation.from_euler(angles).matrices + + return r[good_idx, :, :] + + +def _vote_ij( + clmatrix, n_theta, i, j, k_list, hist_bin_width, full_width, sigma=3.0, sync=False +): + """ + Apply the voting algorithm for images i and j. + + clmatrix is the common lines matrix, constructed using angular resolution, + n_theta. k_list are the images to be used for voting of the pair of images + (i ,j). + + :param clmatrix: The common lines matrix + :param n_theta: The number of points in the theta direction (common lines) + :param i: The i image + :param j: The j image + :param k_list: The list of images for the third image for voting algorithm + :param hist_bin_width: Bin width in smoothing histogram (degrees). + :param full_width: Selection width around smoothed histogram peak (degrees). + `adaptive` will attempt to automatically find the smallest number of + `hist_bin_width`s required to find at least one valid image index. + :param sigma: Voting contribution smoothing factor. Default is 3.0. + :param sync: Perform 180 degree ambiguity synchronization. + + :return: (alpha, good_k), angles and list of all third images + in the peak of the histogram corresponding to the pair of + images (i,j) + """ + + if i == j or clmatrix[i, j] == -1: + return None, [] + + # Some of the entries in clmatrix may be zero if we cleared + # them due to small correlation, or if for each image + # we compute intersections with only some of the other images. + # + # Note that as long as the diagonal of the common lines matrix is + # -1, the conditions (i != j) && (j != k) are not needed, since + # if i == j then clmatrix[i, k] == -1 and similarly for i == k or + # j == k. Thus, the previous voting code (from the JSB paper) is + # correct even though it seems that we should test also that + # (i != j) && (i != k) && (j != k), and only (i != j) && (i != k) + # as tested there. + cl_idx12 = clmatrix[i, j] + cl_idx21 = clmatrix[j, i] + k_list = k_list[ + (k_list != i) & (clmatrix[i, k_list] != -1) & (clmatrix[j, k_list] != -1) + ] + cl_idx13 = clmatrix[i, k_list] + cl_idx31 = clmatrix[k_list, i] + cl_idx23 = clmatrix[j, k_list] + cl_idx32 = clmatrix[k_list, j] + + # Prepare the theta values from the differences of common line indices + # C1, C2, and C3 are unit circles of image i, j, and k + # cl_diff1 is for the angle on C1 created by its intersection with C3 and C2. + # cl_diff2 is for the angle on C2 created by its intersection with C1 and C3. + # cl_diff3 is for the angle on C3 created by its intersection with C2 and C1. + cl_diff1 = cl_idx13 - cl_idx12 + cl_diff2 = cl_idx23 - cl_idx21 + cl_diff3 = cl_idx32 - cl_idx31 + + # Calculate the cos values of rotation angles between i an j images for good k images + cos_phi2, good_idx = _get_cos_phis(cl_diff1, cl_diff2, cl_diff3, n_theta, sync=sync) + + if np.any(np.abs(cos_phi2) - 1 > 1e-12): + logger.warning( + f"Globally Consistent Angular Reconstruction (GCAR) exists" + f" numerical problem: abs(cos_phi2) > 1, with the" + f" difference of {np.abs(cos_phi2)-1}." + ) + cos_phi2 = np.clip(cos_phi2, -1, 1) + + # Store angles between i and j induced by each third image k. + phis = cos_phi2 + # Sore good indices of l in k_list of the image that creates that angle. + inds = k_list[good_idx] + + if phis.shape[0] == 0: + return None, [] + + # Parameters used to compute the smoothed angle histogram. + ntics = int(180 / hist_bin_width) + angles_grid = np.linspace(0, 180, ntics + 1, True) + + # Get angles between images i and j for computing the histogram + angles = np.arccos(phis[:]) * 180 / np.pi + + # Angles that are up to 10 degrees apart are considered + # similar. `sigma` ensures that the width of the density + # estimation kernel is roughly 10 degrees. For 15 degrees, the + # value of the kernel is negligible. + + # Compute the histogram of the angles between images i and j + angles_distances = angles_grid[None, :] - angles[:, None] + angles_hist = np.sum(np.exp(-(angles_distances**2) / (2 * sigma**2)), axis=0) + + # We assume that at the location of the peak we get the true angle + # between images i and j. Find all third images k, that induce an + # angle between i and j that is at most 10 off the true angle. + # Even for debugging, don't put a value that is smaller than two + # tics, since the peak might move a little bit due to wrong k images + # that accidentally fall near the peak. + peak_idx = angles_hist.argmax() + + if full_width == -1: + # Adaptive width (MATLAB) + # Look for the estimations in the peak of the histogram + w_theta_needed = 0 + idx = [] + while sum(idx) == 0: + w_theta_needed += hist_bin_width # widen peak as needed + idx = np.abs(angles - angles_grid[peak_idx]) < w_theta_needed + if w_theta_needed > hist_bin_width: + logger.info( + f"Adaptive width {w_theta_needed} required for ({i},{j}), found {sum(idx)} indices." ) + else: + # Fixed width + idx = np.abs(angles - angles_grid[peak_idx]) < full_width + + good_k = inds[idx] + alpha = np.arccos(phis[idx]) + + return alpha, good_k.astype("int") + + +def _get_cos_phis(cl_diff1, cl_diff2, cl_diff3, n_theta, sync=False): + """ + Calculate cos values of rotation angles between i and j images + + Given C1, C2, and C3 are unit circles of image i, j, and k, compute + resulting cos values of rotation angles between i an j images when both + of them are intersecting with k. + + To ensure that the smallest singular value is big enough, controlled by + the determinant of the matrix, + C=[ 1 c1 c2 ; + c1 1 c3 ; + c2 c3 1 ], + we therefore use the condition below + 1+2*c1*c2*c3-(c1^2+c2^2+c3^2) > 1.0e-5, + so the matrix is far from singular. + + :param cl_diff1: Difference of common line indices on C1 created by + its intersection with C3 and C2 + :param cl_diff2: Difference of common line indices on C2 created by + its intersection with C1 and C3 + :param cl_diff3: Difference of common line indices on C3 created by + its intersection with C2 and C1 + :param n_theta: The number of points in the theta direction (common lines) + :param sync: Perform 180 degree ambiguity synchronization. + + :return: cos values of rotation angles between i and j images + and indices for good k + """ + + # Calculate the theta values from the differences of common line indices + # C1, C2, and C3 are unit circles of image i, j, and k + # theta1 is the angle on C1 created by its intersection with C3 and C2. + # theta2 is the angle on C2 created by its intersection with C1 and C3. + # theta3 is the angle on C3 created by its intersection with C2 and C1. + theta1 = cl_diff1 * 2 * np.pi / n_theta + theta2 = cl_diff2 * 2 * np.pi / n_theta + theta3 = cl_diff3 * 2 * np.pi / n_theta + + c1 = np.cos(theta1) + c2 = np.cos(theta2) + c3 = np.cos(theta3) + + # Each common-line corresponds to a point on the unit sphere. Denote the + # coordinates of these points by (Pix, Piy Piz), and put them in the matrix + # M=[ P1x P2x P3x ; + # P1y P2y P3y ; + # P1z P2z P3z ]. + # + # Then the matrix + # C=[ 1 c1 c2 ; + # c1 1 c3 ; + # c2 c3 1 ], + # where c1, c2, c3 are given above, is given by C = M.T @ M. + # For the points P1, P2, and P3 to form a triangle on the unit sphere, a + # necessary and sufficient condition is for C to be positive definite. This + # is equivalent to + # 1+2*c1*c2*c3-(c1^2+c2^2+c3^2) > 0. + # However, this may result in a triangle that is too flat, that is, the + # angle between the projections is very close to zero. We therefore use the + # condition below + # 1+2*c1*c2*c3-(c1^2+c2^2+c3^2) > 1.0e-5. + # This ensures that the smallest singular value (which is actually + # controlled by the determinant of C) is big enough, so the matrix is far + # from singular. This condition is equivalent to computing the singular + # values of C, followed by checking that the smallest one is big enough. + + cond = 1 + 2 * c1 * c2 * c3 - (np.square(c1) + np.square(c2) + np.square(c3)) + good_idx = np.nonzero(cond > 1e-5)[0] + + # Calculated cos values of angle between i and j images + if sync: + # MATLAB + cos_phi2 = (c3[good_idx] - c1[good_idx] * c2[good_idx]) / ( + np.sqrt(1 - c1[good_idx] ** 2) * np.sqrt(1 - c2[good_idx] ** 2) + ) + + # Some synchronization must be applied when common line is + # out by 180 degrees. + # Here fix the angles between c_ij(c_ji) and c_ik(c_jk) to be smaller than pi/2, + # otherwise there will be an ambiguity between alpha and pi-alpha. + TOL_idx = 1e-12 + + # Select only good_idx + theta1 = theta1[good_idx] + theta2 = theta2[good_idx] + theta3 = theta3[good_idx] + + # Check sync conditions + ind1 = (theta1 > (np.pi + TOL_idx)) | ((theta1 < -TOL_idx) & (theta1 > -np.pi)) + ind2 = (theta2 > (np.pi + TOL_idx)) | ((theta2 < -TOL_idx) & (theta2 > -np.pi)) + align180 = (ind1 & ~ind2) | (~ind1 & ind2) + + # Apply sync + cos_phi2[align180] = -cos_phi2[align180] + else: + # Python + cos_phi2 = (c3[good_idx] - c1[good_idx] * c2[good_idx]) / ( + np.sin(theta1[good_idx]) * np.sin(theta2[good_idx]) + ) - return cos_phi2, good_idx + return cos_phi2, good_idx diff --git a/src/aspire/utils/__init__.py b/src/aspire/utils/__init__.py index ae781d823f..ae896ebeb8 100644 --- a/src/aspire/utils/__init__.py +++ b/src/aspire/utils/__init__.py @@ -1,5 +1,4 @@ from .types import complex_type, real_type, utest_tolerance # isort:skip - from .coor_trans import ( # isort:skip mean_aligned_angular_distance, cart2pol, diff --git a/tests/test_commonline_utils.py b/tests/test_commonline_utils.py new file mode 100644 index 0000000000..64d015d1dd --- /dev/null +++ b/tests/test_commonline_utils.py @@ -0,0 +1,118 @@ +import numpy as np +import pytest + +from aspire.abinitio import JSync +from aspire.abinitio.commonline_utils import ( + _complete_third_row_to_rot, + _estimate_third_rows, + build_outer_products, +) +from aspire.utils import J_conjugate, Rotation, randn, utest_tolerance + +DTYPES = [np.float32, np.float64] + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +def test_estimate_third_rows(dtype): + """ + Test we accurately estimate a set of 3rd rows of rotation matrices + given the 3rd row outer products vijs = vi @ vj.T and viis = vi @ vi.T. + """ + n_img = 20 + + # `build_outer_products` generates a set of ground truth 3rd rows + # of rotation matrices, then forms the outer products vijs = vi @ vj.T + # and viis = vi @ vi.T. + vijs, viis, gt_vis = build_outer_products(n_img, dtype) + + # Estimate third rows from outer products. + # Due to factorization of V, these might be negated third rows. + vis = _estimate_third_rows(vijs, viis) + + # Check if all-close up to difference of sign + ground_truth = np.sign(gt_vis[0, 0]) * gt_vis + estimate = np.sign(vis[0, 0]) * vis + np.testing.assert_allclose(ground_truth, estimate, rtol=1e-05, atol=1e-08) + + # Check dtype passthrough + assert vis.dtype == dtype + + +def test_complete_third_row(dtype): + """ + Test that `complete_third_row_to_rot` produces a proper rotations + given a set of 3rd rows. + """ + # Build random third rows. + r3 = randn(10, 3, seed=123).astype(dtype) + r3 /= np.linalg.norm(r3, axis=1)[..., np.newaxis] + + # Set first row to be identical with z-axis. + r3[0] = np.array([0, 0, 1], dtype=dtype) + + # Generate rotations. + R = _complete_third_row_to_rot(r3) + + # Check dtype passthrough + assert R.dtype == dtype + + # Assert that first rotation is the identity matrix. + np.testing.assert_allclose(R[0], np.eye(3, dtype=dtype)) + + # Assert that each rotation is orthogonal with determinant 1. + assert np.allclose( + R @ R.transpose((0, 2, 1)), np.eye(3, dtype=dtype), atol=utest_tolerance(dtype) + ) + assert np.allclose(np.linalg.det(R), 1) + + +def test_J_sync(dtype): + """ + Test that the J_sync `power_method` returns a set of signs indicating + the set of relative rotations that need to be J-conjugated to attain + global handedness consistency, and that `global_J_sync` returns the + ground truth rotations up to a spurious J-conjugation. + """ + n = 25 + rots = Rotation.generate_random_rotations(n, dtype=dtype).matrices + + # Generate ground truth and randomly J-conjugate relative rotations, + # keeping track of the signs associated with J-conjugated rotations. + n_choose_2 = (n * (n - 1)) // 2 + signs = np.random.randint(0, 2, n_choose_2) * 2 - 1 + Rijs_gt = np.zeros((n_choose_2, 3, 3), dtype=dtype) + Rijs_conjugated = np.zeros((n_choose_2, 3, 3), dtype=dtype) + ij = 0 + for i in range(n - 1): + Ri = rots[i] + for j in range(i + 1, n): + Rj = rots[j] + Rijs_gt[ij] = Rij = Ri.T @ Rj + if signs[ij] == -1: + Rij = J_conjugate(Rij) + Rijs_conjugated[ij] = Rij + ij += 1 + + # Initialize JSync instance with default params. + J_sync = JSync(n) + + # Perform power method and check that signs are correct up to + # multilication by -1. Also check dtype pass-through. + signs_est = J_sync.power_method(Rijs_conjugated) + np.testing.assert_allclose(signs[0] * signs, signs_est[0] * signs_est) + assert signs_est.dtype == dtype + + # Perform global J sync and check that rotations are correct up to + # a spurious J conjugation. Also check dtype pass-through. + Rijs_sync = J_sync.global_J_sync(Rijs_conjugated) + + # If the first is off by a J, J-conjugate the whole set. + if np.allclose(Rijs_gt[0], J_conjugate(Rijs_sync[0])): + Rijs_sync = J_conjugate(Rijs_sync) + + np.testing.assert_allclose(Rijs_sync, Rijs_gt) + assert Rijs_sync.dtype == dtype diff --git a/tests/test_orient_symmetric.py b/tests/test_orient_symmetric.py index d7c4c4716f..ed9c5a6904 100644 --- a/tests/test_orient_symmetric.py +++ b/tests/test_orient_symmetric.py @@ -1,17 +1,15 @@ import numpy as np import pytest -from numpy import pi, random -from numpy.linalg import det, norm from aspire.abinitio import ( CLSymmetryC2, CLSymmetryC3C4, CLSymmetryCn, - cl_angles_to_ind, - complete_third_row_to_rot, - estimate_third_rows, + build_outer_products, + g_sync, ) from aspire.abinitio.commonline_cn import MeanOuterProductEstimator +from aspire.abinitio.commonline_utils import _cl_angles_to_ind from aspire.source import Simulation from aspire.utils import ( J_conjugate, @@ -19,8 +17,6 @@ all_pairs, cyclic_rotations, mean_aligned_angular_distance, - randn, - utest_tolerance, ) from aspire.volume import CnSymmetricVolume @@ -127,7 +123,7 @@ def test_estimate_rotations(n_img, L, order, dtype): rots_gt = src.rotations # g-synchronize ground truth rotations. - rots_gt_sync = cl_symm.g_sync(rots_est, order, rots_gt) + rots_gt_sync = g_sync(rots_est, order, rots_gt) # Register estimates to ground truth rotations and check that the # mean angular distance between them is less than 3 degrees. @@ -141,9 +137,7 @@ def test_relative_rotations(n_img, L, order, dtype): src, cl_symm = source_orientation_objs(n_img, L, order, dtype) # Estimate relative viewing directions. - cl_symm.build_clmatrix() - cl = cl_symm.clmatrix - Rijs = cl_symm._estimate_all_Rijs_c3_c4(cl) + Rijs = cl_symm._estimate_all_Rijs_c3_c4() # Each Rij belongs to the set {Ri.Tg_n^sRj, JRi.Tg_n^sRjJ}, # s = 1, 2, ..., order. We find the mean squared error over @@ -326,8 +320,8 @@ def test_self_commonlines(n_img, L, order, dtype): # Get angle difference between scl_gt and scl. scl_diff1 = scl_gt - scl scl_diff2 = scl_gt - np.flip(scl, 1) # Order of indices might be switched. - scl_diff1_angle = scl_diff1 * 2 * pi / n_theta - scl_diff2_angle = scl_diff2 * 2 * pi / n_theta + scl_diff1_angle = scl_diff1 * 2 * np.pi / n_theta + scl_diff2_angle = scl_diff2 * 2 * np.pi / n_theta # cosine is invariant to 2pi, and abs is invariant to +-pi due to J-conjugation. # We take the mean deviation wrt to the two lines in each image. @@ -339,7 +333,7 @@ def test_self_commonlines(n_img, L, order, dtype): min_mean_angle_diff = scl_idx.choose(scl_diff_angle_mean) # Assert scl detection rate is 100% for 5 degree angle tolerance - angle_tol_err = 5 * pi / 180 + angle_tol_err = 5 * np.pi / 180 detection_rate = np.count_nonzero(min_mean_angle_diff < angle_tol_err) / len(scl) assert np.allclose(detection_rate, 1.0) @@ -484,45 +478,6 @@ def test_global_J_sync(n_img, dtype): assert np.allclose(viis, viis_sync) -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_estimate_third_rows(dtype): - n_img = 20 - - # Build outer products vijs, viis, and get ground truth third rows. - vijs, viis, gt_vis = build_outer_products(n_img, dtype) - - # Estimate third rows from outer products. - # Due to factorization of V, these might be negated third rows. - vis = estimate_third_rows(vijs, viis) - - # Check if all-close up to difference of sign - ground_truth = np.sign(gt_vis[0, 0]) * gt_vis - estimate = np.sign(vis[0, 0]) * vis - assert np.allclose(ground_truth, estimate) - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_complete_third_row(dtype): - # Build random third rows. - r3 = randn(10, 3, seed=123).astype(dtype) - r3 /= norm(r3, axis=1)[..., np.newaxis] - - # Set first row to be identical with z-axis. - r3[0] = np.array([0, 0, 1], dtype=dtype) - - # Generate rotations. - R = complete_third_row_to_rot(r3) - - # Assert that first rotation is the identity matrix. - assert np.allclose(R[0], np.eye(3, dtype=dtype)) - - # Assert that each rotation is orthogonal with determinant 1. - assert np.allclose( - R @ R.transpose((0, 2, 1)), np.eye(3, dtype=dtype), atol=utest_tolerance(dtype) - ) - assert np.allclose(det(R), 1) - - @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_dtype_pass_through(dtype): L = 16 @@ -558,31 +513,6 @@ def build_self_commonlines_matrix(n_theta, rots, order): return scl_gt -def build_outer_products(n_img, dtype): - # Build random third rows, ground truth vis (unit vectors) - gt_vis = np.zeros((n_img, 3), dtype=dtype) - for i in range(n_img): - random.seed(i) - v = random.randn(3) - gt_vis[i] = v / norm(v) - - # Find outer products viis and vijs for i Date: Thu, 16 Oct 2025 16:00:37 -0400 Subject: [PATCH 39/91] initial ImageSource.save with optics block --- src/aspire/source/image.py | 89 +++++++++++++++++++++++++++++++-- src/aspire/storage/starfile.py | 1 + tests/test_coordinate_source.py | 2 +- 3 files changed, 87 insertions(+), 5 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index f1dff74404..3395dd854d 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -1289,6 +1289,79 @@ def _populate_local_metadata(self): """ return [] + @staticmethod + def _prepare_relion_optics_blocks(metadata): + """ + Split metadata into RELION>=3.1 style `data_optics` and `data_particles` blocks. + + The optics block has one row per optics group with: + `_rlnOpticsGroup`, `_rlnOpticsGroupName`, and optics metadata columns. + The particle block keeps the remaining columns and includes a per-particle + `_rlnOpticsGroup` that references the optics block. + """ + # All possible optics group fields + # TODO: Check we have all of them + all_optics_fields = [ + "_rlnImagePixelSize", + "_rlnMicrographPixelSize", + "_rlnSphericalAberration", + "_rlnVoltage", + "_rlnImageSize", + "_rlnAmplitudeContrast", + ] + + # TODO: Need to add _rlnImageSize here and above + required_fields = ["_rlnSphericalAberration", "_rlnVoltage"] + pixel_fields = ["_rlnImagePixelSize", "_rlnMicrographPixelSize"] + + has_required = all(field in metadata for field in required_fields) + has_pixel_field = any(field in metadata for field in pixel_fields) + + # TODO: Add warning? + if not (has_required and has_pixel_field): + # Optics metadata incomplete, fall back to legacy single block. + return None, metadata + + # Collect all optics group fields present in metadata and determine unique optics groups. + optics_value_fields = [ + field for field in all_optics_fields if field in metadata + ] + n_rows = len(metadata["_rlnImagePixelSize"]) + + group_lookup = OrderedDict() # Stores distinct optics groups + optics_groups = np.empty(n_rows, dtype=int) + + for idx in range(n_rows): + signature = tuple(metadata[field][idx] for field in optics_value_fields) + if signature not in group_lookup: + group_lookup[signature] = len(group_lookup) + 1 + optics_groups[idx] = group_lookup[signature] + + metadata["_rlnOpticsGroup"] = optics_groups + + optics_block = OrderedDict() + optics_block["_rlnOpticsGroup"] = [] + optics_block["_rlnOpticsGroupName"] = [] + for field in optics_value_fields: + optics_block[field] = [] + + for signature, group_id in group_lookup.items(): + optics_block["_rlnOpticsGroup"].append(group_id) + optics_block["_rlnOpticsGroupName"].append(f"opticsGroup{group_id}") + for field, value in zip(optics_value_fields, signature): + optics_block[field].append(value) + + # Collect particle_block metadata + particle_block = OrderedDict() + if "_rlnOpticsGroup" in metadata: + particle_block["_rlnOpticsGroup"] = metadata["_rlnOpticsGroup"] + for key, value in metadata.items(): + if key in optics_value_fields or key == "_rlnOpticsGroup": + continue + particle_block[key] = value + + return optics_block, particle_block + def save_metadata(self, starfile_filepath, batch_size=512, save_mode=None): """ Save updated metadata to a STAR file @@ -1324,12 +1397,20 @@ def save_metadata(self, starfile_filepath, batch_size=512, save_mode=None): for x in np.char.split(metadata["_rlnImageName"].astype(np.str_), sep="@") ] + # Separate metadata into optics and particle blocks + optics_block, particle_block = self._prepare_relion_optics_blocks(metadata) + # initialize the star file object and save it odict = OrderedDict() - # since our StarFile only has one block, the convention is to save it with the header "data_", i.e. its name is blank - # if we had a block called "XYZ" it would be saved as "XYZ" - # thus we index the metadata block with "" - odict[""] = metadata + + # StarFile uses the `odict` keys to label the starfile block headers "data_(key)". Following RELION>=3.1 + # convention we label the blocks "data_optics" and "data_particles". + if optics_block is None: + odict["particles"] = particle_block + else: + odict["optics"] = optics_block + odict["particles"] = particle_block + out_star = StarFile(blocks=odict) out_star.write(starfile_filepath) return filename_indices diff --git a/src/aspire/storage/starfile.py b/src/aspire/storage/starfile.py index 98b3219607..9b55610935 100644 --- a/src/aspire/storage/starfile.py +++ b/src/aspire/storage/starfile.py @@ -135,6 +135,7 @@ def write(self, filepath): # create an empty Document _doc = cif.Document() filepath = str(filepath) + for name, block in self.blocks.items(): # construct new empty block _block = _doc.add_new_block(name) diff --git a/tests/test_coordinate_source.py b/tests/test_coordinate_source.py index d2f93d5f65..783790f145 100644 --- a/tests/test_coordinate_source.py +++ b/tests/test_coordinate_source.py @@ -535,7 +535,7 @@ def testSave(self): self.assertTrue(np.array_equal(imgs.asnumpy()[i], saved_mrcs_stack[i])) # assert that the star file has the correct metadata self.assertEqual( - list(saved_star[""].keys()), + list(saved_star["particles"].keys()), [ "_rlnImagePixelSize", "_rlnSymmetryGroup", From 0c167788b29fa0482f8125be60de06a8e1dbab8e Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 23 Oct 2025 16:02:32 -0400 Subject: [PATCH 40/91] add _rlnImageSize column to metadata when saving. --- src/aspire/source/image.py | 7 +++++-- src/aspire/storage/starfile.py | 1 - 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 3395dd854d..b9d38d67cc 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -1310,8 +1310,7 @@ def _prepare_relion_optics_blocks(metadata): "_rlnAmplitudeContrast", ] - # TODO: Need to add _rlnImageSize here and above - required_fields = ["_rlnSphericalAberration", "_rlnVoltage"] + required_fields = ["_rlnSphericalAberration", "_rlnVoltage", "_rlnImageSize"] pixel_fields = ["_rlnImagePixelSize", "_rlnMicrographPixelSize"] has_required = all(field in metadata for field in required_fields) @@ -1397,6 +1396,10 @@ def save_metadata(self, starfile_filepath, batch_size=512, save_mode=None): for x in np.char.split(metadata["_rlnImageName"].astype(np.str_), sep="@") ] + # Populate _rlnImageSize column, required for optics_block below + if "_rlnImageSize" not in metadata: + metadata["_rlnImageSize"] = np.full(self.n, self.L, dtype=int) + # Separate metadata into optics and particle blocks optics_block, particle_block = self._prepare_relion_optics_blocks(metadata) diff --git a/src/aspire/storage/starfile.py b/src/aspire/storage/starfile.py index 9b55610935..98b3219607 100644 --- a/src/aspire/storage/starfile.py +++ b/src/aspire/storage/starfile.py @@ -135,7 +135,6 @@ def write(self, filepath): # create an empty Document _doc = cif.Document() filepath = str(filepath) - for name, block in self.blocks.items(): # construct new empty block _block = _doc.add_new_block(name) From 421fcf59229a213f65f1307213cba5d967d04b99 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 24 Oct 2025 08:50:09 -0400 Subject: [PATCH 41/91] Add _rlnImageSize to coordinate source test. --- tests/test_coordinate_source.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_coordinate_source.py b/tests/test_coordinate_source.py index 783790f145..c8fb20ca8e 100644 --- a/tests/test_coordinate_source.py +++ b/tests/test_coordinate_source.py @@ -542,6 +542,7 @@ def testSave(self): "_rlnImageName", "_rlnCoordinateX", "_rlnCoordinateY", + "_rlnImageSize", ], ) # assert that all the correct coordinates were saved From d89b033c557592143dee76e8661913be6d6c0b71 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 24 Oct 2025 09:53:37 -0400 Subject: [PATCH 42/91] use rlnImageName to get n_rows --- src/aspire/source/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index b9d38d67cc..cdfa4205be 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -1325,7 +1325,7 @@ def _prepare_relion_optics_blocks(metadata): optics_value_fields = [ field for field in all_optics_fields if field in metadata ] - n_rows = len(metadata["_rlnImagePixelSize"]) + n_rows = len(metadata["_rlnImageName"]) group_lookup = OrderedDict() # Stores distinct optics groups optics_groups = np.empty(n_rows, dtype=int) From fc511a1764d63b8c79cbc6e5f1150ea39d1a52cd Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 24 Oct 2025 16:02:54 -0400 Subject: [PATCH 43/91] test for sim save with optics block --- src/aspire/utils/relion_interop.py | 1 + tests/test_simulation.py | 55 +++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/src/aspire/utils/relion_interop.py b/src/aspire/utils/relion_interop.py index 996d321366..90e9a555b2 100644 --- a/src/aspire/utils/relion_interop.py +++ b/src/aspire/utils/relion_interop.py @@ -21,6 +21,7 @@ "_rlnCtfFigureOfMerit": float, "_rlnMagnification": float, "_rlnImagePixelSize": float, + "_rlnImageSize": int, "_rlnAmplitudeContrast": float, "_rlnImageName": str, "_rlnOriginalName": str, diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 5859aae934..082ae8e59a 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -9,7 +9,7 @@ from aspire.noise import WhiteNoiseAdder from aspire.operators import RadialCTFFilter from aspire.source import RelionSource, Simulation, _LegacySimulation -from aspire.utils import utest_tolerance +from aspire.utils import RelionStarFile, utest_tolerance from aspire.volume import LegacyVolume, SymmetryGroup, Volume from .test_utils import matplotlib_dry_run @@ -627,6 +627,59 @@ def testSimulationSaveFile(self): ) +def test_simulation_save_optics_block(tmp_path): + res = 32 + + # Radial CTF Filters. Should make 3 distinct optics blocks + kv_min, kv_max, kv_ct = 200, 300, 3 + voltages = np.linspace(kv_min, kv_max, kv_ct) + ctf_filters = [RadialCTFFilter(voltage=kv) for kv in voltages] + + # Generate and save Simulation + sim = Simulation( + n=9, L=res, C=1, unique_filters=ctf_filters, pixel_size=1.34 + ).cache() + starpath = tmp_path / "sim.star" + sim.save(starpath, overwrite=True) + + star = RelionStarFile(str(starpath)) + assert star.relion_version == "3.1" + assert set(star.blocks.keys()) == {"optics", "particles"} + + optics = star["optics"] + expected_optics_fields = [ + "_rlnOpticsGroup", + "_rlnOpticsGroupName", + "_rlnImagePixelSize", + "_rlnSphericalAberration", + "_rlnVoltage", + "_rlnImageSize", + "_rlnAmplitudeContrast", + ] + for field in expected_optics_fields: + assert field in optics + + np.testing.assert_array_equal( + optics["_rlnOpticsGroup"], np.arange(1, kv_ct + 1, dtype=int) + ) + np.testing.assert_array_equal( + optics["_rlnOpticsGroupName"], + np.array([f"opticsGroup{i}" for i in range(1, kv_ct + 1)], dtype=object), + ) + np.testing.assert_array_equal(optics["_rlnImageSize"], np.full(kv_ct, res)) + + # Depending on Simulation random indexing, voltages will be unordered + np.testing.assert_allclose(np.sort(optics["_rlnVoltage"]), voltages) + + particles = star["particles"] + assert "_rlnOpticsGroup" in particles + assert len(particles["_rlnOpticsGroup"]) == sim.n + np.testing.assert_array_equal( + np.sort(np.unique(particles["_rlnOpticsGroup"])), + np.arange(1, kv_ct + 1, dtype=int), + ) + + def test_default_symmetry_group(): # Check that default is "C1". sim = Simulation() From d0c8c4a3854e84c9c7b4a77dd837028a94230376 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 27 Oct 2025 10:48:42 -0400 Subject: [PATCH 44/91] add _rlnImageDimensionality --- src/aspire/source/image.py | 16 +++++++++++++--- src/aspire/utils/relion_interop.py | 1 + tests/test_simulation.py | 12 +++++++++++- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index cdfa4205be..cf369c584b 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -1306,11 +1306,18 @@ def _prepare_relion_optics_blocks(metadata): "_rlnMicrographPixelSize", "_rlnSphericalAberration", "_rlnVoltage", - "_rlnImageSize", "_rlnAmplitudeContrast", + "_rlnImageSize", + "_rlnImageDimensionality", ] - required_fields = ["_rlnSphericalAberration", "_rlnVoltage", "_rlnImageSize"] + required_fields = [ + "_rlnSphericalAberration", + "_rlnVoltage", + "_rlnAmplitudeContrast", + "_rlnImageSize", + "_rlnImageDimensionality", + ] pixel_fields = ["_rlnImagePixelSize", "_rlnMicrographPixelSize"] has_required = all(field in metadata for field in required_fields) @@ -1396,10 +1403,13 @@ def save_metadata(self, starfile_filepath, batch_size=512, save_mode=None): for x in np.char.split(metadata["_rlnImageName"].astype(np.str_), sep="@") ] - # Populate _rlnImageSize column, required for optics_block below + # Populate _rlnImageSize, _rlnImageDimensionality columns, required for optics_block below if "_rlnImageSize" not in metadata: metadata["_rlnImageSize"] = np.full(self.n, self.L, dtype=int) + if "_rlnImageDimensionality" not in metadata: + metadata["_rlnImageDimensionality"] = np.full(self.n, 2, dtype=int) + # Separate metadata into optics and particle blocks optics_block, particle_block = self._prepare_relion_optics_blocks(metadata) diff --git a/src/aspire/utils/relion_interop.py b/src/aspire/utils/relion_interop.py index 90e9a555b2..c807c7d863 100644 --- a/src/aspire/utils/relion_interop.py +++ b/src/aspire/utils/relion_interop.py @@ -20,6 +20,7 @@ "_rlnDetectorPixelSize": float, "_rlnCtfFigureOfMerit": float, "_rlnMagnification": float, + "_rlnImageDimensionality": int, "_rlnImagePixelSize": float, "_rlnImageSize": int, "_rlnAmplitudeContrast": float, diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 082ae8e59a..c1f273e082 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -653,12 +653,16 @@ def test_simulation_save_optics_block(tmp_path): "_rlnImagePixelSize", "_rlnSphericalAberration", "_rlnVoltage", - "_rlnImageSize", "_rlnAmplitudeContrast", + "_rlnImageSize", + "_rlnImageDimensionality", ] + + # Check all required fields are present for field in expected_optics_fields: assert field in optics + # Optics group and group name should 1-indexed np.testing.assert_array_equal( optics["_rlnOpticsGroup"], np.arange(1, kv_ct + 1, dtype=int) ) @@ -666,11 +670,17 @@ def test_simulation_save_optics_block(tmp_path): optics["_rlnOpticsGroupName"], np.array([f"opticsGroup{i}" for i in range(1, kv_ct + 1)], dtype=object), ) + + # Check image size and image dimensionality np.testing.assert_array_equal(optics["_rlnImageSize"], np.full(kv_ct, res)) + optics_dim = np.array(optics["_rlnImageDimensionality"], dtype=int) + np.testing.assert_array_equal(optics_dim, np.full(len(optics_dim), 2)) + # Depending on Simulation random indexing, voltages will be unordered np.testing.assert_allclose(np.sort(optics["_rlnVoltage"]), voltages) + # Check that each row of the data_particles block has an associated optics group particles = star["particles"] assert "_rlnOpticsGroup" in particles assert len(particles["_rlnOpticsGroup"]) == sim.n From 2650d6bdd3de432c2dcbcea6c9d956ba963f3ba8 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 27 Oct 2025 11:39:52 -0400 Subject: [PATCH 45/91] test save/load roundtrip w/ phase_flip. --- tests/test_simulation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index c1f273e082..46b9c5603d 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -689,6 +689,12 @@ def test_simulation_save_optics_block(tmp_path): np.arange(1, kv_ct + 1, dtype=int), ) + # Test phase_flip after save/load round trip to ensure correct optics group mapping + rln_src = RelionSource(starpath) + np.testing.assert_allclose( + sim.phase_flip().images[:], rln_src.phase_flip().images[:] + ) + def test_default_symmetry_group(): # Check that default is "C1". From fb74873e35808188ba2166cb4c801488af725f4a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 27 Oct 2025 13:14:35 -0400 Subject: [PATCH 46/91] update coord source test --- tests/test_coordinate_source.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_coordinate_source.py b/tests/test_coordinate_source.py index c8fb20ca8e..2a63c7b0b9 100644 --- a/tests/test_coordinate_source.py +++ b/tests/test_coordinate_source.py @@ -543,6 +543,7 @@ def testSave(self): "_rlnCoordinateX", "_rlnCoordinateY", "_rlnImageSize", + "_rlnImageDimensionality", ], ) # assert that all the correct coordinates were saved From 7c6b7f97cb2790a0ffe4c2930c895b454e60bfd9 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 27 Oct 2025 15:48:37 -0400 Subject: [PATCH 47/91] cleanup --- tests/test_simulation.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 46b9c5603d..70fd75a339 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -644,7 +644,7 @@ def test_simulation_save_optics_block(tmp_path): star = RelionStarFile(str(starpath)) assert star.relion_version == "3.1" - assert set(star.blocks.keys()) == {"optics", "particles"} + assert star.blocks.keys() == {"optics", "particles"} optics = star["optics"] expected_optics_fields = [ @@ -668,16 +668,16 @@ def test_simulation_save_optics_block(tmp_path): ) np.testing.assert_array_equal( optics["_rlnOpticsGroupName"], - np.array([f"opticsGroup{i}" for i in range(1, kv_ct + 1)], dtype=object), + np.array([f"opticsGroup{i}" for i in range(1, kv_ct + 1)]), ) - # Check image size and image dimensionality + # Check image size (res) and image dimensionality (2) np.testing.assert_array_equal(optics["_rlnImageSize"], np.full(kv_ct, res)) + np.testing.assert_array_equal( + optics["_rlnImageDimensionality"], np.full(len(optics_dim), 2) + ) - optics_dim = np.array(optics["_rlnImageDimensionality"], dtype=int) - np.testing.assert_array_equal(optics_dim, np.full(len(optics_dim), 2)) - - # Depending on Simulation random indexing, voltages will be unordered + # Due to Simulation random indexing, voltages will be unordered np.testing.assert_allclose(np.sort(optics["_rlnVoltage"]), voltages) # Check that each row of the data_particles block has an associated optics group From f18f91e5a76ec4ef460c3864328596f5d2d90c25 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 28 Oct 2025 09:14:24 -0400 Subject: [PATCH 48/91] remove unused var --- tests/test_simulation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 70fd75a339..944136c999 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -673,9 +673,7 @@ def test_simulation_save_optics_block(tmp_path): # Check image size (res) and image dimensionality (2) np.testing.assert_array_equal(optics["_rlnImageSize"], np.full(kv_ct, res)) - np.testing.assert_array_equal( - optics["_rlnImageDimensionality"], np.full(len(optics_dim), 2) - ) + np.testing.assert_array_equal(optics["_rlnImageDimensionality"], np.full(kv_ct, 2)) # Due to Simulation random indexing, voltages will be unordered np.testing.assert_allclose(np.sort(optics["_rlnVoltage"]), voltages) From eaaad6c0fd60c519945c398a416da877d18c4e7f Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 28 Oct 2025 10:27:40 -0400 Subject: [PATCH 49/91] missing optics fields warning. warning test. code comments. --- src/aspire/source/image.py | 16 ++++++++++------ tests/test_relion_source.py | 24 +++++++++++++++++++++++- tests/test_simulation.py | 2 +- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index cf369c584b..02f4942d74 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -1299,8 +1299,7 @@ def _prepare_relion_optics_blocks(metadata): The particle block keeps the remaining columns and includes a per-particle `_rlnOpticsGroup` that references the optics block. """ - # All possible optics group fields - # TODO: Check we have all of them + # Columns that belong in RELION's optics table. all_optics_fields = [ "_rlnImagePixelSize", "_rlnMicrographPixelSize", @@ -1311,6 +1310,7 @@ def _prepare_relion_optics_blocks(metadata): "_rlnImageDimensionality", ] + # We only write a data_optics block when every required field is present. required_fields = [ "_rlnSphericalAberration", "_rlnVoltage", @@ -1323,18 +1323,21 @@ def _prepare_relion_optics_blocks(metadata): has_required = all(field in metadata for field in required_fields) has_pixel_field = any(field in metadata for field in pixel_fields) - # TODO: Add warning? if not (has_required and has_pixel_field): # Optics metadata incomplete, fall back to legacy single block. + logger.warning( + "Optics metadata incomplete, writing only data_particles block." + ) return None, metadata - # Collect all optics group fields present in metadata and determine unique optics groups. + # Restrict to the optics columns that are actually present on this source. optics_value_fields = [ field for field in all_optics_fields if field in metadata ] n_rows = len(metadata["_rlnImageName"]) - group_lookup = OrderedDict() # Stores distinct optics groups + # Map each unique optics tuple to a 1-based group ID in order encountered. + group_lookup = OrderedDict() optics_groups = np.empty(n_rows, dtype=int) for idx in range(n_rows): @@ -1345,6 +1348,7 @@ def _prepare_relion_optics_blocks(metadata): metadata["_rlnOpticsGroup"] = optics_groups + # Build the optics block rows and assign group names. optics_block = OrderedDict() optics_block["_rlnOpticsGroup"] = [] optics_block["_rlnOpticsGroupName"] = [] @@ -1357,7 +1361,7 @@ def _prepare_relion_optics_blocks(metadata): for field, value in zip(optics_value_fields, signature): optics_block[field].append(value) - # Collect particle_block metadata + # Everything not lifted into the optics block stays with the particle metadata. particle_block = OrderedDict() if "_rlnOpticsGroup" in metadata: particle_block["_rlnOpticsGroup"] = metadata["_rlnOpticsGroup"] diff --git a/tests/test_relion_source.py b/tests/test_relion_source.py index 009ecd321d..d0b996c795 100644 --- a/tests/test_relion_source.py +++ b/tests/test_relion_source.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from aspire.source import RelionSource, Simulation +from aspire.source import ImageSource, RelionSource, Simulation from aspire.utils import RelionStarFile from aspire.volume import SymmetryGroup @@ -61,6 +61,28 @@ def test_symmetry_group(caplog): assert str(src_override_sym.symmetry_group) == "C6" +def test_prepare_relion_optics_blocks_warns(caplog): + """ + Test we warn when optics group metadata is missing. + """ + # metadata dict with no CTF values + metadata = { + "_rlnImagePixelSize": np.array([1.234]), + "_rlnImageSize": np.array([32]), + "_rlnImageDimensionality": np.array([2]), + } + + caplog.clear() + with caplog.at_level(logging.WARNING): + optics_block, particle_block = ImageSource._prepare_relion_optics_blocks( + metadata.copy() + ) + + assert optics_block is None + assert particle_block == metadata + assert "Optics metadata incomplete" in caplog.text + + def test_pixel_size(caplog): """ Instantiate RelionSource from starfiles containing the following pixel size diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 944136c999..2f323d3490 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -8,7 +8,7 @@ from aspire.noise import WhiteNoiseAdder from aspire.operators import RadialCTFFilter -from aspire.source import RelionSource, Simulation, _LegacySimulation +from aspire.source import ImageSource, RelionSource, Simulation, _LegacySimulation from aspire.utils import RelionStarFile, utest_tolerance from aspire.volume import LegacyVolume, SymmetryGroup, Volume From 58d0f14c0be8f3bd8dd23c7bd449aaec047b5c2a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 28 Oct 2025 10:28:34 -0400 Subject: [PATCH 50/91] removed unused import. --- tests/test_simulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 2f323d3490..944136c999 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -8,7 +8,7 @@ from aspire.noise import WhiteNoiseAdder from aspire.operators import RadialCTFFilter -from aspire.source import ImageSource, RelionSource, Simulation, _LegacySimulation +from aspire.source import RelionSource, Simulation, _LegacySimulation from aspire.utils import RelionStarFile, utest_tolerance from aspire.volume import LegacyVolume, SymmetryGroup, Volume From c0c89b99919e87204307ec28a0ea3ca46fa46eec Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 30 Oct 2025 11:24:10 -0400 Subject: [PATCH 51/91] deprecate _rlnAmplitude in favor of private __amplitude field. --- src/aspire/source/image.py | 16 +++++++++++----- src/aspire/utils/relion_interop.py | 1 + tests/test_array_image_source.py | 4 ++-- tests/test_simulation.py | 2 ++ 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 02f4942d74..fa1615c372 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -483,15 +483,21 @@ def offsets(self, values): @property def amplitudes(self): - return np.atleast_1d( - self.get_metadata( - "_rlnAmplitude", default_value=np.array(1.0, dtype=self.dtype) + if self.has_metadata("__amplitude"): + values = self.get_metadata("__amplitude") + else: + values = self.get_metadata( + "_rlnAmplitude", + default_value=np.array(1.0, dtype=np.float64), ) - ) + return np.atleast_1d(np.asarray(values, dtype=np.float64)) @amplitudes.setter def amplitudes(self, values): - return self.set_metadata("_rlnAmplitude", np.array(values, dtype=self.dtype)) + values = np.asarray(values, dtype=np.float64) + self.set_metadata("__amplitude", values) + # Drop the legacy field if we encountered it while loading a STAR file. + self._metadata.pop("_rlnAmplitude", None) @property def angles(self): diff --git a/src/aspire/utils/relion_interop.py b/src/aspire/utils/relion_interop.py index c807c7d863..7f9916bdc2 100644 --- a/src/aspire/utils/relion_interop.py +++ b/src/aspire/utils/relion_interop.py @@ -12,6 +12,7 @@ # of certain key fields used in the codebase, # which are originally read from Relion STAR files. relion_metadata_fields = { + "__amplitude": float, "_rlnVoltage": float, "_rlnDefocusU": float, "_rlnDefocusV": float, diff --git a/tests/test_array_image_source.py b/tests/test_array_image_source.py index 8dc2a28fb7..a2c3ee4f0c 100644 --- a/tests/test_array_image_source.py +++ b/tests/test_array_image_source.py @@ -323,10 +323,10 @@ def test_dtype_passthrough(dtype): # Check dtypes np.testing.assert_equal(src.dtype, dtype) np.testing.assert_equal(src.images[:].dtype, dtype) - np.testing.assert_equal(src.amplitudes.dtype, dtype) - # offsets are always stored as doubles + # offsets and amplitudes are always stored as doubles np.testing.assert_equal(src.offsets.dtype, np.float64) + np.testing.assert_equal(src.amplitudes.dtype, np.float64) def test_stack_1d_only(): diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 944136c999..a066f6b67d 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -882,6 +882,8 @@ def check_metadata(sim_src, relion_src): those in a RelionSource. """ for k, v in sim_src._metadata.items(): + if k.startswith("__"): + continue try: np.testing.assert_array_equal(v, relion_src._metadata[k]) except AssertionError: From 4e691093ed384c9353161a8eb3c26f0e402921c8 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 31 Oct 2025 11:26:22 -0400 Subject: [PATCH 52/91] use _aspireAmplitude to retain ampitudes for saved files --- src/aspire/source/image.py | 15 +++++---------- src/aspire/source/relion.py | 4 ++++ src/aspire/utils/relion_interop.py | 3 ++- tests/test_simulation.py | 2 -- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index fa1615c372..280528367d 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -483,21 +483,16 @@ def offsets(self, values): @property def amplitudes(self): - if self.has_metadata("__amplitude"): - values = self.get_metadata("__amplitude") - else: - values = self.get_metadata( - "_rlnAmplitude", - default_value=np.array(1.0, dtype=np.float64), - ) + values = self.get_metadata( + "_aspireAmplitude", + default_value=np.array(1.0, dtype=np.float64), + ) return np.atleast_1d(np.asarray(values, dtype=np.float64)) @amplitudes.setter def amplitudes(self, values): values = np.asarray(values, dtype=np.float64) - self.set_metadata("__amplitude", values) - # Drop the legacy field if we encountered it while loading a STAR file. - self._metadata.pop("_rlnAmplitude", None) + self.set_metadata("_aspireAmplitude", values) @property def angles(self): diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index c3620c9bdb..f9851eb69f 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -195,6 +195,10 @@ def populate_metadata(self): metadata = RelionStarFile(self.filepath).get_merged_data_block() + # Promote legacy _rlnAmplitude column to the ASPIRE-specific name. + if "_rlnAmplitude" in metadata and "_aspireAmplitude" not in metadata: + metadata["_aspireAmplitude"] = metadata.pop("_rlnAmplitude") + # particle locations are stored as e.g. '000001@first_micrograph.mrcs' # in the _rlnImageName column. here, we're splitting this information # so we can get the particle's index in the .mrcs stack as an int diff --git a/src/aspire/utils/relion_interop.py b/src/aspire/utils/relion_interop.py index 7f9916bdc2..f7774689df 100644 --- a/src/aspire/utils/relion_interop.py +++ b/src/aspire/utils/relion_interop.py @@ -12,7 +12,8 @@ # of certain key fields used in the codebase, # which are originally read from Relion STAR files. relion_metadata_fields = { - "__amplitude": float, + "_aspireAmplitude": float, + "_rlnAmplitude": float, "_rlnVoltage": float, "_rlnDefocusU": float, "_rlnDefocusV": float, diff --git a/tests/test_simulation.py b/tests/test_simulation.py index a066f6b67d..944136c999 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -882,8 +882,6 @@ def check_metadata(sim_src, relion_src): those in a RelionSource. """ for k, v in sim_src._metadata.items(): - if k.startswith("__"): - continue try: np.testing.assert_array_equal(v, relion_src._metadata[k]) except AssertionError: From 3165bbc1dfaaa8b7bd94b1289c35d2f78e51dede Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 3 Nov 2025 10:29:34 -0500 Subject: [PATCH 53/91] ensure pixel size is added to mrc header --- src/aspire/source/image.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 280528367d..afdbdbce16 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -1499,6 +1499,8 @@ def save_images( # for large arrays. stats.update_header(mrc) + # Add pixel size to header + mrc.voxel_size = self.pixel_size else: # save all images into multiple mrc files in batch size for i_start in np.arange(0, self.n, batch_size): @@ -1512,6 +1514,7 @@ def save_images( f"Saving ImageSource[{i_start}-{i_end-1}] to {mrcs_filepath}" ) im = self.images[i_start:i_end] + im.pixel_size = self.pixel_size im.save(mrcs_filepath, overwrite=overwrite) def estimate_signal_mean_energy( From c7338c973241da5bdc2baae350f055b7b43a93ab Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 3 Nov 2025 11:20:15 -0500 Subject: [PATCH 54/91] test mrc pixel_size in header --- tests/test_simulation.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 944136c999..f326612aa1 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -3,6 +3,7 @@ import tempfile from unittest import TestCase +import mrcfile import numpy as np import pytest @@ -876,6 +877,25 @@ def test_save_overwrite(caplog): check_metadata(sim2, sim2_loaded_renamed) +@pytest.mark.parametrize("batch_size", [1, 6]) +def test_simulation_save_sets_voxel_size(tmp_path, batch_size): + """ + Test we save with pixel_size appended to the mrcfile header. + """ + # Note, n=6 and batch_size=6 exercises save_mode=='single' branch. + sim = Simulation(n=6, L=24, pixel_size=1.37) + info = sim.save(tmp_path / "pixel_size.star", batch_size=batch_size, overwrite=True) + + for stack_name in info["mrcs"]: + stack_path = tmp_path / stack_name + with mrcfile.open(stack_path, permissive=True) as f: + vs = f.voxel_size + header_vals = np.array( + [float(vs.x), float(vs.y), float(vs.z)], dtype=np.float64 + ) + np.testing.assert_allclose(header_vals, sim.pixel_size) + + def check_metadata(sim_src, relion_src): """ Helper function to test if metadata fields in a Simulation match From f709882f593e325744b30751f2a13a37019119ca Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 5 Nov 2025 10:11:15 -0500 Subject: [PATCH 55/91] Use defaultdict --- src/aspire/source/image.py | 8 ++------ tests/test_simulation.py | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index afdbdbce16..d7af691abc 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -3,7 +3,7 @@ import logging import os.path from abc import ABC, abstractmethod -from collections import OrderedDict +from collections import OrderedDict, defaultdict from collections.abc import Iterable import mrcfile @@ -1350,11 +1350,7 @@ def _prepare_relion_optics_blocks(metadata): metadata["_rlnOpticsGroup"] = optics_groups # Build the optics block rows and assign group names. - optics_block = OrderedDict() - optics_block["_rlnOpticsGroup"] = [] - optics_block["_rlnOpticsGroupName"] = [] - for field in optics_value_fields: - optics_block[field] = [] + optics_block = defaultdict(list) for signature, group_id in group_lookup.items(): optics_block["_rlnOpticsGroup"].append(group_id) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index f326612aa1..83c0932c17 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -663,7 +663,7 @@ def test_simulation_save_optics_block(tmp_path): for field in expected_optics_fields: assert field in optics - # Optics group and group name should 1-indexed + # Optics group and group name should be 1-indexed np.testing.assert_array_equal( optics["_rlnOpticsGroup"], np.arange(1, kv_ct + 1, dtype=int) ) From d335d297771edcc928fd653ce043757273305e14 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 5 Nov 2025 10:35:14 -0500 Subject: [PATCH 56/91] Remove old file compatibility --- src/aspire/source/relion.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index f9851eb69f..c3620c9bdb 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -195,10 +195,6 @@ def populate_metadata(self): metadata = RelionStarFile(self.filepath).get_merged_data_block() - # Promote legacy _rlnAmplitude column to the ASPIRE-specific name. - if "_rlnAmplitude" in metadata and "_aspireAmplitude" not in metadata: - metadata["_aspireAmplitude"] = metadata.pop("_rlnAmplitude") - # particle locations are stored as e.g. '000001@first_micrograph.mrcs' # in the _rlnImageName column. here, we're splitting this information # so we can get the particle's index in the .mrcs stack as an int From 79780d1059558396495569c497ea338c14b60249 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 5 Nov 2025 15:44:58 -0500 Subject: [PATCH 57/91] Remove _aspireAmplitude from relion_metadata_fields dict. --- src/aspire/utils/relion_interop.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/aspire/utils/relion_interop.py b/src/aspire/utils/relion_interop.py index f7774689df..c807c7d863 100644 --- a/src/aspire/utils/relion_interop.py +++ b/src/aspire/utils/relion_interop.py @@ -12,8 +12,6 @@ # of certain key fields used in the codebase, # which are originally read from Relion STAR files. relion_metadata_fields = { - "_aspireAmplitude": float, - "_rlnAmplitude": float, "_rlnVoltage": float, "_rlnDefocusU": float, "_rlnDefocusV": float, From 6fd4a56c3d83b525339be083c045015db6556719 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 7 Nov 2025 11:49:23 -0500 Subject: [PATCH 58/91] Test save on source slices --- tests/test_simulation.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 83c0932c17..191fca0810 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -695,6 +695,44 @@ def test_simulation_save_optics_block(tmp_path): ) +def test_simulation_slice_save_roundtrip(tmp_path): + # Radial CTF Filters + kv_min, kv_max, kv_ct = 200, 300, 3 + voltages = np.linspace(kv_min, kv_max, kv_ct) + ctf_filters = [RadialCTFFilter(voltage=kv) for kv in voltages] + + # Generate and save slice of Simulation + sim = Simulation(n=9, L=16, C=1, unique_filters=ctf_filters, pixel_size=1.34) + sliced_sim = sim[::2] + save_path = tmp_path / "sliced_sim.star" + sliced_sim.save(save_path, overwrite=True) + + # Load saved slice and compare to original + reloaded = RelionSource(save_path) + + # Check images + np.testing.assert_allclose( + reloaded.images[:].asnumpy(), + sliced_sim.images[:].asnumpy(), + ) + + # Check metadata related to optics block + metadata_fields = [ + "_rlnVoltage", + "_rlnDefocusU", + "_rlnDefocusV", + "_rlnDefocusAngle", + "_rlnSphericalAberration", + "_rlnAmplitudeContrast", + "_rlnImagePixelSize", + ] + for field in metadata_fields: + np.testing.assert_allclose( + reloaded.get_metadata(field), + sliced_sim.get_metadata(field), + ) + + def test_default_symmetry_group(): # Check that default is "C1". sim = Simulation() From 47a5343d15ff132c175cbae30107395cea068869 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 13 Nov 2025 15:06:36 -0500 Subject: [PATCH 59/91] Always write an optics block --- src/aspire/source/image.py | 31 +++++++++++++------------------ src/aspire/source/relion.py | 3 ++- tests/test_coordinate_source.py | 17 +++++++++++++++-- tests/test_relion_source.py | 19 ++++++++++++++++--- 4 files changed, 46 insertions(+), 24 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index d7af691abc..3b0db192cd 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -1311,31 +1311,26 @@ def _prepare_relion_optics_blocks(metadata): "_rlnImageDimensionality", ] - # We only write a data_optics block when every required field is present. - required_fields = [ - "_rlnSphericalAberration", - "_rlnVoltage", - "_rlnAmplitudeContrast", - "_rlnImageSize", - "_rlnImageDimensionality", - ] - pixel_fields = ["_rlnImagePixelSize", "_rlnMicrographPixelSize"] + # Some optics group fields might not always be present, but are necessary + # for reading the file in Relion. We ensure these fields exist and populate + # with a dummy value if not. + n_rows = len(metadata["_rlnImageName"]) - has_required = all(field in metadata for field in required_fields) - has_pixel_field = any(field in metadata for field in pixel_fields) + def _ensure_column(field, value): + if field not in metadata: + logger.warning( + f"Optics field {field} not found, populating with default value {value}" + ) + metadata[field] = np.full(n_rows, value) - if not (has_required and has_pixel_field): - # Optics metadata incomplete, fall back to legacy single block. - logger.warning( - "Optics metadata incomplete, writing only data_particles block." - ) - return None, metadata + _ensure_column("_rlnSphericalAberration", 2.0) + _ensure_column("_rlnVoltage", 300) + _ensure_column("_rlnAmplitudeContrast", 0.1) # Restrict to the optics columns that are actually present on this source. optics_value_fields = [ field for field in all_optics_fields if field in metadata ] - n_rows = len(metadata["_rlnImageName"]) # Map each unique optics tuple to a 1-based group ID in order encountered. group_lookup = OrderedDict() diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index c3620c9bdb..7d7cc3cd54 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -168,7 +168,8 @@ def __init__( f"Found partially populated CTF Params." f" To automatically populate CTFFilters provide {CTF_params}" ) - + self.unique_filters = [IdentityFilter()] + self.filter_indices = np.zeros(self.n, dtype=int) # If no CTF info in STAR, we initialize the filter values of metadata with default values else: self.unique_filters = [IdentityFilter()] diff --git a/tests/test_coordinate_source.py b/tests/test_coordinate_source.py index 2a63c7b0b9..13219dcaa3 100644 --- a/tests/test_coordinate_source.py +++ b/tests/test_coordinate_source.py @@ -526,7 +526,7 @@ def testSave(self): # load saved particle stack saved_star = StarFile(star_path) # we want to read the saved mrcs file from the STAR file - image_name_column = saved_star.get_block_by_index(0)["_rlnImageName"] + image_name_column = saved_star.get_block_by_index(1)["_rlnImageName"] # we're reading a string of the form 0000X@mrcs_path.mrcs _particle, mrcs_path = image_name_column[0].split("@") saved_mrcs_stack = mrcfile.open(os.path.join(self.data_folder, mrcs_path)).data @@ -537,15 +537,28 @@ def testSave(self): self.assertEqual( list(saved_star["particles"].keys()), [ - "_rlnImagePixelSize", + "_rlnOpticsGroup", "_rlnSymmetryGroup", "_rlnImageName", "_rlnCoordinateX", "_rlnCoordinateY", + ], + ) + + self.assertEqual( + list(saved_star["optics"].keys()), + [ + "_rlnOpticsGroup", + "_rlnOpticsGroupName", + "_rlnImagePixelSize", + "_rlnSphericalAberration", + "_rlnVoltage", + "_rlnAmplitudeContrast", "_rlnImageSize", "_rlnImageDimensionality", ], ) + # assert that all the correct coordinates were saved for i in range(10): self.assertEqual( diff --git a/tests/test_relion_source.py b/tests/test_relion_source.py index d0b996c795..31638bdb7e 100644 --- a/tests/test_relion_source.py +++ b/tests/test_relion_source.py @@ -70,6 +70,7 @@ def test_prepare_relion_optics_blocks_warns(caplog): "_rlnImagePixelSize": np.array([1.234]), "_rlnImageSize": np.array([32]), "_rlnImageDimensionality": np.array([2]), + "_rlnImageName": np.array(["000001@stack.mrcs"]), } caplog.clear() @@ -78,9 +79,21 @@ def test_prepare_relion_optics_blocks_warns(caplog): metadata.copy() ) - assert optics_block is None - assert particle_block == metadata - assert "Optics metadata incomplete" in caplog.text + # We should get and optics block + assert optics_block is not None + + # Verify defaults were injected. + np.testing.assert_allclose(optics_block["_rlnImagePixelSize"], [1.234]) + np.testing.assert_array_equal(optics_block["_rlnImageSize"], [32]) + np.testing.assert_array_equal(optics_block["_rlnImageDimensionality"], [2]) + np.testing.assert_allclose(optics_block["_rlnVoltage"], [300.0]) + np.testing.assert_allclose(optics_block["_rlnSphericalAberration"], [2.0]) + np.testing.assert_allclose(optics_block["_rlnAmplitudeContrast"], [0.1]) + + # Caplog should contain the warnings about the three missing fields. + assert "Optics field _rlnSphericalAberration not found" in caplog.text + assert "Optics field _rlnVoltage not found" in caplog.text + assert "Optics field _rlnAmplitudeContrast not found" in caplog.text def test_pixel_size(caplog): From d7d03c8d2f87761648ed91e984b3d67e46c702dd Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 14 Nov 2025 08:10:13 -0500 Subject: [PATCH 60/91] Gallery example simulation -> relion --- .../save_simulation_relion_reconstruct.py | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 gallery/experiments/save_simulation_relion_reconstruct.py diff --git a/gallery/experiments/save_simulation_relion_reconstruct.py b/gallery/experiments/save_simulation_relion_reconstruct.py new file mode 100644 index 0000000000..babeffc351 --- /dev/null +++ b/gallery/experiments/save_simulation_relion_reconstruct.py @@ -0,0 +1,87 @@ +""" +Simulated Stack → RELION Reconstruction +======================================= + +This experiment shows how to: + +1. build a synthetic dataset with ASPIRE, +2. write the stack via ``ImageSource.save`` so RELION can consume it, and +3. call :code:`relion_reconstruct` on the saved STAR file. +""" + +# %% +# Imports +# ------- + +import logging +from pathlib import Path + +import numpy as np + +from aspire.downloader import emdb_2660 +from aspire.noise import WhiteNoiseAdder +from aspire.operators import RadialCTFFilter +from aspire.source import RelionSource, Simulation + +logger = logging.getLogger(__name__) + + +# %% +# Configuration +# ------------- +# We set a few parameters to initialize the Simulation. +# You can safely alter ``n_particles`` (or change the voltages, etc.) when +# trying this interactively; the defaults here are chosen for demonstrative purposes. + +output_dir = Path("relion_save_demo") +output_dir.mkdir(exist_ok=True) + +n_particles = 512 +snr = 0.25 +voltages = np.linspace(200, 300, 3) # kV settings for the radial CTF filters +star_path = output_dir / f"sim_n{n_particles}.star" + + +# %% +# Volume and Filters +# ------------------ +# Start from the EMDB-2660 ribosome map and build a small set of radial CTF filters +# that RELION will recover as optics groups. + +vol = emdb_2660() +ctf_filters = [RadialCTFFilter(voltage=kv) for kv in voltages] + + +# %% +# Simulate, Add Noise, Save +# ------------------------- +# Initialize the Simulation: +# mix the CTFs across the stack, add white noise at a target SNR, +# and write the particles and metadata to a RELION-compatible STAR/MRC stack. + +sim = Simulation( + n=n_particles, + vols=vol, + unique_filters=ctf_filters, + noise_adder=WhiteNoiseAdder.from_snr(snr), +) +sim.save(star_path, overwrite=True) + + +# %% +# Running ``relion_reconstruct`` +# ------------------------------ +# ``relion_reconstruct`` is an external RELION command, so we just show the call. +# Run this in a RELION-enabled shell after generating the STAR file above. + +relion_cmd = [ + "relion_reconstruct", + "--i", + str(star_path), + "--o", + str(output_dir / "relion_recon.mrc"), + "--ctf", +] + +print(" ".join(relion_cmd)) + From e18279abb6c7dffc049a3e69720168b7b15ffbd1 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 14 Nov 2025 08:25:15 -0500 Subject: [PATCH 61/91] clean up gallery --- gallery/experiments/save_simulation_relion_reconstruct.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/gallery/experiments/save_simulation_relion_reconstruct.py b/gallery/experiments/save_simulation_relion_reconstruct.py index babeffc351..767abf6bfc 100644 --- a/gallery/experiments/save_simulation_relion_reconstruct.py +++ b/gallery/experiments/save_simulation_relion_reconstruct.py @@ -21,7 +21,7 @@ from aspire.downloader import emdb_2660 from aspire.noise import WhiteNoiseAdder from aspire.operators import RadialCTFFilter -from aspire.source import RelionSource, Simulation +from aspire.source import Simulation logger = logging.getLogger(__name__) @@ -83,5 +83,4 @@ "--ctf", ] -print(" ".join(relion_cmd)) - +logger.info(" ".join(relion_cmd)) From ecce5a9e0f7dba215f55505ea21512fbad12373e Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 14 Nov 2025 15:14:26 -0500 Subject: [PATCH 62/91] _aspireMetadata 'no_ctf' tag for missing optics fields. Detect ASPIRE-generated dummy values when loading. Unit test. --- src/aspire/source/image.py | 12 +++++++++--- src/aspire/source/relion.py | 17 +++++++++++++++-- tests/test_simulation.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 3b0db192cd..aa1b85a1b3 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -1316,16 +1316,22 @@ def _prepare_relion_optics_blocks(metadata): # with a dummy value if not. n_rows = len(metadata["_rlnImageName"]) + missing_fields = [] + def _ensure_column(field, value): if field not in metadata: + missing_fields.append(field) logger.warning( f"Optics field {field} not found, populating with default value {value}" ) metadata[field] = np.full(n_rows, value) - _ensure_column("_rlnSphericalAberration", 2.0) - _ensure_column("_rlnVoltage", 300) - _ensure_column("_rlnAmplitudeContrast", 0.1) + _ensure_column("_rlnSphericalAberration", 0) + _ensure_column("_rlnVoltage", 0) + _ensure_column("_rlnAmplitudeContrast", 0) + + if missing_fields: + metadata["_aspireMetadata"] = np.full(n_rows, "no_ctf", dtype=object) # Restrict to the optics columns that are actually present on this source. optics_value_fields = [ diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index 7d7cc3cd54..c063c3b7b2 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -125,6 +125,12 @@ def __init__( for key in offset_keys: del self._metadata[key] + # Detect ASPIRE-generated dummy variables + aspire_metadata = metadata.get("_aspireMetadata") + dummy_ctf = isinstance(aspire_metadata, (list, np.ndarray)) and np.all( + np.asarray(aspire_metadata) == "no_ctf" + ) + # CTF estimation parameters coming from Relion CTF_params = [ "_rlnVoltage", @@ -162,14 +168,21 @@ def __init__( # self.unique_filters of the filter that should be applied self.filter_indices = filter_indices + # If we detect ASPIRE added dummy variables, log and initialize identity filter + elif dummy_ctf: + logger.info( + "Detected ASPIRE-generated dummy optics; initializing identity filters." + ) + self.unique_filters = [IdentityFilter()] + self.filter_indices = np.zeros(self.n, dtype=int) + # We have provided some, but not all the required params elif any(param in metadata for param in CTF_params): logger.warning( f"Found partially populated CTF Params." f" To automatically populate CTFFilters provide {CTF_params}" ) - self.unique_filters = [IdentityFilter()] - self.filter_indices = np.zeros(self.n, dtype=int) + # If no CTF info in STAR, we initialize the filter values of metadata with default values else: self.unique_filters = [IdentityFilter()] diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 191fca0810..c053b2eb0f 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -915,6 +915,34 @@ def test_save_overwrite(caplog): check_metadata(sim2, sim2_loaded_renamed) +def test_save_load_dummy_ctf_values(tmp_path, caplog): + """ + Test we populate optics group field with dummy values when none + are present. These values should be detected upon reloading the source. + """ + star_path = tmp_path / "no_ctf.star" + sim = Simulation(n=8, L=16) # no unique_filters, ie. no CTF info + sim.save(star_path, overwrite=True) + + # STAR file should contain our fallback tag + star = RelionStarFile(star_path) + particles_block = star.get_block_by_index(1) + np.testing.assert_array_equal( + particles_block["_aspireMetadata"], np.full(sim.n, "no_ctf", dtype=object) + ) + + # Tag should survive round-trip + caplog.clear() + reloaded = RelionSource(star_path) + np.testing.assert_array_equal( + reloaded._metadata["_aspireMetadata"], + np.full(reloaded.n, "no_ctf", dtype=object), + ) + + # Check message is logged about detecting dummy variables + assert "Detected ASPIRE-generated dummy optics" in caplog.text + + @pytest.mark.parametrize("batch_size", [1, 6]) def test_simulation_save_sets_voxel_size(tmp_path, batch_size): """ From 74fe99babbd4861142f7efee558d33059eacdb58 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 14 Nov 2025 15:56:50 -0500 Subject: [PATCH 63/91] test cleanup --- tests/test_coordinate_source.py | 1 + tests/test_relion_source.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_coordinate_source.py b/tests/test_coordinate_source.py index 13219dcaa3..8baf7d1338 100644 --- a/tests/test_coordinate_source.py +++ b/tests/test_coordinate_source.py @@ -542,6 +542,7 @@ def testSave(self): "_rlnImageName", "_rlnCoordinateX", "_rlnCoordinateY", + "_aspireMetadata", ], ) diff --git a/tests/test_relion_source.py b/tests/test_relion_source.py index 31638bdb7e..64703ed491 100644 --- a/tests/test_relion_source.py +++ b/tests/test_relion_source.py @@ -86,9 +86,9 @@ def test_prepare_relion_optics_blocks_warns(caplog): np.testing.assert_allclose(optics_block["_rlnImagePixelSize"], [1.234]) np.testing.assert_array_equal(optics_block["_rlnImageSize"], [32]) np.testing.assert_array_equal(optics_block["_rlnImageDimensionality"], [2]) - np.testing.assert_allclose(optics_block["_rlnVoltage"], [300.0]) - np.testing.assert_allclose(optics_block["_rlnSphericalAberration"], [2.0]) - np.testing.assert_allclose(optics_block["_rlnAmplitudeContrast"], [0.1]) + np.testing.assert_allclose(optics_block["_rlnVoltage"], [0]) + np.testing.assert_allclose(optics_block["_rlnSphericalAberration"], [0]) + np.testing.assert_allclose(optics_block["_rlnAmplitudeContrast"], [0]) # Caplog should contain the warnings about the three missing fields. assert "Optics field _rlnSphericalAberration not found" in caplog.text From 5e21f68d8777451e9753219fa9e846d1779d44f7 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 18 Nov 2025 08:25:34 -0500 Subject: [PATCH 64/91] optics group level not_ctf flag --- src/aspire/source/image.py | 7 ++++--- src/aspire/source/relion.py | 7 ++----- tests/test_coordinate_source.py | 2 +- tests/test_simulation.py | 11 +++-------- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index aa1b85a1b3..fa8aa2fe25 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -1330,9 +1330,6 @@ def _ensure_column(field, value): _ensure_column("_rlnVoltage", 0) _ensure_column("_rlnAmplitudeContrast", 0) - if missing_fields: - metadata["_aspireMetadata"] = np.full(n_rows, "no_ctf", dtype=object) - # Restrict to the optics columns that are actually present on this source. optics_value_fields = [ field for field in all_optics_fields if field in metadata @@ -1359,6 +1356,10 @@ def _ensure_column(field, value): for field, value in zip(optics_value_fields, signature): optics_block[field].append(value) + # Tag dummy optics if we had to synthesize any fields + if missing_fields: + optics_block["_aspireNoCTF"] = ["." for _ in range(len(group_lookup))] + # Everything not lifted into the optics block stays with the particle metadata. particle_block = OrderedDict() if "_rlnOpticsGroup" in metadata: diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index c063c3b7b2..963e7ae427 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -126,10 +126,7 @@ def __init__( del self._metadata[key] # Detect ASPIRE-generated dummy variables - aspire_metadata = metadata.get("_aspireMetadata") - dummy_ctf = isinstance(aspire_metadata, (list, np.ndarray)) and np.all( - np.asarray(aspire_metadata) == "no_ctf" - ) + no_ctf_flag = "_aspireNoCTF" in metadata # CTF estimation parameters coming from Relion CTF_params = [ @@ -169,7 +166,7 @@ def __init__( self.filter_indices = filter_indices # If we detect ASPIRE added dummy variables, log and initialize identity filter - elif dummy_ctf: + elif no_ctf_flag: logger.info( "Detected ASPIRE-generated dummy optics; initializing identity filters." ) diff --git a/tests/test_coordinate_source.py b/tests/test_coordinate_source.py index 8baf7d1338..3b0f29f1ee 100644 --- a/tests/test_coordinate_source.py +++ b/tests/test_coordinate_source.py @@ -542,7 +542,6 @@ def testSave(self): "_rlnImageName", "_rlnCoordinateX", "_rlnCoordinateY", - "_aspireMetadata", ], ) @@ -557,6 +556,7 @@ def testSave(self): "_rlnAmplitudeContrast", "_rlnImageSize", "_rlnImageDimensionality", + "_aspireNoCTF", ], ) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index c053b2eb0f..58bea32128 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -926,18 +926,13 @@ def test_save_load_dummy_ctf_values(tmp_path, caplog): # STAR file should contain our fallback tag star = RelionStarFile(star_path) - particles_block = star.get_block_by_index(1) - np.testing.assert_array_equal( - particles_block["_aspireMetadata"], np.full(sim.n, "no_ctf", dtype=object) - ) + optics_block = star.get_block_by_index(0) + assert "_aspireNoCTF" in optics_block # Tag should survive round-trip caplog.clear() reloaded = RelionSource(star_path) - np.testing.assert_array_equal( - reloaded._metadata["_aspireMetadata"], - np.full(reloaded.n, "no_ctf", dtype=object), - ) + assert "_aspireNoCTF" in reloaded._metadata # Check message is logged about detecting dummy variables assert "Detected ASPIRE-generated dummy optics" in caplog.text From d4c2b4da25f0176ce4df37e5457dc5d7b3acda36 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 20 Nov 2025 08:22:43 -0500 Subject: [PATCH 65/91] clean up gallery --- .../save_simulation_relion_reconstruct.py | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/gallery/experiments/save_simulation_relion_reconstruct.py b/gallery/experiments/save_simulation_relion_reconstruct.py index 767abf6bfc..eb8cdda740 100644 --- a/gallery/experiments/save_simulation_relion_reconstruct.py +++ b/gallery/experiments/save_simulation_relion_reconstruct.py @@ -1,6 +1,6 @@ """ -Simulated Stack → RELION Reconstruction -======================================= +Simulated Stack to RELION Reconstruction +======================================== This experiment shows how to: @@ -37,10 +37,10 @@ output_dir.mkdir(exist_ok=True) n_particles = 512 -snr = 0.25 -voltages = np.linspace(200, 300, 3) # kV settings for the radial CTF filters -star_path = output_dir / f"sim_n{n_particles}.star" - +snr = 0.5 +defocus = np.linspace(15000, 25000, 7) # defocus values for the radial CTF filters (angstroms) +star_file = f"sim_n{n_particles}.star" +star_path = output_dir / star_file # %% # Volume and Filters @@ -49,7 +49,7 @@ # that RELION will recover as optics groups. vol = emdb_2660() -ctf_filters = [RadialCTFFilter(voltage=kv) for kv in voltages] +ctf_filters = [RadialCTFFilter(defocus=d) for d in defocus] # %% @@ -72,15 +72,10 @@ # Running ``relion_reconstruct`` # ------------------------------ # ``relion_reconstruct`` is an external RELION command, so we just show the call. -# Run this in a RELION-enabled shell after generating the STAR file above. - -relion_cmd = [ - "relion_reconstruct", - "--i", - str(star_path), - "--o", - str(output_dir / "relion_recon.mrc"), - "--ctf", -] - -logger.info(" ".join(relion_cmd)) +# Run this, for the output directory, in a RELION-enabled shell after generating +# the STAR file above. + +logger.info( + f"relion_reconstruct --i {star_file} " + f"--o 'relion_recon.mrc' --ctf" +) From 9bcba119ab4651edc1cb6222e6b1d5c5c9411860 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 20 Nov 2025 08:25:30 -0500 Subject: [PATCH 66/91] tox --- .../experiments/save_simulation_relion_reconstruct.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/gallery/experiments/save_simulation_relion_reconstruct.py b/gallery/experiments/save_simulation_relion_reconstruct.py index eb8cdda740..1432d89bc3 100644 --- a/gallery/experiments/save_simulation_relion_reconstruct.py +++ b/gallery/experiments/save_simulation_relion_reconstruct.py @@ -38,7 +38,9 @@ n_particles = 512 snr = 0.5 -defocus = np.linspace(15000, 25000, 7) # defocus values for the radial CTF filters (angstroms) +defocus = np.linspace( + 15000, 25000, 7 +) # defocus values for the radial CTF filters (angstroms) star_file = f"sim_n{n_particles}.star" star_path = output_dir / star_file @@ -75,7 +77,4 @@ # Run this, for the output directory, in a RELION-enabled shell after generating # the STAR file above. -logger.info( - f"relion_reconstruct --i {star_file} " - f"--o 'relion_recon.mrc' --ctf" -) +logger.info(f"relion_reconstruct --i {star_file} " f"--o 'relion_recon.mrc' --ctf") From 215204aa1b858e0a1f15f842334eadce91ce60ab Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 20 Nov 2025 10:44:18 -0500 Subject: [PATCH 67/91] typo --- gallery/experiments/save_simulation_relion_reconstruct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/experiments/save_simulation_relion_reconstruct.py b/gallery/experiments/save_simulation_relion_reconstruct.py index 1432d89bc3..d7e3405eb4 100644 --- a/gallery/experiments/save_simulation_relion_reconstruct.py +++ b/gallery/experiments/save_simulation_relion_reconstruct.py @@ -74,7 +74,7 @@ # Running ``relion_reconstruct`` # ------------------------------ # ``relion_reconstruct`` is an external RELION command, so we just show the call. -# Run this, for the output directory, in a RELION-enabled shell after generating +# Run this, from the output directory, in a RELION-enabled shell after generating # the STAR file above. logger.info(f"relion_reconstruct --i {star_file} " f"--o 'relion_recon.mrc' --ctf") From 700cedf2dd256253ce086a052364ba3375c11639 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 20 Nov 2025 11:15:07 -0500 Subject: [PATCH 68/91] gallery comment update --- gallery/experiments/save_simulation_relion_reconstruct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/experiments/save_simulation_relion_reconstruct.py b/gallery/experiments/save_simulation_relion_reconstruct.py index d7e3405eb4..e52f6bc2a8 100644 --- a/gallery/experiments/save_simulation_relion_reconstruct.py +++ b/gallery/experiments/save_simulation_relion_reconstruct.py @@ -30,7 +30,7 @@ # Configuration # ------------- # We set a few parameters to initialize the Simulation. -# You can safely alter ``n_particles`` (or change the voltages, etc.) when +# You can safely alter ``n_particles`` (or change the defocus values, etc.) when # trying this interactively; the defaults here are chosen for demonstrative purposes. output_dir = Path("relion_save_demo") From 8cb721c9be13101716dcd86747bb5672468912b1 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 12 Dec 2025 10:03:49 -0500 Subject: [PATCH 69/91] float64 amplitude code comment --- src/aspire/source/image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index fa8aa2fe25..19d7ee311c 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -491,6 +491,7 @@ def amplitudes(self): @amplitudes.setter def amplitudes(self, values): + # Keep amplitudes float64 so downstream filters/metadata retain precision. values = np.asarray(values, dtype=np.float64) self.set_metadata("_aspireAmplitude", values) From 333a608332ab7fdcf325c9e6f36ffab7851d8bb6 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 25 Nov 2025 09:05:32 -0500 Subject: [PATCH 70/91] init add faasrot --- src/aspire/image/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/aspire/image/__init__.py b/src/aspire/image/__init__.py index 431dedf87a..5a5a007e6b 100644 --- a/src/aspire/image/__init__.py +++ b/src/aspire/image/__init__.py @@ -16,3 +16,4 @@ SigmaRejectionImageStacker, WinsorizedImageStacker, ) +from .faasrot.py import faasrot From 243358b2cd7f716a56b63e39dde003891557b67f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 25 Nov 2025 09:14:38 -0500 Subject: [PATCH 71/91] cp faasrot dbg code --- src/aspire/image/__init__.py | 2 +- src/aspire/image/faasrot.py | 173 +++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 src/aspire/image/faasrot.py diff --git a/src/aspire/image/__init__.py b/src/aspire/image/__init__.py index 5a5a007e6b..6240be96da 100644 --- a/src/aspire/image/__init__.py +++ b/src/aspire/image/__init__.py @@ -1,3 +1,4 @@ +from .faasrot.py import faasrot from .image import ( BasisImage, BispecImage, @@ -16,4 +17,3 @@ SigmaRejectionImageStacker, WinsorizedImageStacker, ) -from .faasrot.py import faasrot diff --git a/src/aspire/image/faasrot.py b/src/aspire/image/faasrot.py new file mode 100644 index 0000000000..00ad94879b --- /dev/null +++ b/src/aspire/image/faasrot.py @@ -0,0 +1,173 @@ +import numpy as np + +from aspire.numeric import xp + + +def _pre_rotate(theta): + """ + Given angle `theta` (degrees) return nearest rotation of 90 + degrees required to place angle within [-45,45) and residual + rotation in degrees. + """ + + theta = np.mod(theta, 360) + + # 0 < 45 + rot90 = 0 + residual = theta + + if theta >= 45 and theta < 135: + rot90 = 1 + residual = theta - 90 + elif theta >= 135 and theta < 225: + rot90 = 2 + residual = theta - 180 + elif theta >= 215 and theta < 315: + rot90 = 3 + residual = theta - 270 + elif theta >= 315 and theta < 360: + rot90 = 0 + residual = theta - 360 + + return residual, rot90 + + +def _shift_center(n): + """ + Given `n` pixels return center pixel and shift amount, 0 or 1/2. + """ + if n % 2 == 0: + c = n // 2 # center + s = 1 / 2 # shift + else: + c = n // 2 + s = 0 + + return c, s + + +def _pre_compute(theta, nx, ny): + """ + Retuns M = (Mx, My, rot90) + """ + theta, mult90 = _pre_rotate(theta) + + theta = np.pi * theta / 180 + theta = -theta # Yaroslavsky rotated CW + + cy, sy = _shift_center(ny) + cx, sx = _shift_center(nx) + + # Floating point epsilon + eps = np.finfo(np.float64).eps + + # Precompute Y interpolation tables + My = np.zeros((nx, ny), dtype=np.complex128) + r = np.arange(cy + 1, dtype=int) + u = (1 - np.cos(theta)) / np.sin(theta + eps) + # print("u", u) + alpha1 = 2 * np.pi * 1j * r / ny + + # print("alpha1", alpha1) + + linds = np.arange(ny - 1, cy, -1, dtype=int) + # print('aaa', ny-1, cy, -1) + rinds = np.arange(1, cy - 2 * sy + 1, dtype=int) + # print(linds,rinds) + # This can be broadcast, but leaving loop since would be close to CUDA... + for x in range(nx): + Ux = u * (x - cx + sx + 2) + # print("Ux",Ux) + My[x, r] = np.exp(alpha1 * Ux) + My[x, linds] = np.conj(My[x, rinds]) + + # Precompute X interpolation tables + Mx = np.zeros((ny, nx), dtype=np.complex128) + r = np.arange(cx + 1, dtype=int) + u = -np.sin(theta) + alpha2 = 2 * np.pi * 1j * r / nx + + linds = np.arange(nx - 1, cx, -1, dtype=int) + rinds = np.arange(1, cx - 2 * sx + 1, dtype=int) + # This can be broadcast, but leaving loop since would be close to CUDA... + for y in range(ny): + Uy = u * (y - cy + sy + 2) + Mx[y, r] = np.exp(alpha2 * Uy) + Mx[y, linds] = np.conj(Mx[y, rinds]) + + # After building, transpose to (nx, ny). + Mx = Mx.T + + return Mx, My, mult90 + + +def _rot90(img): + return np.flipud(img.T) + + +def _rot180(img): + return np.flipud(np.fliplr(img)) + + +def _rot270(img): + return np.fliplr(img.T) + + +def faastrotate(images, theta, M=None): + + # Make a stack of 1 + if images.ndim == 2: + images = images[None, :, :] + + n, px0, px1 = images.shape + assert px0 == px1, "Currently only implemented for square images." + + if M is None: + M = _pre_compute(theta, px0, px1) + Mx, My, Mrot90 = M + + result = np.empty((n, px0, px1), dtype=np.float64) + + for i in range(n): + + img = images[i] + + # Pre rotate by multiples of 90 + if Mrot90 == 1: + img = _rot90(img) + elif Mrot90 == 2: + img = _rot180(img) + elif Mrot90 == 3: + img = _rot270(img) + + # Shear 1 + img_k = np.fft.fft(img, axis=-1) + # okay print("\nfft1(img_k):\n", img_k,"\n") + print("\nMy:\n", My, "\n") + img_k = img_k * My + print("\nmult (img_k):\n", img_k, "\n") # okay + + # for _i in range(16): + # #print(f'A[{_i}].x = {img_k.flatten()[_i].real};') + # #print(f'A[{_i}].y = {img_k.flatten()[_i].imag};') + # print(f'FA[{_i}] = {img_k.flatten()[_i]};') + + # breakpoint() + result[i] = np.real(np.fft.ifft(img_k, axis=-1)) + print("\nstage1\n", result[i] * 4, "\n") + + # Shear 2 + img_k = np.fft.fft(result[i], axis=0) + img_k = img_k * Mx + result[i] = np.real(np.fft.ifft(img_k, axis=0)) + + print("\nstage2\n", result * 4 * 4) + + # Shear 3 + img_k = np.fft.fft(result[i], axis=-1) + img_k = img_k * My + result[i] = np.real(np.fft.ifft(img_k, axis=-1)) + + print("\nstage3\n", result * 4 * 4 * 4, "\n") + + return result From 77006f76a2abc47bd0253c77b9bfe2f13d057d7f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 25 Nov 2025 10:42:07 -0500 Subject: [PATCH 72/91] initial faasrot code/test add [skip ci] --- src/aspire/image/__init__.py | 2 +- src/aspire/image/faasrot.py | 34 +++++++++++----------------------- src/aspire/image/image.py | 15 +++++++++++++-- tests/test_image.py | 24 +++++++++++++++++++++++- 4 files changed, 48 insertions(+), 27 deletions(-) diff --git a/src/aspire/image/__init__.py b/src/aspire/image/__init__.py index 6240be96da..264656ff56 100644 --- a/src/aspire/image/__init__.py +++ b/src/aspire/image/__init__.py @@ -1,4 +1,4 @@ -from .faasrot.py import faasrot +from .faasrot import faasrotate from .image import ( BasisImage, BispecImage, diff --git a/src/aspire/image/faasrot.py b/src/aspire/image/faasrot.py index 00ad94879b..0b6a7ab71b 100644 --- a/src/aspire/image/faasrot.py +++ b/src/aspire/image/faasrot.py @@ -65,19 +65,13 @@ def _pre_compute(theta, nx, ny): My = np.zeros((nx, ny), dtype=np.complex128) r = np.arange(cy + 1, dtype=int) u = (1 - np.cos(theta)) / np.sin(theta + eps) - # print("u", u) alpha1 = 2 * np.pi * 1j * r / ny - # print("alpha1", alpha1) - linds = np.arange(ny - 1, cy, -1, dtype=int) - # print('aaa', ny-1, cy, -1) rinds = np.arange(1, cy - 2 * sy + 1, dtype=int) - # print(linds,rinds) # This can be broadcast, but leaving loop since would be close to CUDA... for x in range(nx): Ux = u * (x - cx + sx + 2) - # print("Ux",Ux) My[x, r] = np.exp(alpha1 * Ux) My[x, linds] = np.conj(My[x, rinds]) @@ -113,7 +107,17 @@ def _rot270(img): return np.fliplr(img.T) -def faastrotate(images, theta, M=None): +def faasrotate(images, theta, M=None): + """ + Rotate `images` array by `theta` radians ccw. + + :param images: (n , px, px) array of image data + :param theta: rotation angle in radians + :param M: optional precomputed shearing table + :return: (n, px, px) array fo rotated image data + """ + # Convert to degrees + theta = np.rad2deg(theta) # Make a stack of 1 if images.ndim == 2: @@ -142,32 +146,16 @@ def faastrotate(images, theta, M=None): # Shear 1 img_k = np.fft.fft(img, axis=-1) - # okay print("\nfft1(img_k):\n", img_k,"\n") - print("\nMy:\n", My, "\n") img_k = img_k * My - print("\nmult (img_k):\n", img_k, "\n") # okay - - # for _i in range(16): - # #print(f'A[{_i}].x = {img_k.flatten()[_i].real};') - # #print(f'A[{_i}].y = {img_k.flatten()[_i].imag};') - # print(f'FA[{_i}] = {img_k.flatten()[_i]};') - - # breakpoint() result[i] = np.real(np.fft.ifft(img_k, axis=-1)) - print("\nstage1\n", result[i] * 4, "\n") # Shear 2 img_k = np.fft.fft(result[i], axis=0) img_k = img_k * Mx result[i] = np.real(np.fft.ifft(img_k, axis=0)) - print("\nstage2\n", result * 4 * 4) - # Shear 3 img_k = np.fft.fft(result[i], axis=-1) img_k = img_k * My - result[i] = np.real(np.fft.ifft(img_k, axis=-1)) - - print("\nstage3\n", result * 4 * 4 * 4, "\n") return result diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 8212033f72..2f1ed6012f 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -10,6 +10,7 @@ import aspire.sinogram import aspire.volume +from aspire.image import faasrotate from aspire.nufft import anufft, nufft from aspire.numeric import fft, xp from aspire.utils import ( @@ -635,8 +636,18 @@ def filter(self, filter): original_stack_shape ) - def rotate(self): - raise NotImplementedError + def rotate(self, theta): + """ + Rotate by `theta` radians. + """ + original_stack_shape = self.stack_shape + im = self.stack_reshape(-1) + + im = faasrotate(im._data, theta) + + return self.__class__(im, pixel_size=self.pixel_size).stack_reshape( + original_stack_shape + ) def save(self, mrcs_filepath, overwrite=None): """ diff --git a/tests/test_image.py b/tests/test_image.py index 70f5906a64..f788654d93 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -10,9 +10,10 @@ from PIL import Image as PILImage from pytest import raises from scipy.datasets import face +from scipy.ndimage import rotate from aspire.image import Image -from aspire.utils import Rotation, powerset, utest_tolerance +from aspire.utils import Rotation, grid_2d, powerset, utest_tolerance from aspire.volume import CnSymmetryGroup from .test_utils import matplotlib_dry_run @@ -564,3 +565,24 @@ def test_save_load_pixel_size(get_images, dtype): np.testing.assert_almost_equal( im2.pixel_size, im.pixel_size, err_msg="Image pixel_size incorrect save-load" ) + + +def test_faasrotate(get_images, dtype): + im_np, im = get_images + + mask = grid_2d(im_np.shape[-1])["r"] < 1 + + for theta in np.linspace(0, 2 * np.pi, 100): + im_rot = im.rotate(theta) + + # reference to scipy + ref = rotate( + im_np, + np.rad2deg(theta), + reshape=False, + ) + + # mask off ears + masked_diff = (im_rot - ref) * mask + + np.testing.assert_allclose(masked_diff, 0, atol=1e-7) From 9cc71a88e476773c590a5e5d231685e6d312fe04 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 25 Nov 2025 11:38:58 -0500 Subject: [PATCH 73/91] dbg --- src/aspire/image/faasrot.py | 5 +++-- tests/test_image.py | 10 +++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/aspire/image/faasrot.py b/src/aspire/image/faasrot.py index 0b6a7ab71b..9c7a5f465b 100644 --- a/src/aspire/image/faasrot.py +++ b/src/aspire/image/faasrot.py @@ -150,12 +150,13 @@ def faasrotate(images, theta, M=None): result[i] = np.real(np.fft.ifft(img_k, axis=-1)) # Shear 2 - img_k = np.fft.fft(result[i], axis=0) + img_k = np.fft.fft(result[i], axis=-2) img_k = img_k * Mx - result[i] = np.real(np.fft.ifft(img_k, axis=0)) + result[i] = np.real(np.fft.ifft(img_k, axis=-2)) # Shear 3 img_k = np.fft.fft(result[i], axis=-1) img_k = img_k * My + result[i] = np.real(np.fft.ifft(img_k, axis=-1)) return result diff --git a/tests/test_image.py b/tests/test_image.py index f788654d93..608c83bbab 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -572,7 +572,8 @@ def test_faasrotate(get_images, dtype): mask = grid_2d(im_np.shape[-1])["r"] < 1 - for theta in np.linspace(0, 2 * np.pi, 100): + #for theta in np.linspace(0, 2 * np.pi, 100): + for theta in [np.pi/4]: im_rot = im.rotate(theta) # reference to scipy @@ -580,8 +581,15 @@ def test_faasrotate(get_images, dtype): im_np, np.rad2deg(theta), reshape=False, + axes=(-1,-2), ) + peek = np.empty((3, *im_np.shape[-2:])) + peek[0] = im_np + peek[1] = im_rot + peek[2] = ref + Image(peek).show() + # mask off ears masked_diff = (im_rot - ref) * mask From 29ff4a1ecc8643938ed48c8c88d6faa4cf1dc59f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 25 Nov 2025 14:34:18 -0500 Subject: [PATCH 74/91] add smoke test and xp [skip ci] --- src/aspire/image/faasrot.py | 70 +++++++++++++++++++------------------ tests/test_image.py | 30 ++++++++++------ 2 files changed, 56 insertions(+), 44 deletions(-) diff --git a/src/aspire/image/faasrot.py b/src/aspire/image/faasrot.py index 9c7a5f465b..9e46c87afe 100644 --- a/src/aspire/image/faasrot.py +++ b/src/aspire/image/faasrot.py @@ -1,15 +1,18 @@ import numpy as np -from aspire.numeric import xp +from aspire.numeric import fft, xp def _pre_rotate(theta): """ - Given angle `theta` (degrees) return nearest rotation of 90 - degrees required to place angle within [-45,45) and residual - rotation in degrees. + Given angle `theta` (radians) return nearest rotation of 90 + degrees required to place angle within [-45,45) degrees and residual + rotation (radians). """ + # todo + theta = np.rad2deg(theta) + theta = np.mod(theta, 360) # 0 < 45 @@ -29,7 +32,7 @@ def _pre_rotate(theta): rot90 = 0 residual = theta - 360 - return residual, rot90 + return np.deg2rad(residual), rot90 def _shift_center(n): @@ -49,10 +52,11 @@ def _shift_center(n): def _pre_compute(theta, nx, ny): """ Retuns M = (Mx, My, rot90) + + :param theta: angle in radians """ theta, mult90 = _pre_rotate(theta) - theta = np.pi * theta / 180 theta = -theta # Yaroslavsky rotated CW cy, sy = _shift_center(ny) @@ -62,32 +66,32 @@ def _pre_compute(theta, nx, ny): eps = np.finfo(np.float64).eps # Precompute Y interpolation tables - My = np.zeros((nx, ny), dtype=np.complex128) - r = np.arange(cy + 1, dtype=int) - u = (1 - np.cos(theta)) / np.sin(theta + eps) - alpha1 = 2 * np.pi * 1j * r / ny + My = xp.zeros((nx, ny), dtype=xp.complex128) + r = xp.arange(cy + 1, dtype=int) + u = (1 - xp.cos(theta)) / xp.sin(theta + eps) + alpha1 = 2 * xp.pi * 1j * r / ny - linds = np.arange(ny - 1, cy, -1, dtype=int) - rinds = np.arange(1, cy - 2 * sy + 1, dtype=int) + linds = xp.arange(ny - 1, cy, -1, dtype=int) + rinds = xp.arange(1, cy - 2 * sy + 1, dtype=int) # This can be broadcast, but leaving loop since would be close to CUDA... for x in range(nx): Ux = u * (x - cx + sx + 2) - My[x, r] = np.exp(alpha1 * Ux) - My[x, linds] = np.conj(My[x, rinds]) + My[x, r] = xp.exp(alpha1 * Ux) + My[x, linds] = My[x, rinds].conj() # Precompute X interpolation tables - Mx = np.zeros((ny, nx), dtype=np.complex128) - r = np.arange(cx + 1, dtype=int) - u = -np.sin(theta) - alpha2 = 2 * np.pi * 1j * r / nx + Mx = xp.zeros((ny, nx), dtype=xp.complex128) + r = xp.arange(cx + 1, dtype=int) + u = -xp.sin(theta) + alpha2 = 2 * xp.pi * 1j * r / nx - linds = np.arange(nx - 1, cx, -1, dtype=int) - rinds = np.arange(1, cx - 2 * sx + 1, dtype=int) + linds = xp.arange(nx - 1, cx, -1, dtype=int) + rinds = xp.arange(1, cx - 2 * sx + 1, dtype=int) # This can be broadcast, but leaving loop since would be close to CUDA... for y in range(ny): Uy = u * (y - cy + sy + 2) - Mx[y, r] = np.exp(alpha2 * Uy) - Mx[y, linds] = np.conj(Mx[y, rinds]) + Mx[y, r] = xp.exp(alpha2 * Uy) + Mx[y, linds] = Mx[y, rinds].conj() # After building, transpose to (nx, ny). Mx = Mx.T @@ -96,15 +100,15 @@ def _pre_compute(theta, nx, ny): def _rot90(img): - return np.flipud(img.T) + return xp.flipud(img.T) def _rot180(img): - return np.flipud(np.fliplr(img)) + return xp.flipud(xp.fliplr(img)) def _rot270(img): - return np.fliplr(img.T) + return xp.fliplr(img.T) def faasrotate(images, theta, M=None): @@ -116,8 +120,6 @@ def faasrotate(images, theta, M=None): :param M: optional precomputed shearing table :return: (n, px, px) array fo rotated image data """ - # Convert to degrees - theta = np.rad2deg(theta) # Make a stack of 1 if images.ndim == 2: @@ -134,7 +136,7 @@ def faasrotate(images, theta, M=None): for i in range(n): - img = images[i] + img = xp.asarray(images[i]) # Pre rotate by multiples of 90 if Mrot90 == 1: @@ -145,18 +147,18 @@ def faasrotate(images, theta, M=None): img = _rot270(img) # Shear 1 - img_k = np.fft.fft(img, axis=-1) + img_k = fft.fft(img, axis=-1) img_k = img_k * My - result[i] = np.real(np.fft.ifft(img_k, axis=-1)) + result[i] = fft.ifft(img_k, axis=-1).real # Shear 2 - img_k = np.fft.fft(result[i], axis=-2) + img_k = fft.fft(result[i], axis=-2) img_k = img_k * Mx - result[i] = np.real(np.fft.ifft(img_k, axis=-2)) + result[i] = fft.ifft(img_k, axis=-2).real # Shear 3 - img_k = np.fft.fft(result[i], axis=-1) + img_k = fft.fft(result[i], axis=-1) img_k = img_k * My - result[i] = np.real(np.fft.ifft(img_k, axis=-1)) + result[i] = fft.ifft(img_k, axis=-1).real return result diff --git a/tests/test_image.py b/tests/test_image.py index 608c83bbab..0d4e9cbc7f 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -570,10 +570,10 @@ def test_save_load_pixel_size(get_images, dtype): def test_faasrotate(get_images, dtype): im_np, im = get_images - mask = grid_2d(im_np.shape[-1])["r"] < 1 + mask = grid_2d(im_np.shape[-1])["r"] < 0.9 - #for theta in np.linspace(0, 2 * np.pi, 100): - for theta in [np.pi/4]: + for theta in np.linspace(0, 2 * np.pi, 100): + # for theta in [np.pi/4]: im_rot = im.rotate(theta) # reference to scipy @@ -581,16 +581,26 @@ def test_faasrotate(get_images, dtype): im_np, np.rad2deg(theta), reshape=False, - axes=(-1,-2), + axes=(-1, -2), ) - peek = np.empty((3, *im_np.shape[-2:])) - peek[0] = im_np - peek[1] = im_rot - peek[2] = ref - Image(peek).show() + # peek = np.empty((5, *im_np.shape[-2:])) + # peek[0] = im_np + # peek[1] = im_rot + # peek[2] = ref + # peek[3] = im_rot - ref + + # # print('origin', np.sum(np.abs(im_np*mask))) + # # print('im_rot', np.sum(np.abs(im_rot*mask))) + # # print('ref', np.sum(np.abs(ref*mask))) # mask off ears masked_diff = (im_rot - ref) * mask - np.testing.assert_allclose(masked_diff, 0, atol=1e-7) + # #masked_diff[:,mask] = masked_diff.asnumpy()[:,mask] / ref[:,mask] + # #peek[4] = np.nan_to_num(masked_diff) + # peek[4] = masked_diff + # Image(peek*mask).show() + + # mean masked pixel value is ~0.5, so this is ~2% + np.testing.assert_allclose(masked_diff, 0, atol=1) From 98a7101dea24d3fa2b97cd797b7ac396d678167f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 25 Nov 2025 14:43:24 -0500 Subject: [PATCH 75/91] xp cleanup [skip ci] --- src/aspire/image/faasrot.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/aspire/image/faasrot.py b/src/aspire/image/faasrot.py index 9e46c87afe..b2ab00a48e 100644 --- a/src/aspire/image/faasrot.py +++ b/src/aspire/image/faasrot.py @@ -66,31 +66,31 @@ def _pre_compute(theta, nx, ny): eps = np.finfo(np.float64).eps # Precompute Y interpolation tables - My = xp.zeros((nx, ny), dtype=xp.complex128) - r = xp.arange(cy + 1, dtype=int) - u = (1 - xp.cos(theta)) / xp.sin(theta + eps) - alpha1 = 2 * xp.pi * 1j * r / ny + My = np.zeros((nx, ny), dtype=np.complex128) + r = np.arange(cy + 1, dtype=int) + u = (1 - np.cos(theta)) / np.sin(theta + eps) + alpha1 = 2 * np.pi * 1j * r / ny - linds = xp.arange(ny - 1, cy, -1, dtype=int) - rinds = xp.arange(1, cy - 2 * sy + 1, dtype=int) + linds = np.arange(ny - 1, cy, -1, dtype=int) + rinds = np.arange(1, cy - 2 * sy + 1, dtype=int) # This can be broadcast, but leaving loop since would be close to CUDA... for x in range(nx): Ux = u * (x - cx + sx + 2) - My[x, r] = xp.exp(alpha1 * Ux) + My[x, r] = np.exp(alpha1 * Ux) My[x, linds] = My[x, rinds].conj() # Precompute X interpolation tables - Mx = xp.zeros((ny, nx), dtype=xp.complex128) - r = xp.arange(cx + 1, dtype=int) - u = -xp.sin(theta) - alpha2 = 2 * xp.pi * 1j * r / nx + Mx = np.zeros((ny, nx), dtype=np.complex128) + r = np.arange(cx + 1, dtype=int) + u = -np.sin(theta) + alpha2 = 2 * np.pi * 1j * r / nx - linds = xp.arange(nx - 1, cx, -1, dtype=int) - rinds = xp.arange(1, cx - 2 * sx + 1, dtype=int) + linds = np.arange(nx - 1, cx, -1, dtype=int) + rinds = np.arange(1, cx - 2 * sx + 1, dtype=int) # This can be broadcast, but leaving loop since would be close to CUDA... for y in range(ny): Uy = u * (y - cy + sy + 2) - Mx[y, r] = xp.exp(alpha2 * Uy) + Mx[y, r] = np.exp(alpha2 * Uy) Mx[y, linds] = Mx[y, rinds].conj() # After building, transpose to (nx, ny). @@ -132,8 +132,9 @@ def faasrotate(images, theta, M=None): M = _pre_compute(theta, px0, px1) Mx, My, Mrot90 = M - result = np.empty((n, px0, px1), dtype=np.float64) + Mx, My = xp.asarray(Mx), xp.asarray(My) + result = xp.empty((n, px0, px1), dtype=np.float64) for i in range(n): img = xp.asarray(images[i]) @@ -161,4 +162,4 @@ def faasrotate(images, theta, M=None): img_k = img_k * My result[i] = fft.ifft(img_k, axis=-1).real - return result + return xp.asnumpy(result) From 04efe7527cb2a4c74034fcdbabc5dfee2137a2a7 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 Dec 2025 15:49:44 -0500 Subject: [PATCH 76/91] cleanup and extension --- src/aspire/image/__init__.py | 2 +- .../image/{faasrot.py => fastrotate.py} | 20 ++++-- src/aspire/image/image.py | 61 ++++++++++++++++++- 3 files changed, 74 insertions(+), 9 deletions(-) rename src/aspire/image/{faasrot.py => fastrotate.py} (83%) diff --git a/src/aspire/image/__init__.py b/src/aspire/image/__init__.py index 264656ff56..ff01f2d553 100644 --- a/src/aspire/image/__init__.py +++ b/src/aspire/image/__init__.py @@ -1,4 +1,4 @@ -from .faasrot import faasrotate +from .fastrotate import compute_fastrotate_interp_tables, fastrotate from .image import ( BasisImage, BispecImage, diff --git a/src/aspire/image/faasrot.py b/src/aspire/image/fastrotate.py similarity index 83% rename from src/aspire/image/faasrot.py rename to src/aspire/image/fastrotate.py index b2ab00a48e..52092e735e 100644 --- a/src/aspire/image/faasrot.py +++ b/src/aspire/image/fastrotate.py @@ -49,11 +49,13 @@ def _shift_center(n): return c, s -def _pre_compute(theta, nx, ny): +def compute_fastrotate_interp_tables(theta, nx, ny): """ Retuns M = (Mx, My, rot90) :param theta: angle in radians + :param nx: Number pixels first axis + :param ny: Number pixels second axis """ theta, mult90 = _pre_rotate(theta) @@ -111,14 +113,22 @@ def _rot270(img): return xp.fliplr(img.T) -def faasrotate(images, theta, M=None): +def fastrotate(images, theta, M=None): """ - Rotate `images` array by `theta` radians ccw. + Rotate `images` array by `theta` radians ccw using shearing algorithm. + + Note that this algorithm may have artifacts near the rotation boundary + and will have artifacts outside the rotation boundary. + Users can avoid these by zero padding the input image then + cropping the rotated image and/or masking. + + For reference and notes: + `https://github.com/PrincetonUniversity/aspire/blob/760a43b35453e55ff2d9354339e9ffa109a25371/common/fastrotate/fastrotate.m` :param images: (n , px, px) array of image data :param theta: rotation angle in radians :param M: optional precomputed shearing table - :return: (n, px, px) array fo rotated image data + :return: (n, px, px) array of rotated image data """ # Make a stack of 1 @@ -129,7 +139,7 @@ def faasrotate(images, theta, M=None): assert px0 == px1, "Currently only implemented for square images." if M is None: - M = _pre_compute(theta, px0, px1) + M = compute_fastrotate_interp_tables(theta, px0, px1) Mx, My, Mrot90 = M Mx, My = xp.asarray(Mx), xp.asarray(My) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 2f1ed6012f..9bdae9f7f5 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -10,7 +10,7 @@ import aspire.sinogram import aspire.volume -from aspire.image import faasrotate +from aspire.image import fastrotate from aspire.nufft import anufft, nufft from aspire.numeric import fft, xp from aspire.utils import ( @@ -159,6 +159,8 @@ class Image: ".tif": load_tiff, ".tiff": load_tiff, } + # Available image rotation functions + rotation_methods = {"fastrotate": fastrotate} def __init__(self, data, pixel_size=None, dtype=None): """ @@ -636,15 +638,68 @@ def filter(self, filter): original_stack_shape ) - def rotate(self, theta): + def rotate(self, theta, method="fastrotate", mask=1): """ Rotate by `theta` radians. + + :param theta: Scalar or array of length `n_images` + :param mask: Optional scalar or array mask matching `Image` shape. + Scalar will create a circular mask of prescribed radius `(0,1]`. + Array mask will be applied via elementwise multiplication. + `None` disables masking. + :returns: `Image` containing Rotated image data. """ original_stack_shape = self.stack_shape im = self.stack_reshape(-1) - im = faasrotate(im._data, theta) + # Resolve rotation method + if method not in self.rotation_methods: + raise NotImplementedError( + f"Requested `Image.rotation` method={method} not found." + f" Select from {self.rotation_methods.keys()}" + ) + # otherwise, assign the function + rotation_function = self.rotation_methods[method] + + # Handle both scalar and arrays of rotation angles. + # `theta` arrays are checked to match length of images when stacks axis are flattened. + theta = np.array(theta).flatten() + if len(theta) == 1: + im = rotation_function(im._data, theta) + elif len(theta) == im.n_images: + rot_im = np.empty_like(im._data) + for i in range(im.n_images): + rot_im[i] = rotation_function(im._data[i], theta[i]) + im = rot_im + else: + raise RuntimeError( + f"Length of `theta` {len(theta)} and `Image` data {im.n_images} inconsistent." + ) + + # Masking, scalar case + if mask is not None: + if np.size(mask) == 1: + # Confirm `mask` value is a sane radius + if not (0 < mask <= 1): + raise ValueError( + f"Mask radius must be scalar between (0,1]. Received {mask}" + ) + # Construct a boolean `mask` to apply in next code block as a 2D `mask` + mask = ( + grid_2d(im.shape[-1], normalized=True, dtype=np.float64)["r"] < mask + ) + mask = mask.astype(im.dtype) + + # Masking, 2D case + # Confirm `mask` size is consistent + if mask.shape == im.shape[-2:]: + im = im * mask[None, :, :] + else: + raise RuntimeError( + f"Shape of `mask` {mask.shape} inconsistent with `Image` data shape {im.shape[-2:]}" + ) + # Restore original stack shape and metadata. return self.__class__(im, pixel_size=self.pixel_size).stack_reshape( original_stack_shape ) From ac9a4459e9ecb8ddef700dd938ed99be545c9c47 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 5 Dec 2025 11:07:59 -0500 Subject: [PATCH 77/91] polishing fastrotate --- src/aspire/image/fastrotate.py | 144 +++++++++++++++++++-------------- 1 file changed, 82 insertions(+), 62 deletions(-) diff --git a/src/aspire/image/fastrotate.py b/src/aspire/image/fastrotate.py index 52092e735e..37eaa4e1b8 100644 --- a/src/aspire/image/fastrotate.py +++ b/src/aspire/image/fastrotate.py @@ -5,39 +5,46 @@ def _pre_rotate(theta): """ - Given angle `theta` (radians) return nearest rotation of 90 - degrees required to place angle within [-45,45) degrees and residual - rotation (radians). + Given `theta` radians return nearest rotation of pi/2 + required to place angle within [-pi/4,pi/4) and the residual + rotation in radians. + + :param theta: Rotation in radians + :returns: + - Residual angle in radians + - Number of pi/2 rotations """ - # todo - theta = np.rad2deg(theta) + theta = np.mod(theta, 2 * np.pi) - theta = np.mod(theta, 360) - - # 0 < 45 - rot90 = 0 + # 0 < pi/4 + rots = 0 residual = theta - if theta >= 45 and theta < 135: - rot90 = 1 - residual = theta - 90 - elif theta >= 135 and theta < 225: - rot90 = 2 - residual = theta - 180 - elif theta >= 215 and theta < 315: - rot90 = 3 - residual = theta - 270 - elif theta >= 315 and theta < 360: - rot90 = 0 - residual = theta - 360 + if theta >= np.pi / 4 and theta < 3 * np.pi / 4: + rots = 1 + residual = theta - np.pi / 2 + elif theta >= 3 * np.pi / 4 and theta < 5 * np.pi / 4: + rots = 2 + residual = theta - np.pi + elif theta >= 5 * np.pi / 4 and theta < 7 * np.pi / 4: + rots = 3 + residual = theta - 3 * np.pi / 2 + elif theta >= 7 * np.pi / 4 and theta < 2 * np.pi: + rots = 0 + residual = theta - 2 * np.pi - return np.deg2rad(residual), rot90 + return residual, rots def _shift_center(n): """ Given `n` pixels return center pixel and shift amount, 0 or 1/2. + + :param n: Number of pixels + :returns: + - center pixel + - shift amount """ if n % 2 == 0: c = n // 2 # center @@ -51,7 +58,7 @@ def _shift_center(n): def compute_fastrotate_interp_tables(theta, nx, ny): """ - Retuns M = (Mx, My, rot90) + Retuns iterpolation tables as tuple M = (Mx, My, rots). :param theta: angle in radians :param nx: Number pixels first axis @@ -59,7 +66,8 @@ def compute_fastrotate_interp_tables(theta, nx, ny): """ theta, mult90 = _pre_rotate(theta) - theta = -theta # Yaroslavsky rotated CW + # Reverse rotation, Yaroslavsky rotated CW + theta = -theta cy, sy = _shift_center(ny) cx, sx = _shift_center(nx) @@ -75,11 +83,10 @@ def compute_fastrotate_interp_tables(theta, nx, ny): linds = np.arange(ny - 1, cy, -1, dtype=int) rinds = np.arange(1, cy - 2 * sy + 1, dtype=int) - # This can be broadcast, but leaving loop since would be close to CUDA... - for x in range(nx): - Ux = u * (x - cx + sx + 2) - My[x, r] = np.exp(alpha1 * Ux) - My[x, linds] = My[x, rinds].conj() + + Ux = u * (np.arange(nx) - cx + sx + 2) + My[:, r] = np.exp(alpha1[None, :] * Ux[:, None]) + My[:, linds] = My[:, rinds].conj() # Precompute X interpolation tables Mx = np.zeros((ny, nx), dtype=np.complex128) @@ -89,11 +96,10 @@ def compute_fastrotate_interp_tables(theta, nx, ny): linds = np.arange(nx - 1, cx, -1, dtype=int) rinds = np.arange(1, cx - 2 * sx + 1, dtype=int) - # This can be broadcast, but leaving loop since would be close to CUDA... - for y in range(ny): - Uy = u * (y - cy + sy + 2) - Mx[y, r] = np.exp(alpha2 * Uy) - Mx[y, linds] = Mx[y, rinds].conj() + + Uy = u * (np.arange(ny) - cy + sy + 2) + Mx[:, r] = np.exp(alpha2[None, :] * Uy[:, None]) + Mx[:, linds] = Mx[:, rinds].conj() # After building, transpose to (nx, ny). Mx = Mx.T @@ -101,16 +107,25 @@ def compute_fastrotate_interp_tables(theta, nx, ny): return Mx, My, mult90 +# The following helper utilities are written to work with +# `img` data of dimension 2 or more where the data is expected to be +# in the (-2,-1) dimensions with any other dims as stack axes. def _rot90(img): - return xp.flipud(img.T) + """Rotate image array by 90 degrees.""" + # stack broadcast of flipud(img.T) + return xp.flip(xp.swapaxes(img, -1, -2), axis=-2) def _rot180(img): - return xp.flipud(xp.fliplr(img)) + """Rotate image array by 180 degrees.""" + # stack broadcast of flipud(fliplr) + return xp.flip(img, axis=(-1, -2)) def _rot270(img): - return xp.fliplr(img.T) + """Rotate image array by 90 degrees.""" + # stack broadcast of fliplr(img.T) + return xp.flip(xp.swapaxes(img, -1, -2), axis=-1) def fastrotate(images, theta, M=None): @@ -140,36 +155,41 @@ def fastrotate(images, theta, M=None): if M is None: M = compute_fastrotate_interp_tables(theta, px0, px1) - Mx, My, Mrot90 = M + Mx, My, Mrots = M + + Mx, My = xp.asarray(Mx, dtype=images.dtype), xp.asarray(My, dtype=images.dtype) - Mx, My = xp.asarray(Mx), xp.asarray(My) + # Store if `images` data was provide on host (np.darray) + _host = isinstance(images, np.ndarray) - result = xp.empty((n, px0, px1), dtype=np.float64) - for i in range(n): + # If needed copy image array to device + images = xp.asarray(images) - img = xp.asarray(images[i]) + # Pre rotate by multiples of 90 (pi/2) + if Mrots == 1: + images = _rot90(images) + elif Mrots == 2: + images = _rot180(images) + elif Mrots == 3: + images = _rot270(images) - # Pre rotate by multiples of 90 - if Mrot90 == 1: - img = _rot90(img) - elif Mrot90 == 2: - img = _rot180(img) - elif Mrot90 == 3: - img = _rot270(img) + # Shear 1 + img_k = fft.fft(images, axis=-1) + img_k = img_k * My + images = fft.ifft(img_k, axis=-1).real - # Shear 1 - img_k = fft.fft(img, axis=-1) - img_k = img_k * My - result[i] = fft.ifft(img_k, axis=-1).real + # Shear 2 + img_k = fft.fft(images, axis=-2) + img_k = img_k * Mx + images = fft.ifft(img_k, axis=-2).real - # Shear 2 - img_k = fft.fft(result[i], axis=-2) - img_k = img_k * Mx - result[i] = fft.ifft(img_k, axis=-2).real + # Shear 3 + img_k = fft.fft(images, axis=-1) + img_k = img_k * My + images = fft.ifft(img_k, axis=-1).real - # Shear 3 - img_k = fft.fft(result[i], axis=-1) - img_k = img_k * My - result[i] = fft.ifft(img_k, axis=-1).real + # Return to host if needed + if _host: + images = xp.asnumpy(images) - return xp.asnumpy(result) + return images From 433f916a6b3dee6aa11ba60d4546411998f0eabf Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 5 Dec 2025 11:20:56 -0500 Subject: [PATCH 78/91] first pass polish Image.rotate --- src/aspire/image/image.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 9bdae9f7f5..c9734b72e8 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -638,17 +638,25 @@ def filter(self, filter): original_stack_shape ) - def rotate(self, theta, method="fastrotate", mask=1): + def rotate(self, theta, method="fastrotate", mask=1, **kwargs): """ - Rotate by `theta` radians. + Return `Image` rotated by `theta` radians using `method`. + + Optionally applies `mask`. Note that some methods may + introduce edge artifacts, in which case users may consider + using a tighter mask (eg 0.9) or a combination of pad-crop. + + Any additional kwargs will be passed to `method`. :param theta: Scalar or array of length `n_images` :param mask: Optional scalar or array mask matching `Image` shape. Scalar will create a circular mask of prescribed radius `(0,1]`. Array mask will be applied via elementwise multiplication. `None` disables masking. - :returns: `Image` containing Rotated image data. + :param method: Optionally specify a rotation method. + :return: `Image` containing rotated image data. """ + original_stack_shape = self.stack_shape im = self.stack_reshape(-1) @@ -658,18 +666,19 @@ def rotate(self, theta, method="fastrotate", mask=1): f"Requested `Image.rotation` method={method} not found." f" Select from {self.rotation_methods.keys()}" ) - # otherwise, assign the function + # Assign the rotation method's function + # Any rotation method is expected to handle image data as a 2D array or 3D array (single stack axis). rotation_function = self.rotation_methods[method] # Handle both scalar and arrays of rotation angles. # `theta` arrays are checked to match length of images when stacks axis are flattened. theta = np.array(theta).flatten() if len(theta) == 1: - im = rotation_function(im._data, theta) + im = rotation_function(im._data, theta, **kwargs) elif len(theta) == im.n_images: rot_im = np.empty_like(im._data) for i in range(im.n_images): - rot_im[i] = rotation_function(im._data[i], theta[i]) + rot_im[i] = rotation_function(im._data[i], theta[i], **kwargs) im = rot_im else: raise RuntimeError( From 597e4ae895ffb888b03bca4a10954245666329f6 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 5 Dec 2025 11:54:36 -0500 Subject: [PATCH 79/91] initial add scipy.ndimage.rotate [skip ci] --- src/aspire/image/__init__.py | 9 ++++++++- src/aspire/image/image.py | 4 ++-- src/aspire/image/{fastrotate.py => rotation.py} | 17 +++++++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) rename src/aspire/image/{fastrotate.py => rotation.py} (90%) diff --git a/src/aspire/image/__init__.py b/src/aspire/image/__init__.py index ff01f2d553..442526fa54 100644 --- a/src/aspire/image/__init__.py +++ b/src/aspire/image/__init__.py @@ -1,4 +1,11 @@ -from .fastrotate import compute_fastrotate_interp_tables, fastrotate +# isort: off +from .rotation import ( + compute_fastrotate_interp_tables, + fastrotate, + sp_rotate, +) + +# isort: on from .image import ( BasisImage, BispecImage, diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index c9734b72e8..9ba54a66a4 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -10,7 +10,7 @@ import aspire.sinogram import aspire.volume -from aspire.image import fastrotate +from aspire.image import fastrotate, sp_rotate from aspire.nufft import anufft, nufft from aspire.numeric import fft, xp from aspire.utils import ( @@ -160,7 +160,7 @@ class Image: ".tiff": load_tiff, } # Available image rotation functions - rotation_methods = {"fastrotate": fastrotate} + rotation_methods = {"fastrotate": fastrotate, "scipy": sp_rotate} def __init__(self, data, pixel_size=None, dtype=None): """ diff --git a/src/aspire/image/fastrotate.py b/src/aspire/image/rotation.py similarity index 90% rename from src/aspire/image/fastrotate.py rename to src/aspire/image/rotation.py index 37eaa4e1b8..268d73cc43 100644 --- a/src/aspire/image/fastrotate.py +++ b/src/aspire/image/rotation.py @@ -1,4 +1,5 @@ import numpy as np +from scipy import ndimage from aspire.numeric import fft, xp @@ -193,3 +194,19 @@ def fastrotate(images, theta, M=None): images = xp.asnumpy(images) return images + + +def sp_rotate(im, theta, **kwargs): + """Utility wrapper to form a ASPIRE compatible call to Scipy's image rotation. + + Converts `theta` from radian to degrees. + Defines image axes and reshape behavior. + + Additional kwargs will be passed through. + See scipy.ndimage.rotate + + :param im: Array of image data shape (L,L) or (n,L, L) + :param theta: Rotation in ccw radians. + :return: Array representing rotated `im`. + """ + return ndimage.rotate(im, np.rad2deg(theta), reshape=False, axes=(-1, -2), **kwargs) From 61ce3911a743f8aa4beb8c0a190e6aa4fbc0a1d7 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 5 Dec 2025 14:57:25 -0500 Subject: [PATCH 80/91] extend test, but still need to improve it --- src/aspire/image/rotation.py | 37 +++++++++++++++++++++++---- tests/test_image.py | 48 ++++++++++++++---------------------- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py index 268d73cc43..2649812f9c 100644 --- a/src/aspire/image/rotation.py +++ b/src/aspire/image/rotation.py @@ -196,17 +196,44 @@ def fastrotate(images, theta, M=None): return images -def sp_rotate(im, theta, **kwargs): +def sp_rotate(img, theta, **kwargs): """Utility wrapper to form a ASPIRE compatible call to Scipy's image rotation. Converts `theta` from radian to degrees. - Defines image axes and reshape behavior. + Defines stack/image axes and reshape behavior. + Image data is expected to be in last two axes in all cases. Additional kwargs will be passed through. See scipy.ndimage.rotate - :param im: Array of image data shape (L,L) or (n,L, L) + :param img: Array of image data shape (L,L) or (...,L, L) :param theta: Rotation in ccw radians. - :return: Array representing rotated `im`. + :return: Array representing rotated `img`. """ - return ndimage.rotate(im, np.rad2deg(theta), reshape=False, axes=(-1, -2), **kwargs) + + # Store original shape + original_shape = img.shape + # Image data is expected to be in last two axis in all cases + # Flatten, converts all inputs to consistent 3D shape (single stack axis). + img = img.reshape(-1, *img.shape[-2:]) + + # Scipy accepts a single scalar theta in degrees. + # Handle array of thetas and scalar case by expanding to flat array of img.shape + # Flatten all inputs + theta = np.rad2deg(np.array(theta)).reshape(-1, 1) + # Expand scalar input + if theta.shape[0] == 1: + theta = np.full(img.shape[0], theta, img.dtype) + # Check we have an array matching `img` + if theta.shape != img.shape[:1]: + raise RuntimeError("Inconsistent `theta` and `img` shapes.") + + # Create result array and rotate images via loop + result = np.empty_like(img) + for i in range(img.shape[0]): + result[i] = ndimage.rotate( + img[i], theta[i], reshape=False, axes=(-2, -1), **kwargs + ) + + # Restore original shape + return result.reshape(*original_shape) diff --git a/tests/test_image.py b/tests/test_image.py index 0d4e9cbc7f..f178175e3c 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -567,40 +567,28 @@ def test_save_load_pixel_size(get_images, dtype): ) -def test_faasrotate(get_images, dtype): +@pytest.fixture( + params=Image.rotation_methods, ids=lambda x: f"method={x}", scope="module" +) +def rotation_method(request): + return request.param + + +# TODO, lets replace this with an analytic test +def test_image_rotate(get_images, dtype, rotation_method): im_np, im = get_images - mask = grid_2d(im_np.shape[-1])["r"] < 0.9 + # mask = grid_2d(im_np.shape[-1])["r"] < 0.9 for theta in np.linspace(0, 2 * np.pi, 100): # for theta in [np.pi/4]: - im_rot = im.rotate(theta) - - # reference to scipy - ref = rotate( - im_np, - np.rad2deg(theta), - reshape=False, - axes=(-1, -2), - ) - - # peek = np.empty((5, *im_np.shape[-2:])) - # peek[0] = im_np - # peek[1] = im_rot - # peek[2] = ref - # peek[3] = im_rot - ref - - # # print('origin', np.sum(np.abs(im_np*mask))) - # # print('im_rot', np.sum(np.abs(im_rot*mask))) - # # print('ref', np.sum(np.abs(ref*mask))) - - # mask off ears - masked_diff = (im_rot - ref) * mask - - # #masked_diff[:,mask] = masked_diff.asnumpy()[:,mask] / ref[:,mask] - # #peek[4] = np.nan_to_num(masked_diff) - # peek[4] = masked_diff - # Image(peek*mask).show() + im_rot = im.rotate(theta, method=rotation_method) + + # Use manual call to PIL as reference + ref = np.asarray(PILImage.fromarray(im_np[0]).rotate(np.rad2deg(theta))) + + # masked_diff = (im_rot - ref) * mask + diff = im_rot - ref # mean masked pixel value is ~0.5, so this is ~2% - np.testing.assert_allclose(masked_diff, 0, atol=1) + np.testing.assert_allclose(diff, 0, atol=1) From d7c8194ee5f1d5589c7c06f2af366814131965df Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 8 Dec 2025 09:56:11 -0500 Subject: [PATCH 81/91] rm unused import --- tests/test_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_image.py b/tests/test_image.py index f178175e3c..a8eea5b8c0 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -10,7 +10,6 @@ from PIL import Image as PILImage from pytest import raises from scipy.datasets import face -from scipy.ndimage import rotate from aspire.image import Image from aspire.utils import Rotation, grid_2d, powerset, utest_tolerance From a6bb28147e99d15ef6a247e5e1156c4251a5a8b2 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 9 Dec 2025 07:54:49 -0500 Subject: [PATCH 82/91] more analytic image rot test --- tests/test_image.py | 67 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 53 insertions(+), 14 deletions(-) diff --git a/tests/test_image.py b/tests/test_image.py index a8eea5b8c0..793bf25245 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -12,7 +12,7 @@ from scipy.datasets import face from aspire.image import Image -from aspire.utils import Rotation, grid_2d, powerset, utest_tolerance +from aspire.utils import Rotation, gaussian_2d, grid_2d, powerset, utest_tolerance from aspire.volume import CnSymmetryGroup from .test_utils import matplotlib_dry_run @@ -573,21 +573,60 @@ def rotation_method(request): return request.param -# TODO, lets replace this with an analytic test -def test_image_rotate(get_images, dtype, rotation_method): - im_np, im = get_images +def test_image_rotate(dtype, rotation_method): + """ + Compare image rotations against rotated gaussian blobs. + """ - # mask = grid_2d(im_np.shape[-1])["r"] < 0.9 + def _gen_image(angle, L, K=10): + """ + Generate a sequence of `K` gaussian blobs rotated by `angle`. - for theta in np.linspace(0, 2 * np.pi, 100): - # for theta in [np.pi/4]: - im_rot = im.rotate(theta, method=rotation_method) + Return tuple of unrotated and rotated image arrays. + + :param angle: rotation angle + :param L: size (L-by-L) in pixels + :param K: Number of blobs + :return: + - Array of unrotated data (float64) + - Array of rotated data (float64) + """ + + im = np.zeros((L, L), dtype=np.float64) + rotated_im = np.zeros_like(im) + + centers = np.random.randint(-L // 4, L // 4, size=(10, 2)) + sigmas = np.full((K, 2), L / 10, dtype=np.float64) + + # Rotate the gaussian specifications + R = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) - # Use manual call to PIL as reference - ref = np.asarray(PILImage.fromarray(im_np[0]).rotate(np.rad2deg(theta))) + rotated_centers = centers @ R + + for center, sigma in zip(centers, sigmas): + im[:] = im[:] + gaussian_2d(L, center, sigma, dtype=np.float64) + + for center, sigma in zip(rotated_centers, sigmas): + rotated_im[:] = rotated_im[:] + gaussian_2d( + L, center, sigma, dtype=np.float64 + ) + + return im, rotated_im + + L = 129 # Test image size in pixels + # Create mask, zeros edge artifacts + mask = grid_2d(L, normalized=True)["r"] < 0.9 + + # for theta in np.linspace(0, 2 * np.pi, 100): + for theta in [np.pi / 4]: + im, ref = _gen_image(theta, L) + im = Image(im.astype(dtype, copy=False)) + + # Rotate using `Image` method + im_rot = im.rotate(theta, method=rotation_method) - # masked_diff = (im_rot - ref) * mask - diff = im_rot - ref + masked_diff = (im_rot - ref) * mask - # mean masked pixel value is ~0.5, so this is ~2% - np.testing.assert_allclose(diff, 0, atol=1) + # Compute L1 error of masked diff + L1_error = np.mean(np.abs(masked_diff)) + np.testing.assert_array_less(L1_error, 1e-6) From 95b773b3bfc87a10a9dbd813dc01eb698b6a830f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 9 Dec 2025 07:55:00 -0500 Subject: [PATCH 83/91] comment cleanup --- src/aspire/image/rotation.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py index 2649812f9c..426c28b96c 100644 --- a/src/aspire/image/rotation.py +++ b/src/aspire/image/rotation.py @@ -2,6 +2,7 @@ from scipy import ndimage from aspire.numeric import fft, xp +from aspire.utils import complex_type def _pre_rotate(theta): @@ -158,12 +159,14 @@ def fastrotate(images, theta, M=None): M = compute_fastrotate_interp_tables(theta, px0, px1) Mx, My, Mrots = M - Mx, My = xp.asarray(Mx, dtype=images.dtype), xp.asarray(My, dtype=images.dtype) + # Cast interp tables to match precision of `images` + Mx = xp.asarray(Mx, complex_type(images.dtype)) + My = xp.asarray(My, complex_type(images.dtype)) - # Store if `images` data was provide on host (np.darray) + # Determine if `images` data was provided on host (np.darray) _host = isinstance(images, np.ndarray) - # If needed copy image array to device + # Copy image array to device if needed images = xp.asarray(images) # Pre rotate by multiples of 90 (pi/2) @@ -189,7 +192,7 @@ def fastrotate(images, theta, M=None): img_k = img_k * My images = fft.ifft(img_k, axis=-1).real - # Return to host if needed + # Return to host if input was provided on host if _host: images = xp.asnumpy(images) From 7b9700a487526b990463e033f5461d74cb3990c6 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 9 Dec 2025 09:42:49 -0500 Subject: [PATCH 84/91] cleanup --- tests/test_image.py | 58 ++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/tests/test_image.py b/tests/test_image.py index 793bf25245..7206998a4c 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -578,11 +578,18 @@ def test_image_rotate(dtype, rotation_method): Compare image rotations against rotated gaussian blobs. """ - def _gen_image(angle, L, K=10): + L = 129 # Test image size in pixels + num_test_angles = 42 + # Create mask, used to zero edge artifacts + mask = grid_2d(L, normalized=True)["r"] < 0.9 + + def _gen_image(angle, L, n=1, K=10): """ - Generate a sequence of `K` gaussian blobs rotated by `angle`. + Generate `n` `L-by-L` image arrays, + each constructed by a sequence of `K` gaussian blobs, + and reference images with the blob centers rotated by `angle`. - Return tuple of unrotated and rotated image arrays. + Return tuple of unrotated and rotated image arrays (n-by-L-by-L). :param angle: rotation angle :param L: size (L-by-L) in pixels @@ -592,41 +599,44 @@ def _gen_image(angle, L, K=10): - Array of rotated data (float64) """ - im = np.zeros((L, L), dtype=np.float64) + im = np.zeros((n, L, L), dtype=np.float64) rotated_im = np.zeros_like(im) - centers = np.random.randint(-L // 4, L // 4, size=(10, 2)) - sigmas = np.full((K, 2), L / 10, dtype=np.float64) + centers = np.random.randint(-L // 4, L // 4, size=(n, 10, 2)) + sigmas = np.full((n, K, 2), L / 10, dtype=np.float64) # Rotate the gaussian specifications R = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) - rotated_centers = centers @ R - for center, sigma in zip(centers, sigmas): - im[:] = im[:] + gaussian_2d(L, center, sigma, dtype=np.float64) + # Construct each image independently + for i in range(n): + for center, sigma in zip(centers[i], sigmas[i]): + im[i] = im[i] + gaussian_2d(L, center, sigma, dtype=np.float64) - for center, sigma in zip(rotated_centers, sigmas): - rotated_im[:] = rotated_im[:] + gaussian_2d( - L, center, sigma, dtype=np.float64 - ) + for center, sigma in zip(rotated_centers[i], sigmas[i]): + rotated_im[i] = rotated_im[i] + gaussian_2d( + L, center, sigma, dtype=np.float64 + ) return im, rotated_im - L = 129 # Test image size in pixels - # Create mask, zeros edge artifacts - mask = grid_2d(L, normalized=True)["r"] < 0.9 - - # for theta in np.linspace(0, 2 * np.pi, 100): - for theta in [np.pi / 4]: - im, ref = _gen_image(theta, L) + # Test over a variety of angles `theta` + for theta in np.linspace(0, 2 * np.pi, num_test_angles): + # Generate images and reference (`theta`) rotated images + im, ref = _gen_image(theta, L, n=3) im = Image(im.astype(dtype, copy=False)) - # Rotate using `Image` method + # Rotate using `Image`'s `rotation_method` im_rot = im.rotate(theta, method=rotation_method) + # Mask off boundary artifacts masked_diff = (im_rot - ref) * mask - # Compute L1 error of masked diff - L1_error = np.mean(np.abs(masked_diff)) - np.testing.assert_array_less(L1_error, 1e-6) + # Compute L1 error of masked diff, per image + L1_error = np.mean(np.abs(masked_diff), axis=(-1, -2)) + np.testing.assert_array_less( + L1_error, + 0.1, + err_msg=f"{L} pixels using {rotation_method} @ {theta} radians", + ) From d25a66abc1038ba12727b45e031af3f1730f113a Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 9 Dec 2025 10:24:56 -0500 Subject: [PATCH 85/91] more cleanup --- src/aspire/image/image.py | 3 ++- src/aspire/image/rotation.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 9ba54a66a4..071437a1ef 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -638,7 +638,7 @@ def filter(self, filter): original_stack_shape ) - def rotate(self, theta, method="fastrotate", mask=1, **kwargs): + def rotate(self, theta, method="scipy", mask=1, **kwargs): """ Return `Image` rotated by `theta` radians using `method`. @@ -654,6 +654,7 @@ def rotate(self, theta, method="fastrotate", mask=1, **kwargs): Array mask will be applied via elementwise multiplication. `None` disables masking. :param method: Optionally specify a rotation method. + Defaults to `scipy`. :return: `Image` containing rotated image data. """ diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py index 426c28b96c..043fbfde72 100644 --- a/src/aspire/image/rotation.py +++ b/src/aspire/image/rotation.py @@ -125,7 +125,7 @@ def _rot180(img): def _rot270(img): - """Rotate image array by 90 degrees.""" + """Rotate image array by 270 degrees.""" # stack broadcast of fliplr(img.T) return xp.flip(xp.swapaxes(img, -1, -2), axis=-1) @@ -200,7 +200,8 @@ def fastrotate(images, theta, M=None): def sp_rotate(img, theta, **kwargs): - """Utility wrapper to form a ASPIRE compatible call to Scipy's image rotation. + """ + Utility wrapper to form a ASPIRE compatible call to Scipy's image rotation. Converts `theta` from radian to degrees. Defines stack/image axes and reshape behavior. From e665023d671716096636ee5dd0ef6a1dab84a209 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 10 Dec 2025 10:13:24 -0500 Subject: [PATCH 86/91] attempt to fix theta bcast bug --- src/aspire/image/rotation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py index 043fbfde72..acce553c14 100644 --- a/src/aspire/image/rotation.py +++ b/src/aspire/image/rotation.py @@ -223,13 +223,13 @@ def sp_rotate(img, theta, **kwargs): # Scipy accepts a single scalar theta in degrees. # Handle array of thetas and scalar case by expanding to flat array of img.shape - # Flatten all inputs + # Flatten all inputs, becomes 2d stack theta = np.rad2deg(np.array(theta)).reshape(-1, 1) - # Expand scalar input - if theta.shape[0] == 1: + # Expand single scalar input + if np.size(theta) == 1: theta = np.full(img.shape[0], theta, img.dtype) - # Check we have an array matching `img` - if theta.shape != img.shape[:1]: + # Check we have an array matching `img`, both should be (n,1) + if theta.shape[0] != img.shape[0]: raise RuntimeError("Inconsistent `theta` and `img` shapes.") # Create result array and rotate images via loop From 59f86664499dde706a8a257bd9c94b04bf87e961 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 11 Dec 2025 13:57:26 -0500 Subject: [PATCH 87/91] actually fix the bug this time --- src/aspire/image/rotation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py index acce553c14..cd8211b98a 100644 --- a/src/aspire/image/rotation.py +++ b/src/aspire/image/rotation.py @@ -223,12 +223,12 @@ def sp_rotate(img, theta, **kwargs): # Scipy accepts a single scalar theta in degrees. # Handle array of thetas and scalar case by expanding to flat array of img.shape - # Flatten all inputs, becomes 2d stack - theta = np.rad2deg(np.array(theta)).reshape(-1, 1) + # Flatten all inputs + theta = np.rad2deg(np.array(theta)).reshape(-1) # Expand single scalar input if np.size(theta) == 1: theta = np.full(img.shape[0], theta, img.dtype) - # Check we have an array matching `img`, both should be (n,1) + # Check we have an array matching `img`, both should be len(n) if theta.shape[0] != img.shape[0]: raise RuntimeError("Inconsistent `theta` and `img` shapes.") From 549746bf99a891452894e5ffdbbc34c1fdebd8d0 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 11 Dec 2025 14:21:35 -0500 Subject: [PATCH 88/91] add tests for additional unsupported input cases --- tests/test_image.py | 53 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/tests/test_image.py b/tests/test_image.py index 7206998a4c..ce60a946fe 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -11,7 +11,7 @@ from pytest import raises from scipy.datasets import face -from aspire.image import Image +from aspire.image import Image, fastrotate, sp_rotate from aspire.utils import Rotation, gaussian_2d, grid_2d, powerset, utest_tolerance from aspire.volume import CnSymmetryGroup @@ -640,3 +640,54 @@ def _gen_image(angle, L, n=1, K=10): 0.1, err_msg=f"{L} pixels using {rotation_method} @ {theta} radians", ) + + +def test_sp_rotate_inputs(dtype): + """ + Smoke test various input combinations to the scipy rotation wrapper. + """ + + imgs = np.zeros((6, 8, 8), dtype=dtype) + thetas = np.arange(6, dtype=dtype) + theta = thetas[0] # scalar + + ## These are the only supported calls admitted by the function doc. + # singleton, scalar + _ = sp_rotate(imgs[0], theta) + # stack, scalar + _ = sp_rotate(imgs, theta) + + ## These happen to also work with the code, so were put under test. + ## We're not advertising them, as there really isn't a good use + ## case for this wrapper code outside of the internal wrapping + ## application. + # singleton, single element array + _ = sp_rotate(imgs[0], thetas[0:1]) + # stack, single element array + _ = sp_rotate(imgs, thetas[0:1]) + # stack, stack + _ = sp_rotate(imgs, thetas) + # md-stack, md-stack + _ = sp_rotate(imgs.reshape(2, 3, 8, 8), thetas.reshape(2, 3, 1)) + _ = sp_rotate(imgs.reshape(2, 3, 8, 8), thetas.reshape(2, 3)) + + +def test_fastrotate_inputs(dtype): + """ + Smoke test various input combinations to `fastrotate`. + """ + + imgs = np.zeros((6, 8, 8), dtype=dtype) + theta = 42 + + ## These are the supported calls + # singleton, scalar + _ = fastrotate(imgs[0], theta) + # stack, scalar + _ = fastrotate(imgs, theta) + + ## These can also remain under test, but are not advertised. + # stack, single element array + _ = fastrotate(imgs, np.array(theta)) + # singleton, single element array + _ = fastrotate(imgs[0], np.array(theta)) From 4c95c4a0283800f7ed47705b8ecb9402407de23b Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 11 Dec 2025 14:30:34 -0500 Subject: [PATCH 89/91] tox doesn't like ## --- tests/test_image.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_image.py b/tests/test_image.py index ce60a946fe..3d981b6d09 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -651,16 +651,16 @@ def test_sp_rotate_inputs(dtype): thetas = np.arange(6, dtype=dtype) theta = thetas[0] # scalar - ## These are the only supported calls admitted by the function doc. + # # These are the only supported calls admitted by the function doc. # singleton, scalar _ = sp_rotate(imgs[0], theta) # stack, scalar _ = sp_rotate(imgs, theta) - ## These happen to also work with the code, so were put under test. - ## We're not advertising them, as there really isn't a good use - ## case for this wrapper code outside of the internal wrapping - ## application. + # # These happen to also work with the code, so were put under test. + # # We're not advertising them, as there really isn't a good use + # # case for this wrapper code outside of the internal wrapping + # # application. # singleton, single element array _ = sp_rotate(imgs[0], thetas[0:1]) # stack, single element array @@ -680,13 +680,13 @@ def test_fastrotate_inputs(dtype): imgs = np.zeros((6, 8, 8), dtype=dtype) theta = 42 - ## These are the supported calls + # # These are the supported calls # singleton, scalar _ = fastrotate(imgs[0], theta) # stack, scalar _ = fastrotate(imgs, theta) - ## These can also remain under test, but are not advertised. + # # These can also remain under test, but are not advertised. # stack, single element array _ = fastrotate(imgs, np.array(theta)) # singleton, single element array From 55d7d927a2d8e1f342b28f23fdf83967a90bc5fa Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 12 Dec 2025 09:47:43 -0500 Subject: [PATCH 90/91] Add M+theta arg error and test --- src/aspire/image/rotation.py | 12 ++++++++++-- tests/test_image.py | 24 +++++++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/aspire/image/rotation.py b/src/aspire/image/rotation.py index cd8211b98a..706248110f 100644 --- a/src/aspire/image/rotation.py +++ b/src/aspire/image/rotation.py @@ -143,8 +143,11 @@ def fastrotate(images, theta, M=None): `https://github.com/PrincetonUniversity/aspire/blob/760a43b35453e55ff2d9354339e9ffa109a25371/common/fastrotate/fastrotate.m` :param images: (n , px, px) array of image data - :param theta: rotation angle in radians - :param M: optional precomputed shearing table + :param theta: Rotation angle in radians. + Note when `M` is supplied, `theta` must be `None`. + :param M: Optional precomputed shearing table. + Provided by `M=compute_fastrotate_interp_tables(theta, px, px)`. + Note when `M` is supplied, `theta` must be `None`. :return: (n, px, px) array of rotated image data """ @@ -157,6 +160,11 @@ def fastrotate(images, theta, M=None): if M is None: M = compute_fastrotate_interp_tables(theta, px0, px1) + elif theta is not None: + raise RuntimeError( + "`theta` must be `None` when supplying `M`." + " M is precomputed for a specific `theta`." + ) Mx, My, Mrots = M # Cast interp tables to match precision of `images` diff --git a/tests/test_image.py b/tests/test_image.py index 3d981b6d09..009be52364 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -11,7 +11,7 @@ from pytest import raises from scipy.datasets import face -from aspire.image import Image, fastrotate, sp_rotate +from aspire.image import Image, compute_fastrotate_interp_tables, fastrotate, sp_rotate from aspire.utils import Rotation, gaussian_2d, grid_2d, powerset, utest_tolerance from aspire.volume import CnSymmetryGroup @@ -691,3 +691,25 @@ def test_fastrotate_inputs(dtype): _ = fastrotate(imgs, np.array(theta)) # singleton, single element array _ = fastrotate(imgs[0], np.array(theta)) + + +def test_fastrotate_M_arg(dtype): + """ + Smoke test precomputed `M` input to `fastrotate`. + """ + + imgs = np.random.randn(6, 8, 8).astype(dtype) + theta = np.random.uniform(0, 2 * np.pi) + + # Precompute M + M = compute_fastrotate_interp_tables(theta, *imgs.shape[-2:]) + + # Call with theta None + im_rot_M = fastrotate(imgs, None, M=M) + # Compare to calling withou `M` + im_rot = fastrotate(imgs, theta) + np.testing.assert_allclose(im_rot_M, im_rot) + + # Call with theta, should raise + with raises(RuntimeError, match=r".*`theta` must be `None`.*"): + _ = fastrotate(imgs, theta, M=M) From 6efd7b677e81ebf57b4d176779715cd394cac7fe Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 12 Dec 2025 13:53:13 -0500 Subject: [PATCH 91/91] =?UTF-8?q?Bump=20version:=200.14.1=20=E2=86=92=200.?= =?UTF-8?q?14.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- README.md | 4 ++-- docs/source/conf.py | 2 +- docs/source/index.rst | 2 +- pyproject.toml | 2 +- src/aspire/__init__.py | 2 +- src/aspire/config_default.yaml | 2 +- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index dc9339ea5b..9d6eb705dd 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.14.1 +current_version = 0.14.2 commit = True tag = True diff --git a/README.md b/README.md index 569fa69c62..a6f6ee4a80 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5657281.svg)](https://doi.org/10.5281/zenodo.5657281) [![Downloads](https://static.pepy.tech/badge/aspire/month)](https://pepy.tech/project/aspire) -# ASPIRE - Algorithms for Single Particle Reconstruction - v0.14.1 +# ASPIRE - Algorithms for Single Particle Reconstruction - v0.14.2 The ASPIRE-Python project supersedes [Matlab ASPIRE](https://github.com/PrincetonUniversity/aspire). @@ -20,7 +20,7 @@ For more information about the project, algorithms, and related publications ple Please cite using the following DOI. This DOI represents all versions, and will always resolve to the latest one. ``` -ComputationalCryoEM/ASPIRE-Python: v0.14.1 https://doi.org/10.5281/zenodo.5657281 +ComputationalCryoEM/ASPIRE-Python: v0.14.2 https://doi.org/10.5281/zenodo.5657281 ``` diff --git a/docs/source/conf.py b/docs/source/conf.py index fd203b75de..5faee47e40 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -86,7 +86,7 @@ # built documents. # # The full version, including alpha/beta/rc tags. -release = version = "0.14.1" +release = version = "0.14.2" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/source/index.rst b/docs/source/index.rst index a7ce8b3783..6154ac327b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,4 +1,4 @@ -Aspire v0.14.1 +Aspire v0.14.2 ============== Algorithms for Single Particle Reconstruction diff --git a/pyproject.toml b/pyproject.toml index 2e132d8e1e..198a3bfb06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "aspire" -version = "0.14.1" +version = "0.14.2" description = "Algorithms for Single Particle Reconstruction" readme = "README.md" # Optional requires-python = ">=3.9" diff --git a/src/aspire/__init__.py b/src/aspire/__init__.py index 3b77b9cb67..d7c9359a0e 100644 --- a/src/aspire/__init__.py +++ b/src/aspire/__init__.py @@ -15,7 +15,7 @@ from aspire.exceptions import handle_exception # version in maj.min.bld format -__version__ = "0.14.1" +__version__ = "0.14.2" # Setup `confuse` config diff --git a/src/aspire/config_default.yaml b/src/aspire/config_default.yaml index 4f02923054..ee97c4b8ca 100644 --- a/src/aspire/config_default.yaml +++ b/src/aspire/config_default.yaml @@ -1,4 +1,4 @@ -version: 0.14.1 +version: 0.14.2 common: # numeric module to use - one of numpy/cupy numeric: numpy