diff --git a/networks/__pycache__/generators.cpython-310.pyc b/networks/__pycache__/generators.cpython-310.pyc new file mode 100644 index 0000000..01a5bc0 Binary files /dev/null and b/networks/__pycache__/generators.cpython-310.pyc differ diff --git a/networks/generators.py b/networks/generators.py index 9b1625b..8a8c510 100644 --- a/networks/generators.py +++ b/networks/generators.py @@ -139,140 +139,229 @@ def forward(self, x_thisBranch, x_otherBranch): return z -class dualAtt_24(nn.Module): - - def __init__(self): - super(dualAtt_24, self).__init__() - - self.relu = nn.ReLU(inplace=True) - self.conv3d_7 = nn.Conv3d(in_channels=64, out_channels=16, kernel_size=1, stride=(1, 1, 1), padding=0) - self.pathC_bn1 = nn.BatchNorm3d(64) - - - self.conv3d_8 = nn.Conv3d(in_channels=16, out_channels=4, kernel_size=3, stride=(1, 1, 1), padding=1) - - self.conv3d_9 = nn.Conv3d(in_channels=4, out_channels=1, kernel_size=3, stride=(1, 1, 1), padding=1) - self.pathC_bn2 = nn.BatchNorm3d(1) - - self.fc1 = nn.Linear(9216, 128) - self.fc2 = nn.Linear(128, 32) - self.fc3 = nn.Linear(32, 6) - - """layers for path global""" - self.path1_block1_conv = nn.Conv3d( - in_channels=1, - out_channels=32, - kernel_size=3, - stride=(1, 1, 1), - padding=1, - bias=False) - self.path1_block1_bn = nn.BatchNorm3d(32) - self.maxpool_downsample_pathGlobal11 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) - self.path1_block2_conv = nn.Conv3d( - in_channels=32, - out_channels=32, - kernel_size=3, - stride=(1, 1, 1), - padding=1, - bias=False) - self.path1_block2_bn = nn.BatchNorm3d(32) - self.maxpool_downsample_pathGlobal12 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(1,2,2), padding=1) - self.path1_block3_NLCross = NLBlockND_cross(32) - - self.path2_block1_conv = nn.Conv3d( - in_channels=1, - out_channels=32, - kernel_size=3, - stride=(1, 1, 1), - padding=1, - bias=False) - self.path2_block1_bn = nn.BatchNorm3d(32) - self.maxpool_downsample_pathGlobal21 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) - self.path2_block2_conv = nn.Conv3d( - in_channels=32, - out_channels=32, - kernel_size=3, - stride=(1, 1, 1), - padding=1, - bias=False) - self.path2_block2_bn = nn.BatchNorm3d(32) - self.maxpool_downsample_pathGlobal22 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(1,2,2), padding=1) - self.path2_block3_NLCross = NLBlockND_cross(32) - - - for m in self.modules(): - if isinstance(m, nn.Conv3d): - m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') - elif isinstance(m, nn.BatchNorm3d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def forward(self, x): - # total_start_time = time.time() - x_path1 = torch.unsqueeze(x[:, 0, :, :, :], 1) - x_path2 = torch.unsqueeze(x[:, 1, :, :, :], 1) - - """path global (attention)""" - x_path1 = self.path1_block1_conv(x_path1) - x_path1 = self.path1_block1_bn(x_path1) - x_path1 = self.relu(x_path1) - x_path1 = self.maxpool_downsample_pathGlobal11(x_path1) - # print(x_path1.shape) - - x_path1 = self.path1_block2_conv(x_path1) - x_path1 = self.path1_block2_bn(x_path1) - x_path1 = self.relu(x_path1) - x_path1_0 = self.maxpool_downsample_pathGlobal12(x_path1) - # print(x_path1.shape) - - - x_path2 = self.path2_block1_conv(x_path2) - x_path2 = self.path2_block1_bn(x_path2) - x_path2 = self.relu(x_path2) - x_path2 = self.maxpool_downsample_pathGlobal21(x_path2) - # print(x_path2.shape) - - x_path2 = self.path2_block2_conv(x_path2) - x_path2 = self.path2_block2_bn(x_path2) - x_path2 = self.relu(x_path2) - x_path2_0 = self.maxpool_downsample_pathGlobal22(x_path2) - # print(x_path2.shape) - - x_path1 = self.path1_block3_NLCross(x_path1_0, x_path2_0) - x_path1 = self.relu(x_path1) - - x_path2 = self.path2_block3_NLCross(x_path2_0, x_path1_0) - x_path2 = self.relu(x_path2) - - x_pathC = torch.cat((x_path1, x_path2), 1) - - """path combined""" - x = x_pathC - x = self.pathC_bn1(x) - - x = self.conv3d_7(x) - x = self.relu(x) +# class dualAtt_24(nn.Module): + +# def __init__(self): +# super(dualAtt_24, self).__init__() + +# self.relu = nn.ReLU(inplace=True) +# self.conv3d_7 = nn.Conv3d(in_channels=64, out_channels=16, kernel_size=1, stride=(1, 1, 1), padding=0) +# self.pathC_bn1 = nn.BatchNorm3d(64) + + +# self.conv3d_8 = nn.Conv3d(in_channels=16, out_channels=4, kernel_size=3, stride=(1, 1, 1), padding=1) + +# self.conv3d_9 = nn.Conv3d(in_channels=4, out_channels=1, kernel_size=3, stride=(1, 1, 1), padding=1) +# self.pathC_bn2 = nn.BatchNorm3d(1) + +# self.fc1 = nn.Linear(9216, 128) +# self.fc2 = nn.Linear(128, 32) +# self.fc3 = nn.Linear(32, 6) + +# """layers for path global""" +# self.path1_block1_conv = nn.Conv3d( +# in_channels=1, +# out_channels=32, +# kernel_size=3, +# stride=(1, 1, 1), +# padding=1, +# bias=False) +# self.path1_block1_bn = nn.BatchNorm3d(32) +# self.maxpool_downsample_pathGlobal11 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) +# self.path1_block2_conv = nn.Conv3d( +# in_channels=32, +# out_channels=32, +# kernel_size=3, +# stride=(1, 1, 1), +# padding=1, +# bias=False) +# self.path1_block2_bn = nn.BatchNorm3d(32) +# self.maxpool_downsample_pathGlobal12 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(1,2,2), padding=1) +# self.path1_block3_NLCross = NLBlockND_cross(32) + +# self.path2_block1_conv = nn.Conv3d( +# in_channels=1, +# out_channels=32, +# kernel_size=3, +# stride=(1, 1, 1), +# padding=1, +# bias=False) +# self.path2_block1_bn = nn.BatchNorm3d(32) +# self.maxpool_downsample_pathGlobal21 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) +# self.path2_block2_conv = nn.Conv3d( +# in_channels=32, +# out_channels=32, +# kernel_size=3, +# stride=(1, 1, 1), +# padding=1, +# bias=False) +# self.path2_block2_bn = nn.BatchNorm3d(32) +# self.maxpool_downsample_pathGlobal22 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(1,2,2), padding=1) +# self.path2_block3_NLCross = NLBlockND_cross(32) + + +# for m in self.modules(): +# if isinstance(m, nn.Conv3d): +# m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') +# elif isinstance(m, nn.BatchNorm3d): +# m.weight.data.fill_(1) +# m.bias.data.zero_() + +# def forward(self, x): +# # total_start_time = time.time() +# x_path1 = torch.unsqueeze(x[:, 0, :, :, :], 1) +# x_path2 = torch.unsqueeze(x[:, 1, :, :, :], 1) + +# """path global (attention)""" +# x_path1 = self.path1_block1_conv(x_path1) +# x_path1 = self.path1_block1_bn(x_path1) +# x_path1 = self.relu(x_path1) +# x_path1 = self.maxpool_downsample_pathGlobal11(x_path1) +# # print(x_path1.shape) + +# x_path1 = self.path1_block2_conv(x_path1) +# x_path1 = self.path1_block2_bn(x_path1) +# x_path1 = self.relu(x_path1) +# x_path1_0 = self.maxpool_downsample_pathGlobal12(x_path1) +# # print(x_path1.shape) + + +# x_path2 = self.path2_block1_conv(x_path2) +# x_path2 = self.path2_block1_bn(x_path2) +# x_path2 = self.relu(x_path2) +# x_path2 = self.maxpool_downsample_pathGlobal21(x_path2) +# # print(x_path2.shape) + +# x_path2 = self.path2_block2_conv(x_path2) +# x_path2 = self.path2_block2_bn(x_path2) +# x_path2 = self.relu(x_path2) +# x_path2_0 = self.maxpool_downsample_pathGlobal22(x_path2) +# # print(x_path2.shape) + +# x_path1 = self.path1_block3_NLCross(x_path1_0, x_path2_0) +# x_path1 = self.relu(x_path1) + +# x_path2 = self.path2_block3_NLCross(x_path2_0, x_path1_0) +# x_path2 = self.relu(x_path2) + +# x_pathC = torch.cat((x_path1, x_path2), 1) + +# """path combined""" +# x = x_pathC +# x = self.pathC_bn1(x) + +# x = self.conv3d_7(x) +# x = self.relu(x) + +# x = self.conv3d_8(x) +# x = self.relu(x) + +# x = self.conv3d_9(x) +# x = self.pathC_bn2(x) + +# x = x.view(x.size()[0], -1) +# x = self.relu(x) + +# x = self.fc1(x) +# x = self.relu(x) + +# x = self.fc2(x) +# x = self.relu(x) + +# x = self.fc3(x) +# # time_cost = time.time() - total_start_time +# # print('1 whole cycle time cost {}s'.format(time_cost)) +# # time.sleep(30) +# return x - x = self.conv3d_8(x) - x = self.relu(x) - - x = self.conv3d_9(x) - x = self.pathC_bn2(x) - - x = x.view(x.size()[0], -1) - x = self.relu(x) - - x = self.fc1(x) - x = self.relu(x) - - x = self.fc2(x) - x = self.relu(x) +class dualAtt_24(nn.Module): + def __init__(self): + super(dualAtt_24, self).__init__() + + self.relu = nn.ReLU(inplace=True) + self.conv3d_7 = nn.Conv3d(in_channels=64, out_channels=16, kernel_size=1, stride=(1, 1, 1), padding=0) + self.pathC_bn1 = nn.BatchNorm3d(64) + + self.conv3d_8 = nn.Conv3d(in_channels=16, out_channels=4, kernel_size=3, stride=(1, 1, 1), padding=1) + + self.conv3d_9 = nn.Conv3d(in_channels=4, out_channels=1, kernel_size=3, stride=(1, 1, 1), padding=1) + self.pathC_bn2 = nn.BatchNorm3d(1) + + self.fc1 = None # Defined dynamically below + self.fc2 = nn.Linear(128, 32) + self.fc3 = nn.Linear(32, 6) + + """layers for path global""" + self.path1_block1_conv = nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, stride=(1, 1, 1), padding=1, bias=False) + self.path1_block1_bn = nn.BatchNorm3d(32) + self.maxpool_downsample_pathGlobal11 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) + self.path1_block2_conv = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=(1, 1, 1), padding=1, bias=False) + self.path1_block2_bn = nn.BatchNorm3d(32) + self.maxpool_downsample_pathGlobal12 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=1) + self.path1_block3_NLCross = NLBlockND_cross(32) + + self.path2_block1_conv = nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, stride=(1, 1, 1), padding=1, bias=False) + self.path2_block1_bn = nn.BatchNorm3d(32) + self.maxpool_downsample_pathGlobal21 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) + self.path2_block2_conv = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=(1, 1, 1), padding=1, bias=False) + self.path2_block2_bn = nn.BatchNorm3d(32) + self.maxpool_downsample_pathGlobal22 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=1) + self.path2_block3_NLCross = NLBlockND_cross(32) + + def forward(self, x): + x_path1 = torch.unsqueeze(x[:, 0, :, :, :], 1) + x_path2 = torch.unsqueeze(x[:, 1, :, :, :], 1) + + x_path1 = self.path1_block1_conv(x_path1) + x_path1 = self.path1_block1_bn(x_path1) + x_path1 = self.relu(x_path1) + x_path1 = self.maxpool_downsample_pathGlobal11(x_path1) + + x_path1 = self.path1_block2_conv(x_path1) + x_path1 = self.path1_block2_bn(x_path1) + x_path1 = self.relu(x_path1) + x_path1_0 = self.maxpool_downsample_pathGlobal12(x_path1) + + x_path2 = self.path2_block1_conv(x_path2) + x_path2 = self.path2_block1_bn(x_path2) + x_path2 = self.relu(x_path2) + x_path2 = self.maxpool_downsample_pathGlobal21(x_path2) + + x_path2 = self.path2_block2_conv(x_path2) + x_path2 = self.path2_block2_bn(x_path2) + x_path2 = self.relu(x_path2) + x_path2_0 = self.maxpool_downsample_pathGlobal22(x_path2) + + x_path1 = self.path1_block3_NLCross(x_path1_0, x_path2_0) + x_path1 = self.relu(x_path1) + + x_path2 = self.path2_block3_NLCross(x_path2_0, x_path1_0) + x_path2 = self.relu(x_path2) + + x_pathC = torch.cat((x_path1, x_path2), 1) + + x = x_pathC + x = self.pathC_bn1(x) + x = self.conv3d_7(x) + x = self.relu(x) + x = self.conv3d_8(x) + x = self.relu(x) + x = self.conv3d_9(x) + x = self.pathC_bn2(x) + x = x.view(x.size()[0], -1) + + # Dynamically set the input size of fc1 + if self.fc1 is None: + self.fc1 = nn.Linear(x.size(1), 128).to(x.device) + + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + + return x - x = self.fc3(x) - # time_cost = time.time() - total_start_time - # print('1 whole cycle time cost {}s'.format(time_cost)) - # time.sleep(30) - return x class dualAtt_25(nn.Module): def __init__(self): diff --git a/results/Gen_AttentionReg_0821-160757_load_model.pth b/results/Gen_AttentionReg_0821-160757_load_model.pth new file mode 100644 index 0000000..35d1f30 Binary files /dev/null and b/results/Gen_AttentionReg_0821-160757_load_model.pth differ diff --git a/results/Gen_AttentionReg_0821-160910_load_model.pth b/results/Gen_AttentionReg_0821-160910_load_model.pth new file mode 100644 index 0000000..a7ba509 Binary files /dev/null and b/results/Gen_AttentionReg_0821-160910_load_model.pth differ diff --git a/train_network.py b/train_network.py index 5654588..d1828f5 100644 --- a/train_network.py +++ b/train_network.py @@ -148,24 +148,24 @@ def __getitem__(self, idx): # files, so we will use random validation samples in this demo. # if status == 'val': # base_mat = np.loadtxt('mats_forVal/Case{:04}_mat{}.txt'.format(index,init_no)) - elif self.initialization == 'random_uniform': - #generate samples with random SRE in a certain range (e.g. [0-20] or [0-8]) - # if you are provided with ground truth segmentation, calculate - # the randomized base_TRE (Target Registration Error): - # base_TRE = evaluator.evaluate_transform(base_mat) + # elif self.initialization == 'random_uniform': + # #generate samples with random SRE in a certain range (e.g. [0-20] or [0-8]) + # # if you are provided with ground truth segmentation, calculate + # # the randomized base_TRE (Target Registration Error): + # # base_TRE = evaluator.evaluate_transform(base_mat) - base_mat, params_rand = generate_random_transform(gt_mat) - base_TRE = evaluator.evaluate_transform(base_mat) - uniform_target_TRE = np.random.uniform(0, 20, 1)[0] - scale_ratio = uniform_target_TRE / base_TRE - params_rand = params_rand * scale_ratio - base_mat = load_func.construct_matrix_degree(params=params_rand, - initial_transform=gt_mat) + # base_mat, params_rand = generate_random_transform(gt_mat) + # base_TRE = evaluator.evaluate_transform(base_mat) + # uniform_target_TRE = np.random.uniform(0, 20, 1)[0] + # scale_ratio = uniform_target_TRE / base_TRE + # params_rand = params_rand * scale_ratio + # base_mat = load_func.construct_matrix_degree(params=params_rand, + # initial_transform=gt_mat) - else: - print('!' * 10 + ' Initialization mode <{}> not supported!'.format(self.initialization)) - return + # else: + # print('!' * 10 + ' Initialization mode <{}> not supported!'.format(self.initialization)) + # return """loading MR and US images. In our experiments, we read images from mhd files and resample them with MR segmentation.""" sample4D = np.zeros((2, 32, 96, 96), dtype=np.ubyte) @@ -346,8 +346,10 @@ def train_model(model, criterion, optimizer, scheduler, fn_save, num_epochs=25): if __name__ == '__main__': data_dir = 'sample' - results_dir = 'results' + results_dir = 'results' + if not path.exists(results_dir): + os.makedirs(results_dir) init_mode = args.init_mode network_type = args.network_type diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..b72af88 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/__pycache__/data_loading_funcs.cpython-310.pyc b/utils/__pycache__/data_loading_funcs.cpython-310.pyc new file mode 100644 index 0000000..a0636d8 Binary files /dev/null and b/utils/__pycache__/data_loading_funcs.cpython-310.pyc differ diff --git a/utils/__pycache__/transformations.cpython-310.pyc b/utils/__pycache__/transformations.cpython-310.pyc new file mode 100644 index 0000000..1fcf92c Binary files /dev/null and b/utils/__pycache__/transformations.cpython-310.pyc differ diff --git a/utils/data_loading_funcs.py b/utils/data_loading_funcs.py index ce1afbc..a61d59c 100644 --- a/utils/data_loading_funcs.py +++ b/utils/data_loading_funcs.py @@ -24,7 +24,7 @@ def array_normalize(input_array): def fuse_images(img_ref, img_folat, alpha=0.4): """ """ - mask = (img_folat > 5).astype(np.float32) + mask = (img_folat > 5).astype(float) # print(alpha) mask[mask > 0.5] = alpha mask_comp = 1.0 - mask @@ -76,7 +76,7 @@ def construct_matrix(params, initial_transform=None): # Angles in degree version def decompose_matrix_degree(trans_matrix): eus = tfms.euler_from_matrix(trans_matrix[:3, :3]) - eus = np.asarray(eus, dtype=np.float) / np.pi * 180.0 + eus = np.asarray(eus, dtype=float) / np.pi * 180.0 params = np.asarray([trans_matrix[0, 3], trans_matrix[1, 3], trans_matrix[2, 3], @@ -85,7 +85,7 @@ def decompose_matrix_degree(trans_matrix): def construct_matrix_degree(params, initial_transform=None): if not params is np.array: - params = np.asarray(params, dtype=np.float) + params = np.asarray(params, dtype=float) radians = params[3:] / 180.0 * np.pi mat = tfms.euler_matrix(radians[0], radians[1], radians[2], 'sxyz')