Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 92 additions & 41 deletions FastSurferCNN/data_loader/conform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
import nibabel as nib
import numpy as np
import numpy.typing as npt
from nibabel.freesurfer.mghformat import MGHHeader
from nibabel.freesurfer.mghformat import MGHHeader, MGHImage

if TYPE_CHECKING:
import torch

from FastSurferCNN.utils import ScalarType, logging, nibabelImage
from FastSurferCNN.utils import AffineMatrix4x4, ScalarType, logging, nibabelImage, nibabelHeader
from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType, StrictOrientationType, VoxSizeOption
from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as __conform_to_one_mm
from FastSurferCNN.utils.arg_types import img_size as __img_size
Expand All @@ -56,6 +56,8 @@
Modified by: David Kügler
Date: May-12-2025
"""
FIX_MGH_AFFINE_CALCULATION = False
FIX_CENTER_NOT_CENTER = True

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -344,26 +346,28 @@ def apply_orientation(arr: _TB | npt.ArrayLike, ornt: npt.NDArray[int]) -> _TB:
nibabel.orientations.apply_orientation
This function is an extension to `nibabel.orientations.apply_orientation`.
"""
from nibabel.orientations import OrientationError
from nibabel.orientations import apply_orientation as _apply_orientation
from torch import is_tensor as _is_tensor

if _is_tensor(arr):
ornt = np.asarray(ornt)
n = ornt.shape[0]
if arr.ndim < n:
raise OrientationError("Data array has fewer dimensions than orientation")
# apply ornt transformations
flip_dims = np.nonzero(ornt[:, 1] == -1)[0].tolist()
if len(flip_dims) > 0:
arr = arr.flip(flip_dims)
full_transpose = np.arange(arr.ndim)
# ornt indicates the transpose that has occurred - we reverse it
full_transpose[:n] = np.argsort(ornt[:, 0])
t_arr = arr.permute(*full_transpose)
return t_arr
else:
return _apply_orientation(arr, ornt)
from nibabel.orientations import apply_orientation as _apply_orientation, OrientationError

# only import torch, if it is likely we are dealing with a tensor
if hasattr(arr, "device"):
from torch import is_tensor as _is_tensor

if _is_tensor(arr):
ornt = np.asarray(ornt)
n = ornt.shape[0]
if arr.ndim < n:
raise OrientationError("Data array has fewer dimensions than orientation")
# apply ornt transformations
flip_dims = np.nonzero(ornt[:, 1] == -1)[0].tolist()
if len(flip_dims) > 0:
arr = arr.flip(flip_dims)
full_transpose = np.arange(arr.ndim)
# ornt indicates the transpose that has occurred - we reverse it
full_transpose[:n] = np.argsort(ornt[:, 0])
t_arr = arr.permute(*full_transpose)
return t_arr

return _apply_orientation(arr, ornt)


