diff --git a/data/preprocess.py b/data/preprocess.py index e1d2bf1..0146480 100644 --- a/data/preprocess.py +++ b/data/preprocess.py @@ -1,10 +1,11 @@ -import SimpleITK as sitk -import matplotlib.pyplot as plt -from util.util import mkdir, mkdirs -import pandas as pd -import numpy as np -import torch -import os +import SimpleITK as sitk +import matplotlib.pyplot as plt +from util.util import mkdir, mkdirs +import pandas as pd +import numpy as np +import torch +import torch.nn.functional as F +import os def save_nifti(volume, path): @@ -110,27 +111,25 @@ def extract_volume_from_dicom(case_path, format, _min=0, _max=2048, clamp_en=Tru return img, interp_img, img_nda, interp_img_nda -def extract_patches_with_overlap(volume, patch_size=64, overlap_ratio=0.125): - patches = [] - depth, height, width = volume.shape - patch_depth = patch_height = patch_width = patch_size - overlap_depth = int(patch_depth * overlap_ratio) - overlap_height = int(patch_height * overlap_ratio) - overlap_width = int(patch_width * overlap_ratio) - - d_step = patch_depth - overlap_depth - h_step = patch_height - overlap_height - w_step = patch_width - overlap_width - - for d in range(0, depth - patch_depth + 1, d_step): - for h in range(0, height - patch_height + 1, h_step): - for w in range(0, width - patch_width + 1, w_step): - patch_id = dict() - patch = volume[d:d + patch_depth, h:h + patch_height, w:w + patch_width] - patch_id['patch'] = patch - patches.append(patch_id) - - return patches +def extract_patches_with_overlap(volume, patch_size=64, overlap_ratio=0.125): + if isinstance(volume, np.ndarray): + volume_t = torch.from_numpy(volume) + else: + volume_t = torch.as_tensor(volume) + + if volume_t.ndim != 3: + raise ValueError(f"Expected a 3D volume, got tensor with shape {volume_t.shape}") + + overlap = int(patch_size * overlap_ratio) + stride = patch_size - overlap + if stride <= 0: + raise ValueError("Overlap ratio is too large, resulting in a non-positive stride.") + + patches = extract_patches_3d(volume_t.unsqueeze(0).unsqueeze(0), + kernel_size=patch_size, + stride=stride) + + return patches[:, 0, :, :, :].contiguous() def get_dim_blocks(dim_in, dim_kernel_size, dim_padding=0, dim_stride=1, dim_dilation=1, round_down=True): if round_down: @@ -227,78 +226,104 @@ def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilati return x -def reconstruct_volume(opt, patches_3d_list, output_shape): - overlap = int(opt.patch_size * opt.overlap_ratio) - - patches_3d = torch.cat(patches_3d_list) - for i in range(overlap): - patches_3d[:, :, i, :, :] *= (i + 1) / (overlap + 1) - patches_3d[:, :, :, i, :] *= (i + 1) / (overlap + 1) - patches_3d[:, :, :, :, i] *= (i + 1) / (overlap + 1) - patches_3d[:, :, -(i + 1), :, :] *= (i + 1) / (overlap + 1) - patches_3d[:, :, :, -(i + 1), :] *= (i + 1) / (overlap + 1) - patches_3d[:, :, :, :, -(i + 1)] *= (i + 1) / (overlap + 1) - - output_vol = combine_patches_3d(patches_3d, opt.patch_size, output_shape, - stride=int(opt.patch_size * (1 - opt.overlap_ratio))) - - ones = torch.ones_like(patches_3d).cpu() - for i in range(overlap): - ones[:, :, i, :, :] *= (i + 1) / (overlap + 1) - ones[:, :, :, i, :] *= (i + 1) / (overlap + 1) - ones[:, :, :, :, i] *= (i + 1) / (overlap + 1) - ones[:, :, -(i + 1), :, :] *= (i + 1) / (overlap + 1) - ones[:, :, :, -(i + 1), :] *= (i + 1) / (overlap + 1) - ones[:, :, :, :, -(i + 1)] *= (i + 1) / (overlap + 1) - ones_vol = combine_patches_3d(ones, opt.patch_size, output_shape, stride=int(opt.patch_size * (1 - opt.overlap_ratio))) - recon_vol = output_vol.cpu() / ones_vol - - return recon_vol +def _create_patch_weights(patch_size, overlap, dtype, device): + weights = torch.ones((patch_size, patch_size, patch_size), dtype=dtype, device=device) + + if overlap <= 0: + return weights.unsqueeze(0).unsqueeze(0) + + ramp = torch.arange(1, overlap + 1, dtype=dtype, device=device) / (overlap + 1) + inv_ramp = ramp.flip(0) + + weights[:overlap] *= ramp.view(-1, 1, 1) + weights[-overlap:] *= inv_ramp.view(-1, 1, 1) + weights[:, :overlap] *= ramp.view(1, -1, 1) + weights[:, -overlap:] *= inv_ramp.view(1, -1, 1) + weights[:, :, :overlap] *= ramp.view(1, 1, -1) + weights[:, :, -overlap:] *= inv_ramp.view(1, 1, -1) + + return weights.unsqueeze(0).unsqueeze(0) + + +def reconstruct_volume(opt, patches_3d_list, output_shape): + overlap = int(opt.patch_size * opt.overlap_ratio) + stride = int(opt.patch_size * (1 - opt.overlap_ratio)) + if stride <= 0: + raise ValueError("Overlap ratio is too large, resulting in a non-positive stride.") + + if len(patches_3d_list) == 0: + raise ValueError("The patch list is empty; unable to reconstruct volume.") + + patches_3d = torch.cat([patch.detach() for patch in patches_3d_list]) + weight_patch = _create_patch_weights(opt.patch_size, overlap, patches_3d.dtype, patches_3d.device) + + weighted_patches = patches_3d * weight_patch + output_vol = combine_patches_3d(weighted_patches, opt.patch_size, output_shape, stride=stride).cpu() + + weight_tiles = weight_patch.repeat(patches_3d.shape[0], 1, 1, 1, 1) + ones_vol = combine_patches_3d(weight_tiles, opt.patch_size, output_shape, stride=stride).cpu() + + if ones_vol.is_floating_point(): + eps = torch.finfo(ones_vol.dtype).eps + ones_vol = torch.clamp(ones_vol, min=eps) + else: + ones_vol = ones_vol.masked_fill_(ones_vol == 0, 1) + + recon_vol = output_vol / ones_vol + + return recon_vol def pad_volume(vol, dim): vol = change_dim(vol, dim) - slices_num = vol.shape[0] - if slices_num > dim: - start = int((slices_num - dim)/2) - end = start + dim - res_vol = vol[start:end, :, :] - else: - pad_vol = torch.zeros(dim, dim, dim) - 1 - start = int((dim - slices_num)/2) - end = start + slices_num - pad_vol[start:end, :, :] = vol - res_vol = pad_vol - - return res_vol + slices_num = vol.shape[0] + if slices_num > dim: + start = int((slices_num - dim) / 2) + end = start + dim + return vol[start:end, :, :] + + pad_vol = torch.full((dim, dim, dim), -1, dtype=vol.dtype, device=vol.device) + start = int((dim - slices_num) / 2) + end = start + slices_num + pad_vol[start:end, :, :] = vol + + return pad_vol def change_dim(image, target_dim): - if len(image.shape) == 2: - target_size = (target_dim, target_dim) - elif len(image.shape) == 3: - target_size = (image.shape[0], target_dim, target_dim) - else: - target_size = None - print(f'number of image dimensions is {len(image.shape)} != 2 or 3') - - current_size = image.shape - - # Initialize the output array with zeros - if 'torch' in str(image.dtype): - output_image = torch.zeros(target_size, dtype=image.dtype) - 1 - else: - output_image = np.zeros(target_size, dtype=image.dtype) - 1 - - # Calculate the amount to pad or crop for each dimension - pad_crop = [((t - c) // 2, (t - c + 1) // 2) for t, c in zip(target_size, current_size)] - - input_slices = tuple(slice(max(0, -p[0]), c - max(0, -p[1])) for p, c in zip(pad_crop, current_size)) - output_slices = tuple(slice(max(0, p[0]), t - max(0, p[1])) for p, t in zip(pad_crop, target_size)) - - # Copy the data from the input to the output - output_image[output_slices] = image[input_slices] - - return output_image + if len(image.shape) == 2: + target_size = (target_dim, target_dim) + elif len(image.shape) == 3: + target_size = (image.shape[0], target_dim, target_dim) + else: + raise ValueError(f'number of image dimensions is {len(image.shape)} != 2 or 3') + + current_size = image.shape + slices = [] + pads = [] + + for current, target in zip(current_size, target_size): + if current <= target: + slices.append(slice(None)) + total = target - current + pad_before = total // 2 + pad_after = total - pad_before + pads.append((pad_before, pad_after)) + else: + start = (current - target) // 2 + slices.append(slice(start, start + target)) + pads.append((0, 0)) + + cropped = image[tuple(slices)] + + if isinstance(image, torch.Tensor): + if all(before == 0 and after == 0 for before, after in pads): + return cropped + pad_values = [] + for before, after in reversed(pads): + pad_values.extend([before, after]) + return F.pad(cropped, tuple(pad_values), mode='constant', value=-1) + else: + return np.pad(cropped, pads, mode='constant', constant_values=-1) def calc_dims(case, opt): new_dim = np.zeros(3, dtype=int) @@ -353,38 +378,38 @@ def simple_train_preprocess(opt): interp_vol = torch.from_numpy(interp_vol_nda).to(torch.float32).cpu().detach() interp_vol = interp_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d] - interp_patches = extract_patches_with_overlap(interp_vol, opt.patch_size, opt.overlap_ratio) - if 'coronal' in opt.planes: - cor_atme_vol = torch.load(os.path.join(opt.main_root, opt.atme_cor_root, 'data', 'generation', f'case_{case_idx}', 'atme_vol.pt')).cpu().detach() - cor_atme_vol = cor_atme_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d] - cor_atme_patches = extract_patches_with_overlap(cor_atme_vol, opt.patch_size, opt.overlap_ratio) - assert (len(interp_patches) == len(cor_atme_patches)) - if 'axial' in opt.planes: - ax_atme_vol = torch.load(os.path.join(opt.main_root, opt.atme_ax_root, 'data', 'generation', f'case_{case_idx}', 'atme_vol.pt')).cpu().detach() - ax_atme_vol = ax_atme_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d] - ax_atme_patches = extract_patches_with_overlap(ax_atme_vol, opt.patch_size, opt.overlap_ratio) - assert (len(interp_patches) == len(ax_atme_patches)) - if 'sagittal' in opt.planes: - sag_atme_vol = torch.load(os.path.join(opt.main_root, opt.atme_sag_root, 'data', 'generation', f'case_{case_idx}', 'atme_vol.pt')).cpu().detach() - sag_atme_vol = sag_atme_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d] - sag_atme_patches = extract_patches_with_overlap(sag_atme_vol, opt.patch_size, opt.overlap_ratio) - assert (len(interp_patches) == len(sag_atme_patches)) - - - for i in range(len(interp_patches)): - interp_patch = interp_patches[i]['patch'].unsqueeze(0).clone() - data = {'interp_patch': interp_patch} - if 'coronal' in opt.planes: - cor_atme_patch = cor_atme_patches[i]['patch'].unsqueeze(0).clone() - data['cor_atme_patch'] = cor_atme_patch - if 'axial' in opt.planes: - ax_atme_patch = ax_atme_patches[i]['patch'].unsqueeze(0).clone() - data['ax_atme_patch'] = ax_atme_patch - if 'sagittal' in opt.planes: - sag_atme_patch = sag_atme_patches[i]['patch'].unsqueeze(0).clone() - data['sag_atme_patch'] = sag_atme_patch - torch.save(data, os.path.join(save_train_dir, f'data_{save_idx}.pt')) - save_idx += 1 + interp_patches = extract_patches_with_overlap(interp_vol, opt.patch_size, opt.overlap_ratio) + if 'coronal' in opt.planes: + cor_atme_vol = torch.load(os.path.join(opt.main_root, opt.atme_cor_root, 'data', 'generation', f'case_{case_idx}', 'atme_vol.pt')).cpu().detach() + cor_atme_vol = cor_atme_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d] + cor_atme_patches = extract_patches_with_overlap(cor_atme_vol, opt.patch_size, opt.overlap_ratio) + assert (interp_patches.shape[0] == cor_atme_patches.shape[0]) + if 'axial' in opt.planes: + ax_atme_vol = torch.load(os.path.join(opt.main_root, opt.atme_ax_root, 'data', 'generation', f'case_{case_idx}', 'atme_vol.pt')).cpu().detach() + ax_atme_vol = ax_atme_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d] + ax_atme_patches = extract_patches_with_overlap(ax_atme_vol, opt.patch_size, opt.overlap_ratio) + assert (interp_patches.shape[0] == ax_atme_patches.shape[0]) + if 'sagittal' in opt.planes: + sag_atme_vol = torch.load(os.path.join(opt.main_root, opt.atme_sag_root, 'data', 'generation', f'case_{case_idx}', 'atme_vol.pt')).cpu().detach() + sag_atme_vol = sag_atme_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d] + sag_atme_patches = extract_patches_with_overlap(sag_atme_vol, opt.patch_size, opt.overlap_ratio) + assert (interp_patches.shape[0] == sag_atme_patches.shape[0]) + + + for i in range(interp_patches.shape[0]): + interp_patch = interp_patches[i].unsqueeze(0).clone() + data = {'interp_patch': interp_patch} + if 'coronal' in opt.planes: + cor_atme_patch = cor_atme_patches[i].unsqueeze(0).clone() + data['cor_atme_patch'] = cor_atme_patch + if 'axial' in opt.planes: + ax_atme_patch = ax_atme_patches[i].unsqueeze(0).clone() + data['ax_atme_patch'] = ax_atme_patch + if 'sagittal' in opt.planes: + sag_atme_patch = sag_atme_patches[i].unsqueeze(0).clone() + data['sag_atme_patch'] = sag_atme_patch + torch.save(data, os.path.join(save_train_dir, f'data_{save_idx}.pt')) + save_idx += 1 def read_MRI_case(case_path, format): if format == 'dicom':