Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 151 additions & 126 deletions data/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import SimpleITK as sitk
import matplotlib.pyplot as plt
from util.util import mkdir, mkdirs
import pandas as pd
import numpy as np
import torch
import os
import SimpleITK as sitk
import matplotlib.pyplot as plt
from util.util import mkdir, mkdirs
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
import os


def save_nifti(volume, path):
Expand Down Expand Up @@ -110,27 +111,25 @@ def extract_volume_from_dicom(case_path, format, _min=0, _max=2048, clamp_en=Tru

return img, interp_img, img_nda, interp_img_nda

def extract_patches_with_overlap(volume, patch_size=64, overlap_ratio=0.125):
patches = []
depth, height, width = volume.shape
patch_depth = patch_height = patch_width = patch_size
overlap_depth = int(patch_depth * overlap_ratio)
overlap_height = int(patch_height * overlap_ratio)
overlap_width = int(patch_width * overlap_ratio)

d_step = patch_depth - overlap_depth
h_step = patch_height - overlap_height
w_step = patch_width - overlap_width

for d in range(0, depth - patch_depth + 1, d_step):
for h in range(0, height - patch_height + 1, h_step):
for w in range(0, width - patch_width + 1, w_step):
patch_id = dict()
patch = volume[d:d + patch_depth, h:h + patch_height, w:w + patch_width]
patch_id['patch'] = patch
patches.append(patch_id)

return patches
def extract_patches_with_overlap(volume, patch_size=64, overlap_ratio=0.125):
if isinstance(volume, np.ndarray):
volume_t = torch.from_numpy(volume)
else:
volume_t = torch.as_tensor(volume)

if volume_t.ndim != 3:
raise ValueError(f"Expected a 3D volume, got tensor with shape {volume_t.shape}")

overlap = int(patch_size * overlap_ratio)
stride = patch_size - overlap
if stride <= 0:
raise ValueError("Overlap ratio is too large, resulting in a non-positive stride.")

patches = extract_patches_3d(volume_t.unsqueeze(0).unsqueeze(0),
kernel_size=patch_size,
stride=stride)

return patches[:, 0, :, :, :].contiguous()

def get_dim_blocks(dim_in, dim_kernel_size, dim_padding=0, dim_stride=1, dim_dilation=1, round_down=True):
if round_down:
Expand Down Expand Up @@ -227,78 +226,104 @@ def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilati

return x

def reconstruct_volume(opt, patches_3d_list, output_shape):
overlap = int(opt.patch_size * opt.overlap_ratio)

patches_3d = torch.cat(patches_3d_list)
for i in range(overlap):
patches_3d[:, :, i, :, :] *= (i + 1) / (overlap + 1)
patches_3d[:, :, :, i, :] *= (i + 1) / (overlap + 1)
patches_3d[:, :, :, :, i] *= (i + 1) / (overlap + 1)
patches_3d[:, :, -(i + 1), :, :] *= (i + 1) / (overlap + 1)
patches_3d[:, :, :, -(i + 1), :] *= (i + 1) / (overlap + 1)
patches_3d[:, :, :, :, -(i + 1)] *= (i + 1) / (overlap + 1)

output_vol = combine_patches_3d(patches_3d, opt.patch_size, output_shape,
stride=int(opt.patch_size * (1 - opt.overlap_ratio)))

ones = torch.ones_like(patches_3d).cpu()
for i in range(overlap):
ones[:, :, i, :, :] *= (i + 1) / (overlap + 1)
ones[:, :, :, i, :] *= (i + 1) / (overlap + 1)
ones[:, :, :, :, i] *= (i + 1) / (overlap + 1)
ones[:, :, -(i + 1), :, :] *= (i + 1) / (overlap + 1)
ones[:, :, :, -(i + 1), :] *= (i + 1) / (overlap + 1)
ones[:, :, :, :, -(i + 1)] *= (i + 1) / (overlap + 1)
ones_vol = combine_patches_3d(ones, opt.patch_size, output_shape, stride=int(opt.patch_size * (1 - opt.overlap_ratio)))
recon_vol = output_vol.cpu() / ones_vol

return recon_vol
def _create_patch_weights(patch_size, overlap, dtype, device):
weights = torch.ones((patch_size, patch_size, patch_size), dtype=dtype, device=device)

if overlap <= 0:
return weights.unsqueeze(0).unsqueeze(0)

ramp = torch.arange(1, overlap + 1, dtype=dtype, device=device) / (overlap + 1)
inv_ramp = ramp.flip(0)

weights[:overlap] *= ramp.view(-1, 1, 1)
weights[-overlap:] *= inv_ramp.view(-1, 1, 1)
weights[:, :overlap] *= ramp.view(1, -1, 1)
weights[:, -overlap:] *= inv_ramp.view(1, -1, 1)
weights[:, :, :overlap] *= ramp.view(1, 1, -1)
weights[:, :, -overlap:] *= inv_ramp.view(1, 1, -1)

return weights.unsqueeze(0).unsqueeze(0)


def reconstruct_volume(opt, patches_3d_list, output_shape):
overlap = int(opt.patch_size * opt.overlap_ratio)
stride = int(opt.patch_size * (1 - opt.overlap_ratio))
if stride <= 0:
raise ValueError("Overlap ratio is too large, resulting in a non-positive stride.")

if len(patches_3d_list) == 0:
raise ValueError("The patch list is empty; unable to reconstruct volume.")

patches_3d = torch.cat([patch.detach() for patch in patches_3d_list])
weight_patch = _create_patch_weights(opt.patch_size, overlap, patches_3d.dtype, patches_3d.device)

weighted_patches = patches_3d * weight_patch
output_vol = combine_patches_3d(weighted_patches, opt.patch_size, output_shape, stride=stride).cpu()

weight_tiles = weight_patch.repeat(patches_3d.shape[0], 1, 1, 1, 1)
ones_vol = combine_patches_3d(weight_tiles, opt.patch_size, output_shape, stride=stride).cpu()

if ones_vol.is_floating_point():
eps = torch.finfo(ones_vol.dtype).eps
ones_vol = torch.clamp(ones_vol, min=eps)
else:
ones_vol = ones_vol.masked_fill_(ones_vol == 0, 1)

recon_vol = output_vol / ones_vol

return recon_vol

def pad_volume(vol, dim):
vol = change_dim(vol, dim)

slices_num = vol.shape[0]
if slices_num > dim:
start = int((slices_num - dim)/2)
end = start + dim
res_vol = vol[start:end, :, :]
else:
pad_vol = torch.zeros(dim, dim, dim) - 1
start = int((dim - slices_num)/2)
end = start + slices_num
pad_vol[start:end, :, :] = vol
res_vol = pad_vol

return res_vol
slices_num = vol.shape[0]
if slices_num > dim:
start = int((slices_num - dim) / 2)
end = start + dim
return vol[start:end, :, :]

pad_vol = torch.full((dim, dim, dim), -1, dtype=vol.dtype, device=vol.device)
start = int((dim - slices_num) / 2)
end = start + slices_num
pad_vol[start:end, :, :] = vol

return pad_vol

def change_dim(image, target_dim):
if len(image.shape) == 2:
target_size = (target_dim, target_dim)
elif len(image.shape) == 3:
target_size = (image.shape[0], target_dim, target_dim)
else:
target_size = None
print(f'number of image dimensions is {len(image.shape)} != 2 or 3')

current_size = image.shape

# Initialize the output array with zeros
if 'torch' in str(image.dtype):
output_image = torch.zeros(target_size, dtype=image.dtype) - 1
else:
output_image = np.zeros(target_size, dtype=image.dtype) - 1

# Calculate the amount to pad or crop for each dimension
pad_crop = [((t - c) // 2, (t - c + 1) // 2) for t, c in zip(target_size, current_size)]

input_slices = tuple(slice(max(0, -p[0]), c - max(0, -p[1])) for p, c in zip(pad_crop, current_size))
output_slices = tuple(slice(max(0, p[0]), t - max(0, p[1])) for p, t in zip(pad_crop, target_size))

# Copy the data from the input to the output
output_image[output_slices] = image[input_slices]

return output_image
if len(image.shape) == 2:
target_size = (target_dim, target_dim)
elif len(image.shape) == 3:
target_size = (image.shape[0], target_dim, target_dim)
else:
raise ValueError(f'number of image dimensions is {len(image.shape)} != 2 or 3')

current_size = image.shape
slices = []
pads = []

for current, target in zip(current_size, target_size):
if current <= target:
slices.append(slice(None))
total = target - current
pad_before = total // 2
pad_after = total - pad_before
pads.append((pad_before, pad_after))
else:
start = (current - target) // 2
slices.append(slice(start, start + target))
pads.append((0, 0))

cropped = image[tuple(slices)]

if isinstance(image, torch.Tensor):
if all(before == 0 and after == 0 for before, after in pads):
return cropped
pad_values = []
for before, after in reversed(pads):
pad_values.extend([before, after])
return F.pad(cropped, tuple(pad_values), mode='constant', value=-1)
else:
return np.pad(cropped, pads, mode='constant', constant_values=-1)

def calc_dims(case, opt):
new_dim = np.zeros(3, dtype=int)
Expand Down Expand Up @@ -353,38 +378,38 @@ def simple_train_preprocess(opt):
interp_vol = torch.from_numpy(interp_vol_nda).to(torch.float32).cpu().detach()
interp_vol = interp_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d]

interp_patches = extract_patches_with_overlap(interp_vol, opt.patch_size, opt.overlap_ratio)
if 'coronal' in opt.planes:
cor_atme_vol = torch.load(os.path.join(opt.main_root, opt.atme_cor_root, 'data', 'generation', f'case_{case_idx}', 'atme_vol.pt')).cpu().detach()
cor_atme_vol = cor_atme_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d]
cor_atme_patches = extract_patches_with_overlap(cor_atme_vol, opt.patch_size, opt.overlap_ratio)
assert (len(interp_patches) == len(cor_atme_patches))
if 'axial' in opt.planes:
ax_atme_vol = torch.load(os.path.join(opt.main_root, opt.atme_ax_root, 'data', 'generation', f'case_{case_idx}', 'atme_vol.pt')).cpu().detach()
ax_atme_vol = ax_atme_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d]
ax_atme_patches = extract_patches_with_overlap(ax_atme_vol, opt.patch_size, opt.overlap_ratio)
assert (len(interp_patches) == len(ax_atme_patches))
if 'sagittal' in opt.planes:
sag_atme_vol = torch.load(os.path.join(opt.main_root, opt.atme_sag_root, 'data', 'generation', f'case_{case_idx}', 'atme_vol.pt')).cpu().detach()
sag_atme_vol = sag_atme_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d]
sag_atme_patches = extract_patches_with_overlap(sag_atme_vol, opt.patch_size, opt.overlap_ratio)
assert (len(interp_patches) == len(sag_atme_patches))
for i in range(len(interp_patches)):
interp_patch = interp_patches[i]['patch'].unsqueeze(0).clone()
data = {'interp_patch': interp_patch}
if 'coronal' in opt.planes:
cor_atme_patch = cor_atme_patches[i]['patch'].unsqueeze(0).clone()
data['cor_atme_patch'] = cor_atme_patch
if 'axial' in opt.planes:
ax_atme_patch = ax_atme_patches[i]['patch'].unsqueeze(0).clone()
data['ax_atme_patch'] = ax_atme_patch
if 'sagittal' in opt.planes:
sag_atme_patch = sag_atme_patches[i]['patch'].unsqueeze(0).clone()
data['sag_atme_patch'] = sag_atme_patch
torch.save(data, os.path.join(save_train_dir, f'data_{save_idx}.pt'))
save_idx += 1
interp_patches = extract_patches_with_overlap(interp_vol, opt.patch_size, opt.overlap_ratio)
if 'coronal' in opt.planes:
cor_atme_vol = torch.load(os.path.join(opt.main_root, opt.atme_cor_root, 'data', 'generation', f'case_{case_idx}', 'atme_vol.pt')).cpu().detach()
cor_atme_vol = cor_atme_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d]
cor_atme_patches = extract_patches_with_overlap(cor_atme_vol, opt.patch_size, opt.overlap_ratio)
assert (interp_patches.shape[0] == cor_atme_patches.shape[0])
if 'axial' in opt.planes:
ax_atme_vol = torch.load(os.path.join(opt.main_root, opt.atme_ax_root, 'data', 'generation', f'case_{case_idx}', 'atme_vol.pt')).cpu().detach()
ax_atme_vol = ax_atme_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d]
ax_atme_patches = extract_patches_with_overlap(ax_atme_vol, opt.patch_size, opt.overlap_ratio)
assert (interp_patches.shape[0] == ax_atme_patches.shape[0])
if 'sagittal' in opt.planes:
sag_atme_vol = torch.load(os.path.join(opt.main_root, opt.atme_sag_root, 'data', 'generation', f'case_{case_idx}', 'atme_vol.pt')).cpu().detach()
sag_atme_vol = sag_atme_vol[half_d : -half_d, half_d : -half_d, half_d : -half_d]
sag_atme_patches = extract_patches_with_overlap(sag_atme_vol, opt.patch_size, opt.overlap_ratio)
assert (interp_patches.shape[0] == sag_atme_patches.shape[0])


for i in range(interp_patches.shape[0]):
interp_patch = interp_patches[i].unsqueeze(0).clone()
data = {'interp_patch': interp_patch}
if 'coronal' in opt.planes:
cor_atme_patch = cor_atme_patches[i].unsqueeze(0).clone()
data['cor_atme_patch'] = cor_atme_patch
if 'axial' in opt.planes:
ax_atme_patch = ax_atme_patches[i].unsqueeze(0).clone()
data['ax_atme_patch'] = ax_atme_patch
if 'sagittal' in opt.planes:
sag_atme_patch = sag_atme_patches[i].unsqueeze(0).clone()
data['sag_atme_patch'] = sag_atme_patch
torch.save(data, os.path.join(save_train_dir, f'data_{save_idx}.pt'))
save_idx += 1

def read_MRI_case(case_path, format):
if format == 'dicom':
Expand Down