def map_image(
Expand Down Expand Up @@ -407,7 +411,7 @@ def map_image(
ras2ras = np.eye(4)

# compute vox2vox from src to trg
vox2vox = np.linalg.inv(out_affine) @ ras2ras @ img.affine
vox2vox = np.linalg.inv(out_affine) @ ras2ras @ get_affine_from_any(img)

# here we apply the inverse vox2vox (to pull back the src info to the target image)
image_data = np.asarray(img.dataobj, dtype=dtype)
Expand Down Expand Up @@ -447,11 +451,12 @@ def map_image(
if np.allclose(vox2vox[:, 3], np.round(vox2vox[:, 3]), atol=1e-4):
# reorder axes
ornt = nib.orientations.io_orientation(vox2vox)
reordered = apply_orientation(image_data, ornt)

new_old_index = list(enumerate(map(int, ornt[:, 0])))
# if the direction is flipped (ornt[j, 1] == -1), offset has to start at the other end
offsets = [ornt[j, 1] * vox2vox[i, 3] + (ornt[j, 1] == -1) * img.shape[j] for i, j in new_old_index]
offsets = [-vox2vox[i, 3] + (ornt[j, 1] == -1) * (img.shape[j] - 1) for i, j in new_old_index]
offsets = list(map(lambda x: int(x.astype(int)), offsets))
reordered = apply_orientation(image_data, ornt)
# pad=0 => pad with zeros
return crop_transform(reordered, offsets=offsets, target_shape=out_shape, pad=0)

Expand Down Expand Up @@ -628,6 +633,37 @@ def rescale(
return data_new


def get_affine_from_any(image: nibabelImage | nibabelHeader) -> AffineMatrix4x4:
"""
Retrieve the affine matrix from an MGH header.

This function also incorporates the `FIX_MGH_AFFINE_CALCULATION` hardcoded flag, which attempts to fix the nibabel
calculation of the affine matrix for MGH images (which incorrectly assumes Pxyz_c to be at the center of the image).
It is not, Pxyz_c is offset by half a voxel, see
https://surfer.nmr.mgh.harvard.edu/fswiki/CoordinateSystems?action=AttachFile&do=get&target=fscoordinates.pdf .

Parameters
==========
image : nibabel.spatialimages.SpatialImage, nibabel.spatialimages.SpatialHeader
The image object or file image header object.

Returns
=======
AffineMatrix4x4
The 4x4 affine transformation matrix for mapping voxel data to world coordinates.
"""
if FIX_MGH_AFFINE_CALCULATION and isinstance(image, (MGHImage, MGHHeader)):
mgh_header = image.header if isinstance(image, MGHImage) else image
# the function header.get_affine() is actually bugged, because it uses dims and not dims-1 for center :/
MdcD = np.asarray(mgh_header["Mdc"]).T * mgh_header["delta"]
vol_center = MdcD.dot(np.asarray(mgh_header["dims"][:3]) - 1) / 2
return nib.affines.from_matvec(MdcD, mgh_header["Pxyz_c"] - vol_center)
elif isinstance(image, nib.analyze.SpatialHeader):
return image.get_best_affine()
else:
return image.affine


def conform(
img: nibabelImage,
order: int = 1,
Expand Down Expand Up @@ -703,7 +739,9 @@ def conform(
h1 = prepare_mgh_header(img, *vox_img, _orientation, vox_eps=vox_eps, rot_eps=rot_eps)

# affine is the computed target affine for the output image
target_affine = h1.get_affine()
# BUGGED: target_affine = h1.get_affine()
target_affine = get_affine_from_any(h1)

if LOGGER.getEffectiveLevel() <= logging.DEBUG:
with np.printoptions(precision=2, suppress=True):
from re import sub
Expand Down Expand Up @@ -751,7 +789,8 @@ def conform(
# mapped data is still float here, clip to integers now
if np.issubdtype(target_dtype, np.integer):
mapped_data = np.rint(mapped_data)
new_img = nibabel.MGHImage(mapped_data.astype(target_dtype), target_affine, h1)
# using h1.get_affine() here to keep affine and header consistent within nibabel calculations
new_img = nibabel.MGHImage(mapped_data.astype(target_dtype), get_affine_from_any(h1), h1)

# make sure we store uchar
from nibabel.freesurfer import mghformat
Expand Down Expand Up @@ -779,7 +818,7 @@ def prepare_mgh_header(
rot_eps: float = 1e-6,
) -> MGHHeader:
"""
Prepare the header with affine by target voxel size, target image size and criteria - initialized from img.
Prepare the header with affine by target voxel size, target image size, and criteria - initialized from img.

This implicitly prepares the affine, which can be computed by `return_value.get_affine()`.

Expand All @@ -792,7 +831,7 @@ def prepare_mgh_header(
target_img_size : npt.NDArray[int], None, default=None
The target image size, importantly still in native orientation (reordering after).
orientation : "native", "soft-<orientation>", "<orientation>", default="native"
How the affine should look like.
What the affine should be oriented like.
vox_eps : float, default=1e-4
The epsilon for the voxelsize check.
rot_eps : float, default=1e-6
Expand All @@ -809,13 +848,14 @@ def prepare_mgh_header(
source_img_shape = img.header.get_data_shape()
source_vox_size = img.header.get_zooms()

source_mdc = img.affine[:3, :3] / np.linalg.norm(img.affine[:3, :3], axis=0, keepdims=True)
source_affine = get_affine_from_any(img)
source_mdc = source_affine[:3, :3] / np.linalg.norm(source_affine[:3, :3], axis=0, keepdims=True)
# native
if orientation == "native":
re_order_axes = [0, 1, 2]
mdc_affine = np.linalg.inv(source_mdc)
else:
_ornt_transform, _ = orientation_to_ornts(img.affine, orientation[-3:])
_ornt_transform, _ = orientation_to_ornts(source_affine, orientation[-3:])
re_order_axes = _ornt_transform[:, 0]
if len(orientation) == 3: # lia, ras, etc
# this is a 3x3 matrix
Expand All @@ -837,11 +877,18 @@ def prepare_mgh_header(
if _fov.min() == _fov.max():
# fov is not needed for MGHHeader.get_affine()
h1["fov"] = _fov[0]
center = np.asarray(img.shape[:3], dtype=float) / 2.0
h1["Pxyz_c"] = img.affine.dot(np.hstack((center, [1.0])))[:3]
center = (np.asarray(img.shape[:3], dtype=float) - (1 if FIX_MGH_AFFINE_CALCULATION else 0)) / 2.0
if FIX_CENTER_NOT_CENTER:
# The center is not actually the center, but rather the position of the voxel at Ni/2 (counting at voxel 0)
# Therefore, the center changes, if we apply a vox2vox
# to get to the true center, move back half a voxel in all directions
true_center = center - 0.5 * np.ones((1, 3)) @ source_affine[:3, :3]
# new image center from true center go half a voxel in all direction of the new affine
center = 0.5 * np.ones((1, 3)) @ get_affine_from_any(h1)[:3, :3] + true_center
h1["Pxyz_c"] = source_affine.dot(np.hstack((center, [1.0])))[:3]
# There is a special case here, where an interpolation is triggered, but it is not necessary, if the position of
# the center could "fix this" condition:
vox2vox = np.linalg.inv(h1.get_affine()) @ img.affine
vox2vox = np.linalg.inv(get_affine_from_any(h1)) @ source_affine
if does_vox2vox_rot_require_interpolation(vox2vox, vox_eps=vox_eps, rot_eps=rot_eps):
# 1. has rotation, or vox-size resampling => requires resampling
pass
Expand All @@ -853,10 +900,10 @@ def prepare_mgh_header(
# is it fixable?
if not np.allclose(vec, np.round(vec), **tols) and np.allclose(vec * 2, np.round(vec * 2), **tols):
new_center = (center + (1 - np.isclose(vec, np.round(vec), **tols)) / 2.0, [1.0])
h1["Pxyz_c"] = img.affine.dot(np.hstack(new_center))[:3]
h1["Pxyz_c"] = source_affine.dot(np.hstack(new_center))[:3]

# tr information is not copied when copying from non-mgh formats
if len(img.header.get('pixdim', [])) :
if len(img.header.get('pixdim', [])):
h1['tr'] = img.header['pixdim'][4] * 1000

# The affine can be explicitly constructed by MGHHeader.get_affine() / h1.get_affine()
Expand Down Expand Up @@ -989,15 +1036,17 @@ def is_conform(
img_size_criteria = f"Dimensions {img_size}={'x'.join(map(str, _img_size[:3]))}"
checks[img_size_criteria] = np.array_equal(np.asarray(img.shape[:3]), _img_size), img_size_text

img_affine = get_affine_from_any(img)

# check orientation LIA
affcode = "".join(nib.orientations.aff2axcodes(img.affine))
affcode = "".join(nib.orientations.aff2axcodes(img_affine))
with np.printoptions(precision=2, suppress=True):
orientation_text = "affine=" + re.sub("\\s+", " ", str(img.affine[:3, :3])) + f" => {affcode}"
orientation_text = "affine=" + re.sub("\\s+", " ", str(img_affine[:3, :3])) + f" => {affcode}"
if orientation is None or orientation == "native":
checks[f"Orientation {orientation}"] = "IGNORED", orientation_text
else:
is_soft = not orientation.startswith("soft")
is_correct_orientation = is_orientation(img.affine, orientation[-3:], is_soft, eps)
is_correct_orientation = is_orientation(img_affine, orientation[-3:], is_soft, eps)
checks[f"Orientation {orientation.upper()}"] = is_correct_orientation, orientation_text

# check dtype uchar
Expand Down Expand Up @@ -1232,16 +1281,18 @@ def check_affine_in_nifti(
# Exit otherwise
vox_size_header = header.get_zooms()

img_affine = get_affine_from_any(img)

# voxel size in xyz direction from the affine
vox_size_affine = np.sqrt((img.affine[:3, :3] * img.affine[:3, :3]).sum(0))
vox_size_affine = np.sqrt((img_affine[:3, :3] * img_affine[:3, :3]).sum(0))

if not np.allclose(vox_size_affine, vox_size_header, atol=1e-3):
message = (
f"#############################################################\n"
f"ERROR: Invalid Nifti-header! Affine matrix is inconsistent with "
f"Voxel sizes. \nVoxel size (from header) vs. Voxel size in affine:\n"
f"{tuple(vox_size_header[:3])}, {tuple(vox_size_affine)}\n"
f"Input Affine----------------\n{img.affine}\n"
f"Input Affine----------------\n{img_affine}\n"
f"#############################################################"
)
check = False
Expand Down
Loading