From d8b8d42fea3ac6dbada3b201f5f338f3c4323b05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Thu, 21 Aug 2025 17:27:26 +0200 Subject: [PATCH] Add two potential fixes to conform --- FastSurferCNN/data_loader/conform.py | 133 ++++++++++++++++++--------- 1 file changed, 92 insertions(+), 41 deletions(-) diff --git a/FastSurferCNN/data_loader/conform.py b/FastSurferCNN/data_loader/conform.py index e49924f9..ebacfc6a 100644 --- a/FastSurferCNN/data_loader/conform.py +++ b/FastSurferCNN/data_loader/conform.py @@ -25,12 +25,12 @@ import nibabel as nib import numpy as np import numpy.typing as npt -from nibabel.freesurfer.mghformat import MGHHeader +from nibabel.freesurfer.mghformat import MGHHeader, MGHImage if TYPE_CHECKING: import torch -from FastSurferCNN.utils import ScalarType, logging, nibabelImage +from FastSurferCNN.utils import AffineMatrix4x4, ScalarType, logging, nibabelImage, nibabelHeader from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType, StrictOrientationType, VoxSizeOption from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as __conform_to_one_mm from FastSurferCNN.utils.arg_types import img_size as __img_size @@ -56,6 +56,8 @@ Modified by: David Kügler Date: May-12-2025 """ +FIX_MGH_AFFINE_CALCULATION = False +FIX_CENTER_NOT_CENTER = True LOGGER = logging.getLogger(__name__) @@ -344,26 +346,28 @@ def apply_orientation(arr: _TB | npt.ArrayLike, ornt: npt.NDArray[int]) -> _TB: nibabel.orientations.apply_orientation This function is an extension to `nibabel.orientations.apply_orientation`. """ - from nibabel.orientations import OrientationError - from nibabel.orientations import apply_orientation as _apply_orientation - from torch import is_tensor as _is_tensor - - if _is_tensor(arr): - ornt = np.asarray(ornt) - n = ornt.shape[0] - if arr.ndim < n: - raise OrientationError("Data array has fewer dimensions than orientation") - # apply ornt transformations - flip_dims = np.nonzero(ornt[:, 1] == -1)[0].tolist() - if len(flip_dims) > 0: - arr = arr.flip(flip_dims) - full_transpose = np.arange(arr.ndim) - # ornt indicates the transpose that has occurred - we reverse it - full_transpose[:n] = np.argsort(ornt[:, 0]) - t_arr = arr.permute(*full_transpose) - return t_arr - else: - return _apply_orientation(arr, ornt) + from nibabel.orientations import apply_orientation as _apply_orientation, OrientationError + + # only import torch, if it is likely we are dealing with a tensor + if hasattr(arr, "device"): + from torch import is_tensor as _is_tensor + + if _is_tensor(arr): + ornt = np.asarray(ornt) + n = ornt.shape[0] + if arr.ndim < n: + raise OrientationError("Data array has fewer dimensions than orientation") + # apply ornt transformations + flip_dims = np.nonzero(ornt[:, 1] == -1)[0].tolist() + if len(flip_dims) > 0: + arr = arr.flip(flip_dims) + full_transpose = np.arange(arr.ndim) + # ornt indicates the transpose that has occurred - we reverse it + full_transpose[:n] = np.argsort(ornt[:, 0]) + t_arr = arr.permute(*full_transpose) + return t_arr + + return _apply_orientation(arr, ornt) def map_image( @@ -407,7 +411,7 @@ def map_image( ras2ras = np.eye(4) # compute vox2vox from src to trg - vox2vox = np.linalg.inv(out_affine) @ ras2ras @ img.affine + vox2vox = np.linalg.inv(out_affine) @ ras2ras @ get_affine_from_any(img) # here we apply the inverse vox2vox (to pull back the src info to the target image) image_data = np.asarray(img.dataobj, dtype=dtype) @@ -447,11 +451,12 @@ def map_image( if np.allclose(vox2vox[:, 3], np.round(vox2vox[:, 3]), atol=1e-4): # reorder axes ornt = nib.orientations.io_orientation(vox2vox) + reordered = apply_orientation(image_data, ornt) + new_old_index = list(enumerate(map(int, ornt[:, 0]))) # if the direction is flipped (ornt[j, 1] == -1), offset has to start at the other end - offsets = [ornt[j, 1] * vox2vox[i, 3] + (ornt[j, 1] == -1) * img.shape[j] for i, j in new_old_index] + offsets = [-vox2vox[i, 3] + (ornt[j, 1] == -1) * (img.shape[j] - 1) for i, j in new_old_index] offsets = list(map(lambda x: int(x.astype(int)), offsets)) - reordered = apply_orientation(image_data, ornt) # pad=0 => pad with zeros return crop_transform(reordered, offsets=offsets, target_shape=out_shape, pad=0) @@ -628,6 +633,37 @@ def rescale( return data_new +def get_affine_from_any(image: nibabelImage | nibabelHeader) -> AffineMatrix4x4: + """ + Retrieve the affine matrix from an MGH header. + + This function also incorporates the `FIX_MGH_AFFINE_CALCULATION` hardcoded flag, which attempts to fix the nibabel + calculation of the affine matrix for MGH images (which incorrectly assumes Pxyz_c to be at the center of the image). + It is not, Pxyz_c is offset by half a voxel, see + https://surfer.nmr.mgh.harvard.edu/fswiki/CoordinateSystems?action=AttachFile&do=get&target=fscoordinates.pdf . + + Parameters + ========== + image : nibabel.spatialimages.SpatialImage, nibabel.spatialimages.SpatialHeader + The image object or file image header object. + + Returns + ======= + AffineMatrix4x4 + The 4x4 affine transformation matrix for mapping voxel data to world coordinates. + """ + if FIX_MGH_AFFINE_CALCULATION and isinstance(image, (MGHImage, MGHHeader)): + mgh_header = image.header if isinstance(image, MGHImage) else image + # the function header.get_affine() is actually bugged, because it uses dims and not dims-1 for center :/ + MdcD = np.asarray(mgh_header["Mdc"]).T * mgh_header["delta"] + vol_center = MdcD.dot(np.asarray(mgh_header["dims"][:3]) - 1) / 2 + return nib.affines.from_matvec(MdcD, mgh_header["Pxyz_c"] - vol_center) + elif isinstance(image, nib.analyze.SpatialHeader): + return image.get_best_affine() + else: + return image.affine + + def conform( img: nibabelImage, order: int = 1, @@ -703,7 +739,9 @@ def conform( h1 = prepare_mgh_header(img, *vox_img, _orientation, vox_eps=vox_eps, rot_eps=rot_eps) # affine is the computed target affine for the output image - target_affine = h1.get_affine() + # BUGGED: target_affine = h1.get_affine() + target_affine = get_affine_from_any(h1) + if LOGGER.getEffectiveLevel() <= logging.DEBUG: with np.printoptions(precision=2, suppress=True): from re import sub @@ -751,7 +789,8 @@ def conform( # mapped data is still float here, clip to integers now if np.issubdtype(target_dtype, np.integer): mapped_data = np.rint(mapped_data) - new_img = nibabel.MGHImage(mapped_data.astype(target_dtype), target_affine, h1) + # using h1.get_affine() here to keep affine and header consistent within nibabel calculations + new_img = nibabel.MGHImage(mapped_data.astype(target_dtype), get_affine_from_any(h1), h1) # make sure we store uchar from nibabel.freesurfer import mghformat @@ -779,7 +818,7 @@ def prepare_mgh_header( rot_eps: float = 1e-6, ) -> MGHHeader: """ - Prepare the header with affine by target voxel size, target image size and criteria - initialized from img. + Prepare the header with affine by target voxel size, target image size, and criteria - initialized from img. This implicitly prepares the affine, which can be computed by `return_value.get_affine()`. @@ -792,7 +831,7 @@ def prepare_mgh_header( target_img_size : npt.NDArray[int], None, default=None The target image size, importantly still in native orientation (reordering after). orientation : "native", "soft-", "", default="native" - How the affine should look like. + What the affine should be oriented like. vox_eps : float, default=1e-4 The epsilon for the voxelsize check. rot_eps : float, default=1e-6 @@ -809,13 +848,14 @@ def prepare_mgh_header( source_img_shape = img.header.get_data_shape() source_vox_size = img.header.get_zooms() - source_mdc = img.affine[:3, :3] / np.linalg.norm(img.affine[:3, :3], axis=0, keepdims=True) + source_affine = get_affine_from_any(img) + source_mdc = source_affine[:3, :3] / np.linalg.norm(source_affine[:3, :3], axis=0, keepdims=True) # native if orientation == "native": re_order_axes = [0, 1, 2] mdc_affine = np.linalg.inv(source_mdc) else: - _ornt_transform, _ = orientation_to_ornts(img.affine, orientation[-3:]) + _ornt_transform, _ = orientation_to_ornts(source_affine, orientation[-3:]) re_order_axes = _ornt_transform[:, 0] if len(orientation) == 3: # lia, ras, etc # this is a 3x3 matrix @@ -837,11 +877,18 @@ def prepare_mgh_header( if _fov.min() == _fov.max(): # fov is not needed for MGHHeader.get_affine() h1["fov"] = _fov[0] - center = np.asarray(img.shape[:3], dtype=float) / 2.0 - h1["Pxyz_c"] = img.affine.dot(np.hstack((center, [1.0])))[:3] + center = (np.asarray(img.shape[:3], dtype=float) - (1 if FIX_MGH_AFFINE_CALCULATION else 0)) / 2.0 + if FIX_CENTER_NOT_CENTER: + # The center is not actually the center, but rather the position of the voxel at Ni/2 (counting at voxel 0) + # Therefore, the center changes, if we apply a vox2vox + # to get to the true center, move back half a voxel in all directions + true_center = center - 0.5 * np.ones((1, 3)) @ source_affine[:3, :3] + # new image center from true center go half a voxel in all direction of the new affine + center = 0.5 * np.ones((1, 3)) @ get_affine_from_any(h1)[:3, :3] + true_center + h1["Pxyz_c"] = source_affine.dot(np.hstack((center, [1.0])))[:3] # There is a special case here, where an interpolation is triggered, but it is not necessary, if the position of # the center could "fix this" condition: - vox2vox = np.linalg.inv(h1.get_affine()) @ img.affine + vox2vox = np.linalg.inv(get_affine_from_any(h1)) @ source_affine if does_vox2vox_rot_require_interpolation(vox2vox, vox_eps=vox_eps, rot_eps=rot_eps): # 1. has rotation, or vox-size resampling => requires resampling pass @@ -853,10 +900,10 @@ def prepare_mgh_header( # is it fixable? if not np.allclose(vec, np.round(vec), **tols) and np.allclose(vec * 2, np.round(vec * 2), **tols): new_center = (center + (1 - np.isclose(vec, np.round(vec), **tols)) / 2.0, [1.0]) - h1["Pxyz_c"] = img.affine.dot(np.hstack(new_center))[:3] + h1["Pxyz_c"] = source_affine.dot(np.hstack(new_center))[:3] # tr information is not copied when copying from non-mgh formats - if len(img.header.get('pixdim', [])) : + if len(img.header.get('pixdim', [])): h1['tr'] = img.header['pixdim'][4] * 1000 # The affine can be explicitly constructed by MGHHeader.get_affine() / h1.get_affine() @@ -989,15 +1036,17 @@ def is_conform( img_size_criteria = f"Dimensions {img_size}={'x'.join(map(str, _img_size[:3]))}" checks[img_size_criteria] = np.array_equal(np.asarray(img.shape[:3]), _img_size), img_size_text + img_affine = get_affine_from_any(img) + # check orientation LIA - affcode = "".join(nib.orientations.aff2axcodes(img.affine)) + affcode = "".join(nib.orientations.aff2axcodes(img_affine)) with np.printoptions(precision=2, suppress=True): - orientation_text = "affine=" + re.sub("\\s+", " ", str(img.affine[:3, :3])) + f" => {affcode}" + orientation_text = "affine=" + re.sub("\\s+", " ", str(img_affine[:3, :3])) + f" => {affcode}" if orientation is None or orientation == "native": checks[f"Orientation {orientation}"] = "IGNORED", orientation_text else: is_soft = not orientation.startswith("soft") - is_correct_orientation = is_orientation(img.affine, orientation[-3:], is_soft, eps) + is_correct_orientation = is_orientation(img_affine, orientation[-3:], is_soft, eps) checks[f"Orientation {orientation.upper()}"] = is_correct_orientation, orientation_text # check dtype uchar @@ -1232,8 +1281,10 @@ def check_affine_in_nifti( # Exit otherwise vox_size_header = header.get_zooms() + img_affine = get_affine_from_any(img) + # voxel size in xyz direction from the affine - vox_size_affine = np.sqrt((img.affine[:3, :3] * img.affine[:3, :3]).sum(0)) + vox_size_affine = np.sqrt((img_affine[:3, :3] * img_affine[:3, :3]).sum(0)) if not np.allclose(vox_size_affine, vox_size_header, atol=1e-3): message = ( @@ -1241,7 +1292,7 @@ def check_affine_in_nifti( f"ERROR: Invalid Nifti-header! Affine matrix is inconsistent with " f"Voxel sizes. \nVoxel size (from header) vs. Voxel size in affine:\n" f"{tuple(vox_size_header[:3])}, {tuple(vox_size_affine)}\n" - f"Input Affine----------------\n{img.affine}\n" + f"Input Affine----------------\n{img_affine}\n" f"#############################################################" ) check = False