From 116b382479ef9ebddcfb9826648c6842554bed6b Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Thu, 20 Jul 2023 02:28:42 +0100 Subject: [PATCH 01/23] Implement diag pruning eoc --- custom_utils/pruning/diag_pruning.py | 153 ++++++++++++++++++++++----- eoc/exp_trainability.py | 28 ++++- 2 files changed, 151 insertions(+), 30 deletions(-) diff --git a/custom_utils/pruning/diag_pruning.py b/custom_utils/pruning/diag_pruning.py index bfabe49..ab1514c 100644 --- a/custom_utils/pruning/diag_pruning.py +++ b/custom_utils/pruning/diag_pruning.py @@ -1,29 +1,36 @@ +from numpy.ma.extras import mask_cols +""" +Diagonal Block Pruning +""" +# !pip install betterspy + import torch import torch.nn as nn import torch.nn.utils.prune as prune -import scipy.sparse as sparse import numpy as np -import matplotlib.pyplot as plt +import numba as nb import betterspy +from scipy import sparse from typing import Iterable, Callable +def diag_pruning(weight: torch.tensor, mask:torch.tensor, block_size: int = 4): + num_rows, num_cols = mask.shape -from custom_utils.constants import LINEAR, CONV2D + # Compute the number of complete blocks and remaining diagonal size + num_blocks = min(num_rows // block_size, num_cols // block_size) + if num_blocks == 0: + mask[:, :] = torch.ones_like(mask) + return torch.ones_like(mask) -def diag_pruning_linear( - module: nn.Linear, - block_size: int = 4, - perm_type: str = None, -): - mask = torch.zeros_like(module.weight) - assert ( - mask.size()[0] == mask.size()[1] - ), "Diagonal pruning isn't implemented for rectangular matrix" + if num_cols > num_rows: + residual_diagonal = num_rows % block_size + else: + residual_diagonal = num_cols % block_size - num_rows = num_cols = mask.size()[0] - num_blocks = min(num_rows // block_size, num_cols // block_size) - residual_diagonal = min(num_rows % block_size, num_cols % block_size) + for i in range(num_blocks): + block_dim = block_size if i < num_blocks - 1 else residual_diagonal + # Create a mask to zero out non-block diagonal elements for i in range(num_blocks): start_row = i * block_size end_row = start_row + block_size @@ -39,27 +46,123 @@ def diag_pruning_linear( end_col = start_col + residual_diagonal mask[start_row:end_row, start_col:end_col] = 1 - if perm_type is not None: + return mask + +def diag_pruning_linear( + module: nn.Linear, + block_size: int = 4, + perm_type: str = None, + col_perm: torch.tensor = None, + row_perm: torch.tensor = None, +): + # mask = torch.zeros_like(module.weight) + + # num_rows, num_cols = mask.size() + # num_blocks = min(num_rows // block_size, num_cols // block_size) + # residual_diagonal = min(num_rows % block_size, num_cols % block_size) + + # for i in range(num_blocks): + # start_row = i * block_size + # end_row = start_row + block_size + # start_col = i * block_size + # end_col = start_col + block_size + # mask[start_row:end_row, start_col:end_col] = 1 + + # # If there is a residual diagonal, use a smaller block size to fit it + # if residual_diagonal > 0: + # start_row = num_blocks * block_size + # end_row = start_row + residual_diagonal + # start_col = num_blocks * block_size + # end_col = start_col + residual_diagonal + # mask[start_row:end_row, start_col:end_col] = 1 + assert isinstance(module, nn.Linear) + mask = torch.zeros_like(module.weight) + num_rows, num_cols = module.weight.shape + max_size = max(num_rows, num_cols) + min_size = min(num_rows, num_cols) + + num_reps = max_size // min_size + 1 + for j in range(num_reps): + if num_rows > num_cols: + diag_pruning(module.weight, mask[j * min_size:, :], block_size=block_size) + else: + diag_pruning(module.weight, mask[:, j * min_size:], block_size=block_size) + if perm_type == "RANDOM": col_perm = torch.randperm(num_cols) row_perm = torch.randperm(num_rows) mask = mask[:, col_perm] mask = mask[row_perm, :] + if perm_type == "CUSTOM": + mask = mask[:, col_perm] + mask = mask[row_perm, :] + print(mask.shape, module.weight.shape) + prune.custom_from_mask(module, "weight", mask) +def diag_pruning_conv2d( + module: nn.Conv2d, + block_size: int = 4, + perm_type: str = None, +): + assert isinstance(module, nn.Conv2d) + in_out_channels = module.weight[:, :, 0, 0] + + mask = torch.zeros_like(in_out_channels) + num_rows, num_cols = in_out_channels.shape + max_size = max(num_rows, num_cols) + min_size = min(num_rows, num_cols) + + num_reps = max_size // min_size + 1 + mask = torch.zeros_like(in_out_channels) + for j in range(num_reps): + if num_rows > num_cols: + diag_pruning(in_out_channels, mask[j * min_size:, :], block_size=block_size) + else: + diag_pruning(in_out_channels, mask[:, j * min_size:], block_size=block_size) + if perm_type == "RANDOM": + col_perm = torch.randperm(num_cols) + row_perm = torch.randperm(num_rows) + mask = mask[:, col_perm] + mask = mask[row_perm, :] + conv2d_mask = torch.zeros_like(module.weight) # 4-dimensional + non_zero_idcs = torch.nonzero(mask, as_tuple=True) + conv2d_mask[non_zero_idcs] = 1 + # sparse_matrix = sparse.csr_matrix(mask.cpu().detach().numpy()) + # betterspy.show(sparse_matrix) + + + prune.custom_from_mask(module, "weight", conv2d_mask) + + + + def exp(): - layer1 = nn.Linear(100, 100) - diag_pruning_linear(layer1, perm_type="random") - print(layer1.weight) + layer1 = nn.Linear(37, 91) + diag_pruning_linear(layer1, block_size=10, perm_type="RANDOM") + - sparse_weight = sparse.csr_matrix(layer1.weight.detach()) - print(sparse_weight) + # Show and save the sparsity pattern + sparse_matrix = sparse.csr_matrix(layer1.weight.detach().numpy()) + betterspy.show(sparse_matrix) - import matplotlib - matplotlib.use('TkAgg') - plt.spy(sparse_weight) + layer1 = nn.Linear(37, 91) + diag_pruning_linear(layer1, block_size=10, perm_type="") + # Show and save the sparsity pattern + sparse_matrix = sparse.csr_matrix(layer1.weight.detach().numpy()) + betterspy.show(sparse_matrix) if __name__ == "__main__": - exp() + a = torch.rand((2,2,2)) + b = torch.rand((2,2)) + idcs = torch.nonzero(b<0.7, as_tuple=True) + # a[idcs] = 0 + a[idcs] = 0 + # exp() + a = torch.rand((10, 100, 100)) + c = nn.Conv2d(10, 10, kernel_size=3) + d = c.weight.cpu().detach() + b = c(a) + diag_pruning_conv2d(c, block_size=15) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 0d96c00..dbdd51c 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -10,6 +10,7 @@ from meanfield import MeanField from models.fcn import FCN from custom_utils import utils +from custom_utils.pruning.diag_pruning import diag_pruning_linear def d_tanh(x): @@ -23,6 +24,18 @@ def init_weights(m, sw, sb): nn.init.normal_(m.bias, mean=0.0, std=np.sqrt(sb)) +def init_weights_pruned(m, sw, sb, prune_amount, num_classes): + if type(m) == nn.Linear: + scaling_factor = m.out_features + if "weight_mask" in m.named_buffers(): + mask = m.weight_mask + prune_amount = float(torch.sum(mask==0) / torch.numel(mask)) + scaling_factor *= (1-prune_amount) + print(f"scaling factor: {scaling_factor}") + nn.init.normal_(m.weight, mean=0.0, std=(np.sqrt(sw / scaling_factor))) + nn.init.normal_(m.bias, mean=0.0, std=np.sqrt(sb)) + + def exp_trainability(args=None) -> None: """ Explore the trainability of FCN on MNISt after being initialized far from EOC curve. @@ -30,8 +43,8 @@ def exp_trainability(args=None) -> None: # Parameters for experiments @Jay: Fix hard-coded data_type = "MNIST" - depth = 50 - width = 300 + depth = 5 + width = 2000 num_experiments = 2 num_epochs = 20 act = np.tanh @@ -64,11 +77,16 @@ def exp_trainability(args=None) -> None: num_classes=num_classes, input_dims=input_dims, ) + for module in fcn.modules(): + if isinstance(module, nn.Linear) and module.out_features != num_classes: + diag_pruning_linear(module, 10, "RANDOM") + model_name = "FCN_diag_block_10x10" + for q_star in [0.2, 0.5, 1.0, 1.5]: print(f"Calculating eoc curve for qstar {q_star}...") meanfield = MeanField(np.tanh, d_act) sw, sb = meanfield.sw_sb(q_star, 1) - group_name = f"depth-{depth}_width-{width}_q-{q_star}" # For logging purpose + group_name = f"{model_name}_depth-{depth}_width-{width}_q-{q_star}" # For logging purpose config = { "depth": depth, "width": width, @@ -83,12 +101,12 @@ def exp_trainability(args=None) -> None: "q_star": q_star, "tau": 0, "patience": patience, - "model": "FCN" + "model": model_name, } for tau_per in [0.1, 0.5, 0.8, 1, 1.2, 2.0, 10.0]: log_dir = os.path.join("logs", group_name) new_sw = sw * tau_per - fcn.apply(lambda m: init_weights(m, new_sw, sb)) + fcn.apply(lambda m: init_weights_pruned(m, new_sw, sb, 0.005, num_classes)) fcn, log_dict = utils.train_model( model=fcn, train_loader=train_loader, From e6cdcf6f5dc97063e48245543ec375458683bc62 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Thu, 20 Jul 2023 02:46:03 +0100 Subject: [PATCH 02/23] Debug scaling factor for initializing weights for pruned fcn model --- eoc/exp_trainability.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index dbdd51c..761315a 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -27,7 +27,7 @@ def init_weights(m, sw, sb): def init_weights_pruned(m, sw, sb, prune_amount, num_classes): if type(m) == nn.Linear: scaling_factor = m.out_features - if "weight_mask" in m.named_buffers(): + if hasattr(m, "weight_mask"): mask = m.weight_mask prune_amount = float(torch.sum(mask==0) / torch.numel(mask)) scaling_factor *= (1-prune_amount) From f94b6d406a2e39ad35fbd51047d5ca7347afde5d Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Thu, 20 Jul 2023 10:20:23 +0100 Subject: [PATCH 03/23] Change exp settings --- eoc/exp_trainability.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 761315a..de3c845 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -30,7 +30,7 @@ def init_weights_pruned(m, sw, sb, prune_amount, num_classes): if hasattr(m, "weight_mask"): mask = m.weight_mask prune_amount = float(torch.sum(mask==0) / torch.numel(mask)) - scaling_factor *= (1-prune_amount) + scaling_factor *= (1 - prune_amount) print(f"scaling factor: {scaling_factor}") nn.init.normal_(m.weight, mean=0.0, std=(np.sqrt(sw / scaling_factor))) nn.init.normal_(m.bias, mean=0.0, std=np.sqrt(sb)) @@ -79,8 +79,8 @@ def exp_trainability(args=None) -> None: ) for module in fcn.modules(): if isinstance(module, nn.Linear) and module.out_features != num_classes: - diag_pruning_linear(module, 10, "RANDOM") - model_name = "FCN_diag_block_10x10" + diag_pruning_linear(module, 100, "RANDOM") + model_name = "FCN_diag_block_100x100" for q_star in [0.2, 0.5, 1.0, 1.5]: print(f"Calculating eoc curve for qstar {q_star}...") From 757f527276b24f0e2bdf122fb2196c95bee9c36d Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Thu, 20 Jul 2023 10:23:36 +0100 Subject: [PATCH 04/23] Change exp settings --- eoc/exp_trainability.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index de3c845..95d8e3a 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -46,7 +46,7 @@ def exp_trainability(args=None) -> None: depth = 5 width = 2000 num_experiments = 2 - num_epochs = 20 + num_epochs = 5 act = np.tanh d_act = d_tanh tau_1 = tau_2 = 1.0 @@ -103,7 +103,7 @@ def exp_trainability(args=None) -> None: "patience": patience, "model": model_name, } - for tau_per in [0.1, 0.5, 0.8, 1, 1.2, 2.0, 10.0]: + for tau_per in [0.5, 0.8, 1, 1.2, 2.0]: log_dir = os.path.join("logs", group_name) new_sw = sw * tau_per fcn.apply(lambda m: init_weights_pruned(m, new_sw, sb, 0.005, num_classes)) From df0d021895a2a2b115cb7313eacc1f6411bf3912 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Thu, 20 Jul 2023 10:29:46 +0100 Subject: [PATCH 05/23] Change exp settings --- eoc/exp_trainability.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 95d8e3a..c2f2a6d 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -32,8 +32,8 @@ def init_weights_pruned(m, sw, sb, prune_amount, num_classes): prune_amount = float(torch.sum(mask==0) / torch.numel(mask)) scaling_factor *= (1 - prune_amount) print(f"scaling factor: {scaling_factor}") - nn.init.normal_(m.weight, mean=0.0, std=(np.sqrt(sw / scaling_factor))) - nn.init.normal_(m.bias, mean=0.0, std=np.sqrt(sb)) + nn.init.normal_(m.weight_orig, mean=0.0, std=(np.sqrt(sw / scaling_factor))) + nn.init.normal_(m.bias_orig, mean=0.0, std=np.sqrt(sb)) def exp_trainability(args=None) -> None: From 817f36b268d751e8e5c2e3970e0ebcabaccdc63b Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Thu, 20 Jul 2023 10:31:03 +0100 Subject: [PATCH 06/23] Change exp settings --- eoc/exp_trainability.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index c2f2a6d..6fff4d1 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -3,6 +3,8 @@ import numpy as np import os import wandb +import betterspy +import scipy.sparse as sparse from torch.utils.tensorboard import SummaryWriter from collections import defaultdict @@ -107,6 +109,7 @@ def exp_trainability(args=None) -> None: log_dir = os.path.join("logs", group_name) new_sw = sw * tau_per fcn.apply(lambda m: init_weights_pruned(m, new_sw, sb, 0.005, num_classes)) + betterspy.show(sparse.csr_matrix(fcn.fc0.weight.cpu().detach())) fcn, log_dict = utils.train_model( model=fcn, train_loader=train_loader, From 06b0a81314ee1d394c45344deea7fb4fa003c7c1 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Thu, 20 Jul 2023 11:19:47 +0100 Subject: [PATCH 07/23] Initialize weight_orig not weight for pruned model --- eoc/exp_trainability.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 6fff4d1..1aba668 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -34,8 +34,8 @@ def init_weights_pruned(m, sw, sb, prune_amount, num_classes): prune_amount = float(torch.sum(mask==0) / torch.numel(mask)) scaling_factor *= (1 - prune_amount) print(f"scaling factor: {scaling_factor}") - nn.init.normal_(m.weight_orig, mean=0.0, std=(np.sqrt(sw / scaling_factor))) - nn.init.normal_(m.bias_orig, mean=0.0, std=np.sqrt(sb)) + nn.init.normal_(m.weight_orig, mean=0.0, std=(np.sqrt(sw / scaling_factor))) + nn.init.normal_(m.bias, mean=0.0, std=np.sqrt(sb)) def exp_trainability(args=None) -> None: From 88ce1a4d2984327aa20bcbfa5c6cdfe8003b7c8a Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Thu, 20 Jul 2023 11:51:12 +0100 Subject: [PATCH 08/23] Initialize weight_orig not weight for pruned model --- eoc/exp_trainability.py | 2 +- models/fcn.py | 49 +++++++++++++++++++++++++++-------------- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 1aba668..918ea21 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -109,7 +109,7 @@ def exp_trainability(args=None) -> None: log_dir = os.path.join("logs", group_name) new_sw = sw * tau_per fcn.apply(lambda m: init_weights_pruned(m, new_sw, sb, 0.005, num_classes)) - betterspy.show(sparse.csr_matrix(fcn.fc0.weight.cpu().detach())) + betterspy.show(sparse.csr_matrix(fcn.fcn.fc0.weight.cpu().detach())) fcn, log_dict = utils.train_model( model=fcn, train_loader=train_loader, diff --git a/models/fcn.py b/models/fcn.py index 6c05258..97664bb 100644 --- a/models/fcn.py +++ b/models/fcn.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn -from typing import Callable, Union, Iterable +from typing import Callable, Union, List +from collections import OrderedDict class FCN(nn.Module): @@ -10,7 +11,7 @@ def __init__( input_dims: int = 784, num_classes: int = 10, num_layers: int = 5, - hidden_dims: Union[int, Iterable] = 784, + hidden_dims: Union[int, List] = 784, is_softmax: bool = False, act_func: Callable = nn.ReLU, ): @@ -20,23 +21,37 @@ def __init__( self.num_layers = num_layers self.act_func = act_func if isinstance(hidden_dims, int): - hidden_dims = [hidden_dims for _ in range(num_layers-1)] + hidden_dims = [hidden_dims for _ in range(num_layers - 1)] + layer_dims = [input_dims] + hidden_dims + [num_classes] + else: + assert len(hidden_dims) == num_layers - 1 + layer_dims = [input_dims] + hidden_dims + [num_classes] + print(layer_dims) - modules = [] + modules = OrderedDict() + for idx in range(0, num_layers - 1): + modules[f"fc{idx}"] = nn.Linear(layer_dims[idx], layer_dims[idx + 1]) + modules[f"act{idx}"] = self.act_func() + modules[f"fc{num_layers - 1}"] = nn.Linear(layer_dims[num_layers - 1], num_classes) + modules["final_softmax"] = nn.Softmax() + self.fcn = nn.Sequential(modules) - # Initial Layer - modules.append(nn.Linear(input_dims, hidden_dims[0])) - modules.append(self.act_func()) - - # Intermediate Layers - for idx in range(1, num_layers - 1): - modules.append(nn.Linear(hidden_dims[idx], hidden_dims[idx])) - modules.append(self.act_func()) - # Final Layer - modules.append(nn.Linear(hidden_dims[num_layers-2], num_classes)) - modules.append(nn.Softmax()) - - self.fcn = nn.Sequential(*modules) + # + # modules = [] + # + # # Initial Layer + # modules.append(nn.Linear(input_dims, hidden_dims[0])) + # modules.append(self.act_func()) + # + # # Intermediate Layers + # for idx in range(1, num_layers - 1): + # modules.append(nn.Linear(hidden_dims[idx], hidden_dims[idx])) + # modules.append(self.act_func()) + # # Final Layer + # modules.append(nn.Linear(hidden_dims[num_layers-2], num_classes)) + # modules.append(nn.Softmax()) + # + # self.fcn = nn.Sequential(*modules) def forward(self, x): return self.fcn(x) From 06b761c0e9273480d7d4a1c3436b20a1bb35eaa4 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Thu, 20 Jul 2023 12:03:08 +0100 Subject: [PATCH 09/23] Initialize weight_orig not weight for pruned model --- eoc/exp_trainability.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 918ea21..98a54be 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -45,7 +45,7 @@ def exp_trainability(args=None) -> None: # Parameters for experiments @Jay: Fix hard-coded data_type = "MNIST" - depth = 5 + depth = 15 width = 2000 num_experiments = 2 num_epochs = 5 From 7cd1bd7183e38be94101d23f6d41687106a86109 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Thu, 20 Jul 2023 12:03:21 +0100 Subject: [PATCH 10/23] Initialize weight_orig not weight for pruned model --- eoc/exp_trainability.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 98a54be..df88a94 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -45,10 +45,10 @@ def exp_trainability(args=None) -> None: # Parameters for experiments @Jay: Fix hard-coded data_type = "MNIST" - depth = 15 + depth = 20 width = 2000 num_experiments = 2 - num_epochs = 5 + num_epochs = 20 act = np.tanh d_act = d_tanh tau_1 = tau_2 = 1.0 From da306c6f24c705c8a86a1d4b39a2927e81065171 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Thu, 20 Jul 2023 12:04:27 +0100 Subject: [PATCH 11/23] Remove print statements --- custom_utils/pruning/diag_pruning.py | 1 - eoc/exp_trainability.py | 1 - 2 files changed, 2 deletions(-) diff --git a/custom_utils/pruning/diag_pruning.py b/custom_utils/pruning/diag_pruning.py index ab1514c..7c921e1 100644 --- a/custom_utils/pruning/diag_pruning.py +++ b/custom_utils/pruning/diag_pruning.py @@ -95,7 +95,6 @@ def diag_pruning_linear( if perm_type == "CUSTOM": mask = mask[:, col_perm] mask = mask[row_perm, :] - print(mask.shape, module.weight.shape) prune.custom_from_mask(module, "weight", mask) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index df88a94..1a4efb9 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -33,7 +33,6 @@ def init_weights_pruned(m, sw, sb, prune_amount, num_classes): mask = m.weight_mask prune_amount = float(torch.sum(mask==0) / torch.numel(mask)) scaling_factor *= (1 - prune_amount) - print(f"scaling factor: {scaling_factor}") nn.init.normal_(m.weight_orig, mean=0.0, std=(np.sqrt(sw / scaling_factor))) nn.init.normal_(m.bias, mean=0.0, std=np.sqrt(sb)) From 8130620e96ad44039c18687900bba4971f162fb5 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Fri, 21 Jul 2023 19:43:06 +0100 Subject: [PATCH 12/23] Implemented 3d plotting --- eoc/exp_trainability.py | 462 +++++++++++++++++++++++++++++----------- 1 file changed, 342 insertions(+), 120 deletions(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 1a4efb9..2ae953e 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -3,10 +3,15 @@ import numpy as np import os import wandb +import random +import json import betterspy import scipy.sparse as sparse +import argparse +import matplotlib.pyplot as plt +from matplotlib import cm -from torch.utils.tensorboard import SummaryWriter +from typing import Tuple from collections import defaultdict from meanfield import MeanField @@ -15,156 +20,220 @@ from custom_utils.pruning.diag_pruning import diag_pruning_linear +def tanh(x): + return np.tanh(x) + + def d_tanh(x): """Derivative of tanh.""" return 1.0 / np.cosh(x) ** 2 +def relu(x): + return np.maximum(0, x) + + +def d_relu(x): + d = np.zeros_like(x) + d[np.nonzero(x > 0)] = 1 + return d + + def init_weights(m, sw, sb): if type(m) == nn.Linear: - nn.init.normal_(m.weight, mean=0.0, std=(np.sqrt(sw / m.out_features))) + nn.init.normal_(m.weight, mean=0.0, std=(np.sqrt(sw / m.in_features))) nn.init.normal_(m.bias, mean=0.0, std=np.sqrt(sb)) -def init_weights_pruned(m, sw, sb, prune_amount, num_classes): +def init_weights_new(m, sw, sb): if type(m) == nn.Linear: - scaling_factor = m.out_features + scaling_factor = m.in_features if hasattr(m, "weight_mask"): mask = m.weight_mask - prune_amount = float(torch.sum(mask==0) / torch.numel(mask)) - scaling_factor *= (1 - prune_amount) + prune_amount = float(torch.sum(mask == 0) / torch.numel(mask)) + scaling_factor *= 1 - prune_amount nn.init.normal_(m.weight_orig, mean=0.0, std=(np.sqrt(sw / scaling_factor))) nn.init.normal_(m.bias, mean=0.0, std=np.sqrt(sb)) -def exp_trainability(args=None) -> None: +def get_act_func(act_func: str = "TANH") -> Tuple[nn.Module, callable, callable]: + if act_func == "TANH": + return nn.Tanh, tanh, d_tanh + elif act_func == "RELU": + return nn.ReLU, relu, d_tanh + else: + raise NotImplementedError(f"{act_func} is not implemented.") + + +def get_model( + model: str = "FCN", + input_dims: int = 784, + num_classes: int = 10, + depth: int = 5, + width: int = 300, + act_func: nn.Module = nn.Tanh, + block_size: int = None, +) -> nn.Module: + if model == "FCN": + return FCN( + input_dims=input_dims, + hidden_dims=width, + num_layers=depth, + act_func=act_func, + num_classes=num_classes, + ) + elif model == "FCN_DIAG_PERM": + fcn = FCN( + input_dims=input_dims, + hidden_dims=width, + num_layers=depth, + act_func=act_func, + num_classes=num_classes, + ) + assert ( + block_size is not None + ), f"block_size should be defined to apply block diag pruning" + for module in fcn.modules(): + if isinstance(module, nn.Linear) and module.out_features != num_classes: + diag_pruning_linear(module, block_size=block_size, perm_type="RANDOM") + return fcn + else: + raise NotImplementedError(f"model {model} is not implemented.") + + +def exp_trainability(args: argparse.Namespace = None) -> None: """ Explore the trainability of FCN on MNISt after being initialized far from EOC curve. """ - - # Parameters for experiments @Jay: Fix hard-coded - data_type = "MNIST" - depth = 20 - width = 2000 - num_experiments = 2 - num_epochs = 20 - act = np.tanh - d_act = d_tanh - tau_1 = tau_2 = 1.0 - q_star = 0.5 - lr_rate = 1e-3 - optimizer = "SGD" - weight_decay = 0 - patience = 20 - # Load dataset train_loader, test_loader, classes = utils.prepare_dataloader( - data_type=data_type, is_flatten=True + data_type=args.data_type, is_flatten=True ) num_classes = len(classes) input_dims = next(iter(train_loader))[0].size()[-1] # Pre-configuration - seed = 2 - utils.set_random_seeds(seed) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + utils.set_random_seeds(args.seed) + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cpu") + nn_act_func, act_func, d_act_func = get_act_func(args.act_func) + is_wandb = False + is_plot = True # Define Model - act_func = nn.Tanh - fcn = FCN( - num_layers=depth, - hidden_dims=width, - act_func=act_func, - num_classes=num_classes, - input_dims=input_dims, - ) - for module in fcn.modules(): - if isinstance(module, nn.Linear) and module.out_features != num_classes: - diag_pruning_linear(module, 100, "RANDOM") - model_name = "FCN_diag_block_100x100" - - for q_star in [0.2, 0.5, 1.0, 1.5]: + fcn = get_model(model=args.model, + input_dims=input_dims, + num_classes=num_classes, + depth=args.depth, + width=args.width, + act_func=nn_act_func, + block_size=args.block_size + ) + # Finding the list of tau percentages on which we do experiments + min_tau_per, max_tau_per = args.tau_range + tau_pers = np.logspace(min_tau_per, max_tau_per, args.num_taus).tolist() + + # Finding the list of q_stars on which we do experiments + q_min = args.qstar_range[0] + q_max = args.qstar_range[1] + q_stars = np.linspace(q_min, q_max, args.num_qstars).tolist() + + # Lists for 3d graphs + sw_grid = np.empty((len(tau_pers), len(q_stars))) + sb_grid = np.empty((len(tau_pers), len(q_stars))) + train_acc_grid = np.empty((len(tau_pers), len(q_stars))) + eval_acc_grid = np.empty((len(tau_pers), len(q_stars))) + + for q_idx in range(len(q_stars)): + q_star = q_stars[q_idx] + + # Calculating eoc curve print(f"Calculating eoc curve for qstar {q_star}...") - meanfield = MeanField(np.tanh, d_act) + meanfield = MeanField(act_func, d_act_func) sw, sb = meanfield.sw_sb(q_star, 1) - group_name = f"{model_name}_depth-{depth}_width-{width}_q-{q_star}" # For logging purpose - config = { - "depth": depth, - "width": width, - "epochs": num_epochs, - "seed": seed, - "activation": "tanh", - "learning_rate": lr_rate, - "optimizer": optimizer, - "weight decay": weight_decay, - "weight variance": sw, - "bias variance": sb, - "q_star": q_star, - "tau": 0, - "patience": patience, - "model": model_name, - } - for tau_per in [0.5, 0.8, 1, 1.2, 2.0]: - log_dir = os.path.join("logs", group_name) + group_name = f"{args.model}_depth-{args.depth}_width-{args.width}_q-{q_star}" # For logging purpose + + # Logging hyperparameters to wandb config + config = vars(args) + config["weight_var"] = sw + config["bias_var"] = sb + + for tau_idx in range(len(tau_pers)): + tau_per = tau_pers[tau_idx] + eval_acc = 0 + train_acc = 0 new_sw = sw * tau_per - fcn.apply(lambda m: init_weights_pruned(m, new_sw, sb, 0.005, num_classes)) - betterspy.show(sparse.csr_matrix(fcn.fcn.fc0.weight.cpu().detach())) - fcn, log_dict = utils.train_model( - model=fcn, - train_loader=train_loader, - test_loader=test_loader, - patience=patience, - num_epochs=num_epochs, - verbose=True, - device=device, - is_log_dict=True, - learning_rate=lr_rate, - optimizer=optimizer, - ) - train_accs = log_dict["train_acc"] - eval_accs = log_dict["eval_acc"] - epochs = log_dict["epochs"] - - train_accs_data = [[x, y] for (x, y) in zip(epochs, train_accs)] - eval_accs_data = [[x, y] for (x, y) in zip(epochs, eval_accs)] - train_table = wandb.Table( - data=train_accs_data, columns=["epochs", "train accuracy"] - ) - eval_table = wandb.Table( - data=eval_accs_data, columns=["epochs", "eval accuracy"] - ) - config["tau_per"] = tau_per - wandb.init( - project="Edge of Chaos", - tags=["EOC preliminary", "EOC trainability"], - group=group_name, - name=f"tau_per={tau_per}", - config=config, - ) - wandb.log( - { - group_name - + "_train": wandb.plot.line( - train_table, - "epochs", - "train accuracy", - title="EOC curve trainability", - ) - } - ) - wandb.log( - { - group_name - + "_eval": wandb.plot.line( - eval_table, - "epochs", - "eval accuracy", - title="EOC curve trainability", - ) - } - ) - wandb.run.summary["best_accuracy"] = eval_accs[-1] - wandb.finish() + sw_grid[tau_idx][q_idx] = new_sw + sb_grid[tau_idx][q_idx] = sb + for num_exp in range(args.num_exps): + fcn.apply(lambda m: init_weights(m, new_sw, sb)) + fcn, log_dict = utils.train_model( + model=fcn, + train_loader=train_loader, + test_loader=test_loader, + patience=args.patience, + num_epochs=args.epochs, + verbose=True, + device=device, + is_log_dict=True, + learning_rate=args.lr, + optimizer=args.optimizer, + l2_regularization_strength=args.weight_decay, + ) + train_accs = log_dict["train_acc"] + eval_accs = log_dict["eval_acc"] + epochs = log_dict["epochs"] + train_acc += train_accs[-1] + eval_acc += eval_accs[-1] + train_acc /= args.num_exps + eval_acc /= args.num_exps + train_acc_grid[tau_idx, q_idx] = train_acc + eval_acc_grid[tau_idx, q_idx] = eval_acc + + if is_wandb: + train_accs_data = [[x, y] for (x, y) in zip(epochs, train_accs)] + eval_accs_data = [[x, y] for (x, y) in zip(epochs, eval_accs)] + train_table = wandb.Table( + data=train_accs_data, columns=["epochs", "train accuracy"] + ) + eval_table = wandb.Table( + data=eval_accs_data, columns=["epochs", "eval accuracy"] + ) + config["tau_per"] = tau_per + wandb.init( + project="Edge of Chaos", + tags=["EOC preliminary", "EOC trainability"], + group=group_name, + name=f"tau_per={tau_per}", + config=config, + ) + wandb.log( + { + group_name + + "_train": wandb.plot.line( + train_table, + "epochs", + "train accuracy", + title="EOC curve trainability", + ) + } + ) + wandb.log( + { + group_name + + "_eval": wandb.plot.line( + eval_table, + "epochs", + "eval accuracy", + title="EOC curve trainability", + ) + } + ) + wandb.run.summary["best_accuracy"] = eval_accs[-1] + wandb.finish() + + # Logging to csv files + log_dir = os.path.join("logs", group_name) if not os.path.exists(log_dir): os.makedirs(log_dir) filename = f"tau_per-{tau_per}" @@ -172,6 +241,159 @@ def exp_trainability(args=None) -> None: print(log_dict) utils.log_data(log_dict, filepath) + # Logging 3d results + orig_log_dir = os.path.join("logs_3d", "run_") + log_dir = orig_log_dir + idx = 0 + while os.path.exists(log_dir): + log_dir = orig_log_dir + str(idx) + idx += 1 + os.makedirs(log_dir) + params_dict = vars(args) + graph_log_dict = { + "sw_grid": sw_grid.tolist(), + "sb_grid": sb_grid.tolist(), + "train_acc_grid": train_acc_grid.tolist(), + "eval_acc_grid": eval_acc_grid.tolist() + } + + params_path = os.path.join(log_dir, "params.json") + graph_log_path = os.path.join(log_dir, "3d_graph_log.json") + with open(params_path, 'w+') as f: + json.dump(params_dict, f) + with open(graph_log_path, 'w+') as f: + json.dump(graph_log_dict, f) + + if is_plot: + fig = plt.figure(figsize=plt.figaspect(1.0)) + ax = plt.axes(projection='3d') + + surf = ax.plot_surface(sw_grid, sb_grid, eval_acc_grid,) + # rstride=1, cstride=1, cmap=cm.coolwarm, + # linewidth=0, antialiased=False) + fig.colorbar(surf, shrink=0.5, aspect=10) + ax.set_xlabel("sw") + ax.set_ylabel("sb") + ax.set_zlabel("eval acc") + plt.show() + +def plot_3d(filepath:str): + with open(filepath, 'r') as f: + loaded_dict = json.load(f) + sw_grid = np.array(loaded_dict["sw_grid"]) + sb_grid = np.array(loaded_dict["sb_grid"]) + train_acc_grid = np.array(loaded_dict["train_acc_grid"]) + fig = plt.figure(figsize=(10, 8)) + ax = plt.axes(projection='3d') + + surf = ax.plot_surface(sw_grid, sb_grid, train_acc_grid, rstride=1, cstride=1, cmap=cm.coolwarm, + linewidth=0, antialiased=False) + fig.colorbar(surf, shrink=0.5, aspect=10) + ax.set_xlabel("sw") + ax.set_ylabel("sb") + ax.set_zlabel("train acc") + plt.show() if __name__ == "__main__": - exp_trainability() + parser = argparse.ArgumentParser() + + # the following arguments are only for trainability experiment. + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--num_exps", + default=2, + type=int, + help="Number of experiments to get average accuracy of trained model", + ) + parser.add_argument( + "--seed", + default=0, + type=int, + help="Random seed", + ) + parser.add_argument( + "--model", + default="FCN", + type=str, + help="Model type", + ) + parser.add_argument( + "--block_size", + default=None, + type=int, + help="Block size for pruning models", + ) + parser.add_argument( + "--data_type", + default="MNIST", + type=str, + choices=["MNIST", "CIFAR_10"], + help="Model type", + ) + parser.add_argument( + "--act_func", + default="TANH", + type=str, + help="Activation function for each layer in FCN", + ) + parser.add_argument( + "--num_taus", + default=15, + type=int, + help="number of taus(multiplicative constant for variance of weight matrix to check the thickness of" + "edge of chaos", + ) + parser.add_argument( + "--tau_range", + default=(-1, 1), + type=tuple, + help="Range of taus(multiplicative constant) for variance of weight matrix (log10 scale)", + ) + parser.add_argument( + "--qstar_range", + default=(0.1, 10.0), + type=tuple, + help="Range of taus(multiplicative constant) for variance of weight matrix", + ) + parser.add_argument( + "--num_qstars", + default=10, + type=int, + help="Number of qstars with which we do experiments", + ) + parser.add_argument( + "--depth", default=20, type=int, help="Depth of FCN" + ) + parser.add_argument( + "--width", default=300, type=int, help="width of fully-connected layer" + ) + parser.add_argument( + "--batch-size", default=128, type=int, help="batch size for SGD" + ) + parser.add_argument( + "--epochs", default=20, type=int, help="number of epochs to train FCN" + ) + parser.add_argument("--optimizer", default="SGD", type=str, help="OPTIMIZER TYPE") + parser.add_argument( + "--lr", default=1e-3, type=float, help="learning rate for training" + ) + parser.add_argument( + "--weight_decay", default=0, type=float, help="Weight decay for training" + ) + parser.add_argument( + "--patience", + default=20, + type=int, + help="Number of epochs for which the model's train acc " + "is allowed to be not decreasing.", + ) + parser.add_argument( + "--debug", action="store_true", default=False, help="debug the main experiment" + ) + + args = parser.parse_args() + exp_trainability(args) + # filepath = os.path.join("logs_3d", "run_", "3d_graph_log.json") + # plot_3d(filepath) \ No newline at end of file From 6fbb11dbb302ad29c1471ec626d5b7754f924f16 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Fri, 21 Jul 2023 19:52:06 +0100 Subject: [PATCH 13/23] Implemented 3d plotting --- eoc/exp_trainability.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 2ae953e..04c4d10 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -275,6 +275,14 @@ def exp_trainability(args: argparse.Namespace = None) -> None: ax.set_xlabel("sw") ax.set_ylabel("sb") ax.set_zlabel("eval acc") + + # EOC curve + num_taus = sw_grid.shape[0] + eoc_idx = (num_taus - 1) / 2 + eoc_sw_list = sw_grid[eoc_idx, :] + eoc_sb_list = sb_grid[eoc_idx, :] + eoc_eval_acc_list = eval_acc_grid[eoc_idx, :] + ax.plot(eoc_sw_list, eoc_sb_list, eoc_eval_acc_list, label='EOC') plt.show() def plot_3d(filepath:str): @@ -292,6 +300,14 @@ def plot_3d(filepath:str): ax.set_xlabel("sw") ax.set_ylabel("sb") ax.set_zlabel("train acc") + + # EOC curve + num_taus = sw_grid.shape[0] + eoc_idx = int((num_taus - 1) / 2) + eoc_sw_list = sw_grid[eoc_idx, :] + eoc_sb_list = sb_grid[eoc_idx, :] + eoc_eval_acc_list = train_acc_grid[eoc_idx, :] + ax.plot(eoc_sw_list, eoc_sb_list, eoc_eval_acc_list, label='EOC') plt.show() if __name__ == "__main__": @@ -392,8 +408,8 @@ def plot_3d(filepath:str): parser.add_argument( "--debug", action="store_true", default=False, help="debug the main experiment" ) - args = parser.parse_args() - exp_trainability(args) - # filepath = os.path.join("logs_3d", "run_", "3d_graph_log.json") - # plot_3d(filepath) \ No newline at end of file + assert args.num_taus % 2 != 0, f"{args.num_taus} should be odd to include tau_per = 1 which is EOC case" + # exp_trainability(args) + filepath = os.path.join("logs_3d", "run_", "3d_graph_log.json") + plot_3d(filepath) \ No newline at end of file From 900cd60ec14c4ac2635a2a2093d589bb7c2ba6e4 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Fri, 21 Jul 2023 19:59:16 +0100 Subject: [PATCH 14/23] Implemented 3d plotting --- eoc/exp_trainability.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 04c4d10..15d98b0 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -410,6 +410,6 @@ def plot_3d(filepath:str): ) args = parser.parse_args() assert args.num_taus % 2 != 0, f"{args.num_taus} should be odd to include tau_per = 1 which is EOC case" - # exp_trainability(args) - filepath = os.path.join("logs_3d", "run_", "3d_graph_log.json") - plot_3d(filepath) \ No newline at end of file + exp_trainability(args) + # filepath = os.path.join("logs_3d", "run_", "3d_graph_log.json") + # plot_3d(filepath) \ No newline at end of file From de1dd2f819f89b7e647b51fb48dd62824f5b1418 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Sat, 22 Jul 2023 02:12:06 +0100 Subject: [PATCH 15/23] Implemented 3d plotting --- eoc/exp_trainability.py | 53 ++++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 15d98b0..76275b1 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -241,6 +241,36 @@ def exp_trainability(args: argparse.Namespace = None) -> None: print(log_dict) utils.log_data(log_dict, filepath) + if is_plot: + fig = plt.figure(figsize=plt.figaspect(1.0)) + ax = plt.axes(projection='3d') + + surf = ax.plot_surface(sw_grid, sb_grid, eval_acc_grid,) + # rstride=1, cstride=1, cmap=cm.coolwarm, + # linewidth=0, antialiased=False) + fig.colorbar(surf, shrink=0.5, aspect=10) + ax.set_xlabel("sw") + ax.set_ylabel("sb") + ax.set_zlabel("eval acc") + + # EOC curve + num_taus = sw_grid.shape[0] + eoc_idx = (num_taus - 1) / 2 + eoc_sw_list = sw_grid[eoc_idx, :] + eoc_sb_list = sb_grid[eoc_idx, :] + eoc_eval_acc_list = eval_acc_grid[eoc_idx, :] + ax.plot(eoc_sw_list, eoc_sb_list, eoc_eval_acc_list, label='EOC') + plt.show() + + # logging in Wandb + wandb.init( + project="Edge of Chaos", + tags=["EOC preliminary", "EOC trainability"], + group="3d_graph", + config=vars(args), + ) + wandb.log({"3d plot": wandb.Image(fig)}) + wandb.finish() # Logging 3d results orig_log_dir = os.path.join("logs_3d", "run_") log_dir = orig_log_dir @@ -264,27 +294,6 @@ def exp_trainability(args: argparse.Namespace = None) -> None: with open(graph_log_path, 'w+') as f: json.dump(graph_log_dict, f) - if is_plot: - fig = plt.figure(figsize=plt.figaspect(1.0)) - ax = plt.axes(projection='3d') - - surf = ax.plot_surface(sw_grid, sb_grid, eval_acc_grid,) - # rstride=1, cstride=1, cmap=cm.coolwarm, - # linewidth=0, antialiased=False) - fig.colorbar(surf, shrink=0.5, aspect=10) - ax.set_xlabel("sw") - ax.set_ylabel("sb") - ax.set_zlabel("eval acc") - - # EOC curve - num_taus = sw_grid.shape[0] - eoc_idx = (num_taus - 1) / 2 - eoc_sw_list = sw_grid[eoc_idx, :] - eoc_sb_list = sb_grid[eoc_idx, :] - eoc_eval_acc_list = eval_acc_grid[eoc_idx, :] - ax.plot(eoc_sw_list, eoc_sb_list, eoc_eval_acc_list, label='EOC') - plt.show() - def plot_3d(filepath:str): with open(filepath, 'r') as f: loaded_dict = json.load(f) @@ -356,7 +365,7 @@ def plot_3d(filepath:str): ) parser.add_argument( "--num_taus", - default=15, + default=21, type=int, help="number of taus(multiplicative constant for variance of weight matrix to check the thickness of" "edge of chaos", From 0e6148edea101ab5dbef3ea046ed0a61b06857b0 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Sat, 22 Jul 2023 02:14:28 +0100 Subject: [PATCH 16/23] Implemented 3d plotting --- eoc/exp_trainability.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 76275b1..dc3e77c 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -271,6 +271,13 @@ def exp_trainability(args: argparse.Namespace = None) -> None: ) wandb.log({"3d plot": wandb.Image(fig)}) wandb.finish() + graph_log_dict = { + "sw_grid": sw_grid.tolist(), + "sb_grid": sb_grid.tolist(), + "train_acc_grid": train_acc_grid.tolist(), + "eval_acc_grid": eval_acc_grid.tolist() + } + print(graph_log_dict) # Logging 3d results orig_log_dir = os.path.join("logs_3d", "run_") log_dir = orig_log_dir @@ -278,7 +285,7 @@ def exp_trainability(args: argparse.Namespace = None) -> None: while os.path.exists(log_dir): log_dir = orig_log_dir + str(idx) idx += 1 - os.makedirs(log_dir) + os.makedirs(log_dir, exist_ok=True) params_dict = vars(args) graph_log_dict = { "sw_grid": sw_grid.tolist(), From aff89afbe180407978446a891e3296a6754f3d73 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Sat, 22 Jul 2023 23:03:37 +0100 Subject: [PATCH 17/23] Implemented 3d plotting --- eoc/exp_trainability.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index dc3e77c..7c02d71 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -144,6 +144,15 @@ def exp_trainability(args: argparse.Namespace = None) -> None: train_acc_grid = np.empty((len(tau_pers), len(q_stars))) eval_acc_grid = np.empty((len(tau_pers), len(q_stars))) + # Logging the results + orig_log_dir = os.path.join("logs_3d", "run_") + log_dir = orig_log_dir + idx = 0 + while os.path.exists(log_dir): + log_dir = orig_log_dir + str(idx) + idx += 1 + os.makedirs(log_dir, exist_ok=True) + for q_idx in range(len(q_stars)): q_star = q_stars[q_idx] @@ -279,13 +288,13 @@ def exp_trainability(args: argparse.Namespace = None) -> None: } print(graph_log_dict) # Logging 3d results - orig_log_dir = os.path.join("logs_3d", "run_") - log_dir = orig_log_dir - idx = 0 - while os.path.exists(log_dir): - log_dir = orig_log_dir + str(idx) - idx += 1 - os.makedirs(log_dir, exist_ok=True) + # orig_log_dir = os.path.join("logs_3d", "run_") + # log_dir = orig_log_dir + # idx = 0 + # while os.path.exists(log_dir): + # log_dir = orig_log_dir + str(idx) + # idx += 1 + # os.makedirs(log_dir, exist_ok=True) params_dict = vars(args) graph_log_dict = { "sw_grid": sw_grid.tolist(), From f1a4510d8f565f3f224a1079c056eb048162c5c4 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Sat, 22 Jul 2023 23:07:29 +0100 Subject: [PATCH 18/23] For debugging purpose --- eoc/exp_trainability.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 7c02d71..217492a 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -381,7 +381,7 @@ def plot_3d(filepath:str): ) parser.add_argument( "--num_taus", - default=21, + default=1, type=int, help="number of taus(multiplicative constant for variance of weight matrix to check the thickness of" "edge of chaos", @@ -400,7 +400,7 @@ def plot_3d(filepath:str): ) parser.add_argument( "--num_qstars", - default=10, + default=1, type=int, help="Number of qstars with which we do experiments", ) @@ -414,7 +414,7 @@ def plot_3d(filepath:str): "--batch-size", default=128, type=int, help="batch size for SGD" ) parser.add_argument( - "--epochs", default=20, type=int, help="number of epochs to train FCN" + "--epochs", default=1, type=int, help="number of epochs to train FCN" ) parser.add_argument("--optimizer", default="SGD", type=str, help="OPTIMIZER TYPE") parser.add_argument( From f8cc5e33cf15195ca6ed9fb1024e822e78a02355 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Sat, 22 Jul 2023 23:10:55 +0100 Subject: [PATCH 19/23] For debugging purpose --- eoc/exp_trainability.py | 1 + 1 file changed, 1 insertion(+) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 217492a..2fba3b9 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -265,6 +265,7 @@ def exp_trainability(args: argparse.Namespace = None) -> None: # EOC curve num_taus = sw_grid.shape[0] eoc_idx = (num_taus - 1) / 2 + eoc_idx = int(eoc_idx) eoc_sw_list = sw_grid[eoc_idx, :] eoc_sb_list = sb_grid[eoc_idx, :] eoc_eval_acc_list = eval_acc_grid[eoc_idx, :] From 3c8ac4f3ee0327dc12be0d2fb6329b5887da6608 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Sat, 22 Jul 2023 23:16:44 +0100 Subject: [PATCH 20/23] Update exp configs --- eoc/exp_trainability.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 2fba3b9..b8345d7 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -382,7 +382,7 @@ def plot_3d(filepath:str): ) parser.add_argument( "--num_taus", - default=1, + default=21, type=int, help="number of taus(multiplicative constant for variance of weight matrix to check the thickness of" "edge of chaos", @@ -401,7 +401,7 @@ def plot_3d(filepath:str): ) parser.add_argument( "--num_qstars", - default=1, + default=15, type=int, help="Number of qstars with which we do experiments", ) @@ -415,7 +415,7 @@ def plot_3d(filepath:str): "--batch-size", default=128, type=int, help="batch size for SGD" ) parser.add_argument( - "--epochs", default=1, type=int, help="number of epochs to train FCN" + "--epochs", default=20, type=int, help="number of epochs to train FCN" ) parser.add_argument("--optimizer", default="SGD", type=str, help="OPTIMIZER TYPE") parser.add_argument( From 9cebe957d5006acd3d839f4b79f3a67addf70d8c Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Mon, 24 Jul 2023 23:36:51 +0100 Subject: [PATCH 21/23] Update exp configs --- custom_utils/exp_results.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 custom_utils/exp_results.py diff --git a/custom_utils/exp_results.py b/custom_utils/exp_results.py new file mode 100644 index 0000000..3c8501a --- /dev/null +++ b/custom_utils/exp_results.py @@ -0,0 +1,10 @@ +import numpy as np +q_stars = np.linspace(0.1, 10.0, 15) +tau_pers = np.logspace(-1, 1, 21) +q_stars = q_stars[:2] + +accuracy = [[0.1135, 0.1135, 0.1135, 0.1135, 0.1135, 0.1135, 0.1135, 0.5165, 0.8742, 0.96355, 0.9647, 0.9657, + 0.9599, 0.94795, 0.9345, 0.9183, 0.9201, 0.9205, 0.9158, 0.89535, 0.87705, ], + [0.1135, 0.1135, 0.1135, 0.1135, 0.1135, 0.1135, 0.57705, 0.89775, 0.9637, 0.9633, 0.9589, 0.9506, + 0.93465, 0.9174, 0.87675, 0.90865, 0.91655, 0.8647, 0.81115, 0.87085, 0.6902], + ] From 11bcf71d9d46f27ea42c2899d6beeee6acc89754 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Tue, 1 Aug 2023 15:11:52 +0100 Subject: [PATCH 22/23] Update exp configs --- eoc/exp_trainability.py | 10 ++-- exp_custom_pruning.py | 100 ++++++++++++++++++++++++++++++++++++++++ models/fcn.py | 2 +- 3 files changed, 106 insertions(+), 6 deletions(-) create mode 100644 exp_custom_pruning.py diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index b8345d7..7f2c3f2 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -345,7 +345,7 @@ def plot_3d(filepath:str): ) parser.add_argument( "--num_exps", - default=2, + default=1, type=int, help="Number of experiments to get average accuracy of trained model", ) @@ -382,14 +382,14 @@ def plot_3d(filepath:str): ) parser.add_argument( "--num_taus", - default=21, + default=11, type=int, help="number of taus(multiplicative constant for variance of weight matrix to check the thickness of" "edge of chaos", ) parser.add_argument( "--tau_range", - default=(-1, 1), + default=(-0.5, 0.5), type=tuple, help="Range of taus(multiplicative constant) for variance of weight matrix (log10 scale)", ) @@ -401,7 +401,7 @@ def plot_3d(filepath:str): ) parser.add_argument( "--num_qstars", - default=15, + default=10, type=int, help="Number of qstars with which we do experiments", ) @@ -415,7 +415,7 @@ def plot_3d(filepath:str): "--batch-size", default=128, type=int, help="batch size for SGD" ) parser.add_argument( - "--epochs", default=20, type=int, help="number of epochs to train FCN" + "--epochs", default=7, type=int, help="number of epochs to train FCN" ) parser.add_argument("--optimizer", default="SGD", type=str, help="OPTIMIZER TYPE") parser.add_argument( diff --git a/exp_custom_pruning.py b/exp_custom_pruning.py new file mode 100644 index 0000000..3c57d44 --- /dev/null +++ b/exp_custom_pruning.py @@ -0,0 +1,100 @@ +import torch +import os + +from custom_utils.pruning.diag_pruning import diag_pruning_linear +from custom_utils.utils import prepare_dataloader, evaluate_model, train_model, load_model, set_random_seeds +from models.fcn import FCN + +def inverse_permutation(perm): + inv = torch.empty_like(perm) + inv[perm] = torch.arange(perm.size(0), device=perm.device) + return inv + +def generate_perm_tensor(image_size: int = 32, block_size: int = 4): + assert image_size % block_size == 0 + perm_list = [] + for row in range(image_size): + for col in range(image_size): + i = int(row/block_size) + j = int(col/block_size) + previous_block_nums = i * image_size * block_size + j * block_size * block_size + + i = row % block_size + j = col % block_size + index = previous_block_nums + i * block_size + j + perm_list.append(index) + return inverse_permutation(torch.tensor(perm_list)) + +import torch +import torch.nn as nn +import scipy.sparse as sparse +import betterspy + +def main(): + seed = 999 + set_random_seeds(seed) + cuda_device = torch.device("cuda:0") + data_type = "CIFAR_10" + + # Experimental setting + lr_rate = 1e-2 + num_epochs=300 + patience=10 + verbose = True + + # Load dataset + train_loader, test_loader, classes = prepare_dataloader( + num_workers=8, + train_batch_size=128, + eval_batch_size=256, + data_type=data_type, + is_flatten=True,) + input_dims = next(iter(train_loader))[0].size()[-1] + num_layers = 5 + fcn = FCN(num_layers=num_layers, input_dims=input_dims) + + model_filename = f"FCN_{num_layers}_{data_type}.pt" + model_dir = "saved_models" + model_filepath = os.path.join(model_dir, model_filename) + if not (os.path.exists(model_dir)): + os.makedirs(model_dir) + + # fcn = utils.load_model(fcn, model_filepath, cuda_device) + # _, eval_accuracy, _ = utils.evaluate_model(model=fcn, + # test_loader=test_loader, + # device=cuda_device, + # criterion=None) + # print(f"Before permutation: Test Accuracy {eval_accuracy}") + + # fcn = utils.FCN(num_layers=num_layers, input_dims=input_dims) + counts = 0 + for module in fcn.modules(): + if isinstance(module, nn.Linear): + a = generate_perm_tensor(image_size=32, block_size=4) + perm_tensor = torch.cat((a, a + 32*32, a + 32*32*2)) + print(perm_tensor) + diag_pruning_linear(module, block_size=16, perm_type="CUSTOM", row_perm=perm_tensor, col_perm=perm_tensor) + counts += 1 + if counts >=2: + break + sparse_matrix = sparse.csr_matrix(fcn.fcn.fc0.weight.cpu().detach().numpy()) + + # Show and save the sparsity pattern + betterspy.show(sparse_matrix) + train_model(model=fcn, + train_loader=train_loader, + test_loader=test_loader, + device=cuda_device, + learning_rate=lr_rate, + num_epochs=num_epochs, + T_max=num_epochs, + patience=patience, + verbose=verbose, + optimizer="SGD") + + _, eval_accuracy, _ = evaluate_model(model=fcn, + test_loader=test_loader, + device=cuda_device, + criterion=None) + print(f"After permutation: Test Accuracy {eval_accuracy}") +main() \ No newline at end of file diff --git a/models/fcn.py b/models/fcn.py index 97664bb..b60b2b0 100644 --- a/models/fcn.py +++ b/models/fcn.py @@ -33,7 +33,7 @@ def __init__( modules[f"fc{idx}"] = nn.Linear(layer_dims[idx], layer_dims[idx + 1]) modules[f"act{idx}"] = self.act_func() modules[f"fc{num_layers - 1}"] = nn.Linear(layer_dims[num_layers - 1], num_classes) - modules["final_softmax"] = nn.Softmax() + # modules["final_softmax"] = nn.Softmax() self.fcn = nn.Sequential(modules) # From 15692110c06b013fd138efc7ff2a1a08206e9048 Mon Sep 17 00:00:00 2001 From: vxbrandon Date: Fri, 11 Aug 2023 20:28:18 +0900 Subject: [PATCH 23/23] Implement pre-training models on various datasets --- custom_utils/pruning/diag_pruning.py | 69 ++-- custom_utils/train_models.py | 137 ++++++++ custom_utils/utils.py | 473 +++++++++++++++++++++------ eoc/exp_trainability.py | 53 ++- exp_block_pruning.py | 0 exps/exp_block_pruning.py | 88 +++++ logs_3d/run_0/3d_graph_log.json | 1 + logs_3d/run_0/params.json | 1 + models/fcn.py | 21 +- models/resnet.py | 2 + models/vggnet.py | 112 +++++++ 11 files changed, 768 insertions(+), 189 deletions(-) create mode 100644 custom_utils/train_models.py delete mode 100644 exp_block_pruning.py create mode 100644 exps/exp_block_pruning.py create mode 100644 logs_3d/run_0/3d_graph_log.json create mode 100644 logs_3d/run_0/params.json create mode 100644 models/vggnet.py diff --git a/custom_utils/pruning/diag_pruning.py b/custom_utils/pruning/diag_pruning.py index 7c921e1..18ad951 100644 --- a/custom_utils/pruning/diag_pruning.py +++ b/custom_utils/pruning/diag_pruning.py @@ -1,4 +1,5 @@ from numpy.ma.extras import mask_cols + """ Diagonal Block Pruning """ @@ -7,13 +8,11 @@ import torch import torch.nn as nn import torch.nn.utils.prune as prune -import numpy as np -import numba as nb import betterspy from scipy import sparse -from typing import Iterable, Callable -def diag_pruning(weight: torch.tensor, mask:torch.tensor, block_size: int = 4): + +def diag_pruning(weight: torch.tensor, mask: torch.tensor, block_size: int = 4): num_rows, num_cols = mask.shape # Compute the number of complete blocks and remaining diagonal size @@ -48,6 +47,7 @@ def diag_pruning(weight: torch.tensor, mask:torch.tensor, block_size: int = 4): return mask + def diag_pruning_linear( module: nn.Linear, block_size: int = 4, @@ -55,26 +55,6 @@ def diag_pruning_linear( col_perm: torch.tensor = None, row_perm: torch.tensor = None, ): - # mask = torch.zeros_like(module.weight) - - # num_rows, num_cols = mask.size() - # num_blocks = min(num_rows // block_size, num_cols // block_size) - # residual_diagonal = min(num_rows % block_size, num_cols % block_size) - - # for i in range(num_blocks): - # start_row = i * block_size - # end_row = start_row + block_size - # start_col = i * block_size - # end_col = start_col + block_size - # mask[start_row:end_row, start_col:end_col] = 1 - - # # If there is a residual diagonal, use a smaller block size to fit it - # if residual_diagonal > 0: - # start_row = num_blocks * block_size - # end_row = start_row + residual_diagonal - # start_col = num_blocks * block_size - # end_col = start_col + residual_diagonal - # mask[start_row:end_row, start_col:end_col] = 1 assert isinstance(module, nn.Linear) mask = torch.zeros_like(module.weight) num_rows, num_cols = module.weight.shape @@ -83,16 +63,16 @@ def diag_pruning_linear( num_reps = max_size // min_size + 1 for j in range(num_reps): - if num_rows > num_cols: - diag_pruning(module.weight, mask[j * min_size:, :], block_size=block_size) - else: - diag_pruning(module.weight, mask[:, j * min_size:], block_size=block_size) + if num_rows > num_cols: + diag_pruning(module.weight, mask[j * min_size :, :], block_size=block_size) + else: + diag_pruning(module.weight, mask[:, j * min_size :], block_size=block_size) if perm_type == "RANDOM": col_perm = torch.randperm(num_cols) row_perm = torch.randperm(num_rows) mask = mask[:, col_perm] mask = mask[row_perm, :] - if perm_type == "CUSTOM": + elif perm_type == "CUSTOM": mask = mask[:, col_perm] mask = mask[row_perm, :] @@ -106,41 +86,36 @@ def diag_pruning_conv2d( ): assert isinstance(module, nn.Conv2d) in_out_channels = module.weight[:, :, 0, 0] - - mask = torch.zeros_like(in_out_channels) num_rows, num_cols = in_out_channels.shape max_size = max(num_rows, num_cols) min_size = min(num_rows, num_cols) - num_reps = max_size // min_size + 1 mask = torch.zeros_like(in_out_channels) for j in range(num_reps): - if num_rows > num_cols: - diag_pruning(in_out_channels, mask[j * min_size:, :], block_size=block_size) - else: - diag_pruning(in_out_channels, mask[:, j * min_size:], block_size=block_size) + if num_rows > num_cols: + diag_pruning( + in_out_channels, mask[j * min_size :, :], block_size=block_size + ) + else: + diag_pruning( + in_out_channels, mask[:, j * min_size :], block_size=block_size + ) if perm_type == "RANDOM": col_perm = torch.randperm(num_cols) row_perm = torch.randperm(num_rows) mask = mask[:, col_perm] mask = mask[row_perm, :] - conv2d_mask = torch.zeros_like(module.weight) # 4-dimensional + conv2d_mask = torch.zeros_like(module.weight) # 4-dimensional non_zero_idcs = torch.nonzero(mask, as_tuple=True) conv2d_mask[non_zero_idcs] = 1 - # sparse_matrix = sparse.csr_matrix(mask.cpu().detach().numpy()) - # betterspy.show(sparse_matrix) - prune.custom_from_mask(module, "weight", conv2d_mask) - - def exp(): layer1 = nn.Linear(37, 91) diag_pruning_linear(layer1, block_size=10, perm_type="RANDOM") - # Show and save the sparsity pattern sparse_matrix = sparse.csr_matrix(layer1.weight.detach().numpy()) betterspy.show(sparse_matrix) @@ -154,12 +129,10 @@ def exp(): if __name__ == "__main__": - a = torch.rand((2,2,2)) - b = torch.rand((2,2)) - idcs = torch.nonzero(b<0.7, as_tuple=True) - # a[idcs] = 0 + a = torch.rand((2, 2, 2)) + b = torch.rand((2, 2)) + idcs = torch.nonzero(b < 0.7, as_tuple=True) a[idcs] = 0 - # exp() a = torch.rand((10, 100, 100)) c = nn.Conv2d(10, 10, kernel_size=3) d = c.weight.cpu().detach() diff --git a/custom_utils/train_models.py b/custom_utils/train_models.py new file mode 100644 index 0000000..d6be24a --- /dev/null +++ b/custom_utils/train_models.py @@ -0,0 +1,137 @@ +""" +Train Models +""" + +import torch +import torch.nn as nn +import os +from typing import List + +import custom_utils.utils as utils +from models.fcn import FCN +from models.resnet import ResNet18 +from models.vggnet import vgg19_bn + + +def train( + data_type: str = "MNIST", + model_type: str = "FCN", + num_layers: int = 5, # For FCN model only + seed: int = 0, + lr_rate: float = 1e-3, + num_epochs: int = 100, + epoch_rewind: int = 3, + weight_decay: float = 5e-4, + patience: int = 30, + verbose: bool = True, + model_dir: str = None, +): + utils.set_random_seeds(seed) + cuda_device = torch.device("cuda:0") + + # The CNN models assume that the dataset is of size (32, 32, 3), so we need to adapt the greyscale dataset to fit + # this size. + if (model_type == "FCN") or (data_type in ["CIFAR-10", "SVHN"]): + is_resize_greyscale = False + else: + is_resize_greyscale = True + + # Flatten dataset if model_type is "FCN". + is_flatten = True if model_type == "FCN" else False + + # Load dataset + train_loader, val_loader, test_loader, classes = utils.prepare_dataloader( + num_workers=8, + train_batch_size=128, + test_batch_size=256, + data_type=data_type, + is_flatten=is_flatten, + is_resize_greyscale=is_resize_greyscale, + train_eval_prop=[0.9, 0.1], + seed=seed, + ) + input_dims = next(iter(train_loader))[0].size()[-1] + + # Initialize model + if model_type == "FCN": + model = FCN( + num_layers=num_layers, input_dims=input_dims, num_classes=len(classes) + ) + elif model_type == "VGG-19": + model = vgg19_bn() + elif model_type == "RESNET-18": + model = ResNet18() + else: + raise NotImplementedError(f"{model_type} is not implemented.") + + # Logging directory and filename + model_filename = f"FCN_{num_layers}_{data_type}.pt" + if model_dir is None: + model_dir = "saved_models" + model_filepath = os.path.join(model_dir, model_filename) + + if not (os.path.exists(model_dir)): + os.makedirs(model_dir) + + if os.path.exists(model_filepath): + print( + "FCN is already trained. To create new pre-trained model, delete the existing model file." + ) + return + + filepath_rewind = os.path.join( + model_dir, f"FCN_{num_layers}_{data_type}_rewind_{epoch_rewind}.pt" + ) + + # Train model + utils.train_model( + model=model, + train_loader=train_loader, + test_loader=val_loader, + device=cuda_device, + optimizer="SGD", + l2_regularization_strength=weight_decay, + learning_rate=lr_rate, + num_epochs=num_epochs, + T_max=num_epochs, + verbose=verbose, + epoch_rewind=epoch_rewind, + filepath_rewind=filepath_rewind, + patience=patience, + ) + + utils.save_model(model=model, model_dir=model_dir, model_filename=model_filename) + # utils.load_model(model, os.path.join(model_dir, model_filename), cuda_device) + + _, eval_accuracy, _ = utils.evaluate_model( + model=model, test_loader=test_loader, device=cuda_device, criterion=None + ) + print(f"Number of layers: {num_layers}/ Test Accuracy: {eval_accuracy}") + + +if __name__ == "__main__": + dataset_types = ["MNIST", "CIFAR_10", "SVHN", "FASHION_MNIST"] + model_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "saved_models") + print(model_dir) + print("Pretraining Models...") + + # Pre-train FCN + model_type = "FCN" + for dataset_type in dataset_types: + print(f"\n==Training {model_type} on {dataset_type}==") + num_epochs = 300 if dataset_type == "CIFAR_10" else 100 + train(data_type=dataset_type, model_type=model_type, num_epochs=num_epochs) + + # Pre-train VGG-19bn + model_type = "VGG-19" + num_epochs = 300 + lr_rate = 5e-3 + for dataset_type in dataset_types: + print(f"Training {model_type} on {dataset_type}..") + num_epochs = 300 if dataset_type == "CIFAR_10" else 100 + train( + data_type=dataset_type, + model_type=model_type, + num_epochs=num_epochs, + lr_rate=lr_rate, + ) diff --git a/custom_utils/utils.py b/custom_utils/utils.py index 2153e81..0a5ca72 100644 --- a/custom_utils/utils.py +++ b/custom_utils/utils.py @@ -6,12 +6,13 @@ import torch import torch.nn as nn import torch.optim as optim +import torch.nn.utils.prune as prune import torchvision import sklearn.metrics import pandas as pd from torchvision import datasets, transforms from torch.utils.tensorboard import SummaryWriter -from collections import defaultdict +from typing import List from models.resnet import ResNet18 @@ -27,24 +28,14 @@ def set_random_seeds(random_seed=0): def prepare_dataloader( num_workers=8, train_batch_size=128, - eval_batch_size=256, - data_type: str = "CIFAR_10", + test_batch_size=256, + data_type: str = "MNIST", is_flatten: bool = False, + train_eval_prop: List[float] = [1.0, 0.0], + seed: int = 0, + is_resize_greyscale: bool = False, ): - # train_transform = transforms.Compose([ - # transforms.RandomCrop(32, padding=4), - # transforms.RandomHorizontalFlip(), - # transforms.ToTensor(), - # transforms.Normalize(mean=(0.485, 0.456, 0.406), - # std=(0.229, 0.224, 0.225)) - # ]) - - # test_transform = transforms.Compose([ - # transforms.ToTensor(), - # transforms.Normalize(mean=(0.485, 0.456, 0.406), - # std=(0.229, 0.224, 0.225)) - # ]) - + eval_batch_size = test_batch_size if data_type == "CIFAR_10": train_transform = transforms.Compose( [ @@ -77,28 +68,9 @@ def prepare_dataloader( train_set = torchvision.datasets.CIFAR10( root="data", train=True, download=True, transform=train_transform ) - test_set = torchvision.datasets.CIFAR10( root="data", train=False, download=True, transform=test_transform ) - - train_sampler = torch.utils.data.RandomSampler(train_set) - test_sampler = torch.utils.data.SequentialSampler(test_set) - - train_loader = torch.utils.data.DataLoader( - dataset=train_set, - batch_size=train_batch_size, - sampler=train_sampler, - num_workers=num_workers, - ) - - test_loader = torch.utils.data.DataLoader( - dataset=test_set, - batch_size=eval_batch_size, - sampler=test_sampler, - num_workers=num_workers, - ) - classes = train_set.classes elif data_type == "MNIST": train_transform = transforms.Compose( @@ -109,6 +81,22 @@ def prepare_dataloader( [transforms.ToTensor(), transforms.Normalize(mean=(0.1307,), std=(0.3081,))] ) + if is_resize_greyscale: + train_transform = transforms.Compose( + [ + train_transform, + transforms.Resize((32, 32), antialias=True), + transforms.Lambda(lambda x: torch.cat([x] * 3, axis=0)), + ], + ) + test_transform = transforms.Compose( + [ + test_transform, + transforms.Resize((32, 32), antialias=True), + transforms.Lambda(lambda x: torch.cat([x] * 3, axis=0)), + ], + ) + if is_flatten: train_transform = transforms.Compose( [train_transform, transforms.Lambda(torch.flatten)] @@ -123,29 +111,123 @@ def prepare_dataloader( test_set = datasets.MNIST( root="data", train=False, download=True, transform=test_transform ) + classes = train_set.classes + elif data_type == "FASHION_MNIST": + # Define the transformation + transform = transforms.Compose( + [ + transforms.ToTensor(), # Convert PIL image to PyTorch tensor + transforms.Normalize( + (0.5,), (0.5,) + ), # Normalize the tensor with mean and standard deviation + ] + ) + train_transform = test_transform = transform - train_sampler = torch.utils.data.RandomSampler(train_set) - test_sampler = torch.utils.data.SequentialSampler(test_set) + if is_resize_greyscale: + train_transform = transforms.Compose( + [ + train_transform, + transforms.Resize((32, 32), antialias=True), + transforms.Lambda(lambda x: torch.cat([x] * 3, axis=0)), + ], + ) + test_transform = transforms.Compose( + [ + test_transform, + transforms.Resize((32, 32), antialias=True), + transforms.Lambda(lambda x: torch.cat([x] * 3, axis=0)), + ], + ) + if is_flatten: + train_transform = transforms.Compose( + [train_transform, transforms.Lambda(torch.flatten)] + ) + test_transform = transforms.Compose( + [test_transform, transforms.Lambda(torch.flatten)] + ) - train_loader = torch.utils.data.DataLoader( - dataset=train_set, - batch_size=train_batch_size, - sampler=train_sampler, - num_workers=num_workers, + # Load the Fashion MNIST dataset + train_set = torchvision.datasets.FashionMNIST( + root="./data", + train=True, + download=True, + transform=train_transform, ) - test_loader = torch.utils.data.DataLoader( - dataset=test_set, - batch_size=eval_batch_size, - sampler=test_sampler, - num_workers=num_workers, + test_set = torchvision.datasets.FashionMNIST( + root="./data", + train=False, + download=True, + transform=test_transform, ) classes = train_set.classes + elif data_type == "SVHN": + + # Define the transformation + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + train_transform = test_transform = transform + + if is_flatten: + train_transform = transforms.Compose( + [train_transform, transforms.Lambda(torch.flatten)] + ) + test_transform = transforms.Compose( + [test_transform, transforms.Lambda(torch.flatten)] + ) + + # Load the SVHN dataset + train_set = torchvision.datasets.SVHN( + root="./data", split="train", download=True, transform=train_transform + ) + + test_set = torchvision.datasets.SVHN( + root="./data", split="test", download=True, transform=test_transform + ) + + # Get the unique labels from the training set + classes = sorted(list(set(train_set.labels.tolist()))) + else: raise NotImplementedError(f"data_type {data_type} is not implemented.") - return train_loader, test_loader, classes + generator = torch.Generator().manual_seed(seed) + train_set, eval_set = torch.utils.data.random_split( + train_set, train_eval_prop, generator + ) + train_sampler = torch.utils.data.RandomSampler(train_set) + eval_sampler = torch.utils.data.SequentialSampler(eval_set) + test_sampler = torch.utils.data.SequentialSampler(test_set) + + train_loader = torch.utils.data.DataLoader( + dataset=train_set, + batch_size=train_batch_size, + sampler=train_sampler, + num_workers=num_workers, + ) + eval_loader = torch.utils.data.DataLoader( + dataset=eval_set, + batch_size=train_batch_size, + sampler=eval_sampler, + num_workers=num_workers, + ) + test_loader = torch.utils.data.DataLoader( + dataset=test_set, + batch_size=eval_batch_size, + sampler=test_sampler, + num_workers=num_workers, + ) + + if train_eval_prop[0] == 1: + return train_loader, test_loader, classes + else: + return train_loader, eval_loader, test_loader, classes def evaluate_model(model, test_loader, device, criterion=None): @@ -207,28 +289,33 @@ def train_model( train_loader, test_loader, device, - optimizer: str = "ADAM", l1_regularization_strength=0, l2_regularization_strength=1e-4, - learning_rate=1e-3, + optimizer="ADAM", + learning_rate=1e-2, num_epochs=200, patience: int = 5, T_max: int = 200, verbose: bool = False, - is_log_dict: bool = False, + epoch_rewind: int = 0, + filepath_rewind: str = "", ): + # The training configurations were not carefully selected. + criterion = nn.CrossEntropyLoss() model.to(device) best_eval_loss = float("inf") - best_train_loss = float("inf") # It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10. - # optimizer = optim.SGD(model.parameters(), - # lr=learning_rate, - # momentum=0.9, - # weight_decay=l2_regularization_strength) - if optimizer == "ADAM": + if optimizer == "SGD": + optimizer = optim.SGD( + model.parameters(), + lr=learning_rate, + momentum=0.9, + weight_decay=l2_regularization_strength, + ) + elif optimizer == "ADAM": optimizer = optim.Adam( model.parameters(), lr=learning_rate, @@ -237,15 +324,8 @@ def train_model( weight_decay=l2_regularization_strength, amsgrad=False, ) - elif optimizer == "SGD": - optimizer = optim.SGD( - model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=l2_regularization_strength, - ) - # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500) - # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, - # milestones=[100, 150], - # gamma=0.1, - # last_epoch=-1) + else: + raise NotImplementedError(f"Optimizer {optimizer} is not implemented.") scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max) # Evaluation @@ -260,14 +340,16 @@ def train_model( ) ) - if is_log_dict: - log_dict = defaultdict(list) - # Initialize best model best_model_state_dict = model.state_dict() patience_stack = 0 for epoch in range(num_epochs): + # Save parameters for rewinding to inital parameters + if epoch == epoch_rewind: + if filepath_rewind != "": + torch.save(model.state_dict(), filepath_rewind) + # Training model.train() @@ -318,14 +400,10 @@ def train_model( eval_loss, eval_accuracy, _ = evaluate_model( model=model, test_loader=test_loader, device=device, criterion=criterion ) - # if eval_loss < best_eval_loss: - # best_model_state_dict = model.state_dict() - # best_eval_loss = eval_loss - # patience_stack = 0 - # else: - # patience_stack += 1 - if train_loss < best_train_loss: - best_train_loss = train_loss + if eval_loss <= best_eval_loss: + # best_model_state_dict = copy.deepcopy(model.state_dict()) + best_model_state_dict = model.state_dict() + best_eval_loss = eval_loss patience_stack = 0 else: patience_stack += 1 @@ -343,19 +421,8 @@ def train_model( epoch + 1, train_loss, train_accuracy, eval_loss, eval_accuracy ) ) - - if is_log_dict: - log_dict["epochs"].append(int(epoch)) - log_dict["train_loss"].append(float(train_loss)) - log_dict["eval_loss"].append(float(eval_loss)) - log_dict["train_acc"].append(float(train_accuracy)) - log_dict["eval_acc"].append(float(eval_accuracy)) - # model.load_state_dict(best_model_state_dict) - - if is_log_dict: - return model, log_dict - else: - return model + model.load_state_dict(best_model_state_dict) + return model def save_model(model, model_dir, model_filename): @@ -402,3 +469,227 @@ def print_num_params(model): def log_data(data: dict, filepath: str): log_pd = pd.DataFrame(data) log_pd.to_csv(filepath, index=False) + + +""" +Pruning Utils +""" + + +def compute_final_pruning_rate(pruning_rate, num_iterations): + """A function to compute the final pruning rate for iterative pruning. + Note that this cannot be applied for global pruning rate if the pruning rate is heterogeneous among different layers. + Args: + pruning_rate (float): Pruning rate. + num_iterations (int): Number of iterations. + Returns: + float: Final pruning rate. + """ + + final_pruning_rate = 1 - (1 - pruning_rate) ** num_iterations + + return final_pruning_rate + + +def measure_module_sparsity(module, weight=True, bias=False, use_mask=False): + num_zeros = 0 + num_elements = 0 + + if use_mask == True: + for buffer_name, buffer in module.named_buffers(): + if "weight_mask" in buffer_name and weight == True: + num_zeros += torch.sum(buffer == 0).item() + num_elements += buffer.nelement() + if "bias_mask" in buffer_name and bias == True: + num_zeros += torch.sum(buffer == 0).item() + num_elements += buffer.nelement() + else: + for param_name, param in module.named_parameters(): + if "weight" in param_name and weight == True: + num_zeros += torch.sum(param == 0).item() + num_elements += param.nelement() + if "bias" in param_name and bias == True: + num_zeros += torch.sum(param == 0).item() + num_elements += param.nelement() + + return num_zeros, num_elements, 0 + + +def measure_global_sparsity( + model, weight=True, bias=False, conv2d_use_mask=False, linear_use_mask=False +): + num_zeros = 0 + num_elements = 0 + + for module_name, module in model.named_modules(): + if isinstance(module, torch.nn.Conv2d): + module_num_zeros, module_num_elements, _ = measure_module_sparsity( + module, weight=weight, bias=bias, use_mask=conv2d_use_mask + ) + num_zeros += module_num_zeros + num_elements += module_num_elements + + elif isinstance(module, torch.nn.Linear): + module_num_zeros, module_num_elements, _ = measure_module_sparsity( + module, weight=weight, bias=bias, use_mask=linear_use_mask + ) + num_zeros += module_num_zeros + num_elements += module_num_elements + + sparsity = num_zeros / num_elements + + return num_zeros, num_elements, sparsity + + +def iterative_pruning_finetuning( + model, + train_loader, + test_loader, + device, + learning_rate, + l1_regularization_strength, + l2_regularization_strength, + learning_rate_decay=0.1, + conv2d_prune_amount=0.4, + linear_prune_amount=0.2, + num_iterations=10, + num_epochs_per_iteration=10, + model_filename_prefix="pruned_model", + model_dir="saved_models", + grouped_pruning=False, + is_structured_pruning: bool = False, + structured_dims: int = 0, + is_stop_same_acc: bool = True, +): + _, unpruned_eval_acc, _ = evaluate_model( + model=model, test_loader=test_loader, device=device, criterion=None + ) + TOLERANCE = 0.01 # Tolerance for early stopping pruning + print(f"Unpruned evaluation acc is {unpruned_eval_acc}.") + + pruned_accuracies = [] + best_eval_accuracy = 0 + + for i in range(num_iterations): + print("=" * 20) + print("Pruning and Finetuning {}/{}".format(i + 1, num_iterations)) + + if grouped_pruning == True: + # Global pruning + # I would rather call it grouped pruning. + parameters_to_prune = [] + for module_name, module in model.named_modules(): + if isinstance(module, torch.nn.Conv2d): + parameters_to_prune.append((module, "weight")) + prune.global_unstructured( + parameters_to_prune, + pruning_method=prune.L1Unstructured, + amount=conv2d_prune_amount, + ) + else: + for module_name, module in model.named_modules(): + if isinstance(module, torch.nn.Conv2d): + if is_structured_pruning: + prune.ln_structured( + module, + name="weight", + amount=conv2d_prune_amount, + n=1, + dim=structured_dims, + ) + else: + prune.l1_unstructured( + module, name="weight", amount=conv2d_prune_amount + ) + elif isinstance(module, torch.nn.Linear): + prune.l1_unstructured( + module, name="weight", amount=linear_prune_amount + ) + + _, eval_accuracy, _ = evaluate_model( + model=model, test_loader=test_loader, device=device, criterion=None + ) + + # classification_report = create_classification_report( + # model=model, test_loader=test_loader, device=device) + + num_zeros, num_elements, sparsity = measure_global_sparsity( + model, weight=True, bias=False, conv2d_use_mask=True, linear_use_mask=False + ) + + print(f"Global Sparsity: {sparsity}") + print("Conv2d Sparsity: ", 1 - (1 - conv2d_prune_amount) ** (i + 1)) + print("Test Accuracy: {:.3f}".format(eval_accuracy)) + + # print(model.conv1._forward_pre_hooks) + + print("\nFine-tuning...") + + train_model( + model=model, + train_loader=train_loader, + test_loader=test_loader, + device=device, + l1_regularization_strength=l1_regularization_strength, + l2_regularization_strength=l2_regularization_strength, + learning_rate=learning_rate * (learning_rate_decay**i), + num_epochs=num_epochs_per_iteration, + ) + + _, eval_accuracy, _ = evaluate_model( + model=model, test_loader=test_loader, device=device, criterion=None + ) + + pruned_accuracies.append(eval_accuracy.cpu()) + + # classification_report = create_classification_report( + # model=model, test_loader=test_loader, device=device) + + num_zeros, num_elements, sparsity = measure_global_sparsity( + model, weight=True, bias=False, conv2d_use_mask=True, linear_use_mask=False + ) + + print("Test Accuracy: {:.3f}".format(eval_accuracy)) + + if eval_accuracy > best_eval_accuracy: + best_eval_accuracy = eval_accuracy + + # model_filename = "{}_{}.pt".format(model_filename_prefix, i + 1) + # model_filepath = os.path.join(model_dir, model_filename) + # save_model(model=model, + # model_dir=model_dir, + # model_filename=model_filename) + # model = load_model(model=model, + # model_filepath=model_filepath, + # device=device) + if is_stop_same_acc: + if unpruned_eval_acc - eval_accuracy > TOLERANCE: + print("Stopping Pruning as it exceeds the tolerance.") + return model, pruned_accuracies + print("=" * 20) + + return model, pruned_accuracies + + +def remove_parameters(model): + for module_name, module in model.named_modules(): + if isinstance(module, torch.nn.Conv2d): + try: + prune.remove(module, "weight") + except: + pass + try: + prune.remove(module, "bias") + except: + pass + elif isinstance(module, torch.nn.Linear): + try: + prune.remove(module, "weight") + except: + pass + try: + prune.remove(module, "bias") + except: + pass + + return model diff --git a/eoc/exp_trainability.py b/eoc/exp_trainability.py index 7f2c3f2..041481b 100644 --- a/eoc/exp_trainability.py +++ b/eoc/exp_trainability.py @@ -146,12 +146,12 @@ def exp_trainability(args: argparse.Namespace = None) -> None: # Logging the results orig_log_dir = os.path.join("logs_3d", "run_") - log_dir = orig_log_dir + log_dir_3d = orig_log_dir idx = 0 - while os.path.exists(log_dir): - log_dir = orig_log_dir + str(idx) + while os.path.exists(log_dir_3d): + log_dir_3d = orig_log_dir + str(idx) idx += 1 - os.makedirs(log_dir, exist_ok=True) + os.makedirs(log_dir_3d, exist_ok=True) for q_idx in range(len(q_stars)): q_star = q_stars[q_idx] @@ -251,26 +251,7 @@ def exp_trainability(args: argparse.Namespace = None) -> None: utils.log_data(log_dict, filepath) if is_plot: - fig = plt.figure(figsize=plt.figaspect(1.0)) - ax = plt.axes(projection='3d') - - surf = ax.plot_surface(sw_grid, sb_grid, eval_acc_grid,) - # rstride=1, cstride=1, cmap=cm.coolwarm, - # linewidth=0, antialiased=False) - fig.colorbar(surf, shrink=0.5, aspect=10) - ax.set_xlabel("sw") - ax.set_ylabel("sb") - ax.set_zlabel("eval acc") - - # EOC curve - num_taus = sw_grid.shape[0] - eoc_idx = (num_taus - 1) / 2 - eoc_idx = int(eoc_idx) - eoc_sw_list = sw_grid[eoc_idx, :] - eoc_sb_list = sb_grid[eoc_idx, :] - eoc_eval_acc_list = eval_acc_grid[eoc_idx, :] - ax.plot(eoc_sw_list, eoc_sb_list, eoc_eval_acc_list, label='EOC') - plt.show() + fig = plot_3d(log_dir_3d) # logging in Wandb wandb.init( @@ -311,21 +292,28 @@ def exp_trainability(args: argparse.Namespace = None) -> None: with open(graph_log_path, 'w+') as f: json.dump(graph_log_dict, f) -def plot_3d(filepath:str): +def plot_3d(filepath:str, is_flip_xaxis: bool=False): with open(filepath, 'r') as f: loaded_dict = json.load(f) sw_grid = np.array(loaded_dict["sw_grid"]) sb_grid = np.array(loaded_dict["sb_grid"]) - train_acc_grid = np.array(loaded_dict["train_acc_grid"]) + train_acc_grid = np.array(loaded_dict["eval_acc_grid"]) + + # Flip x-axis + if is_flip_xaxis: + sw_grid = np.flip(sw_grid, axis=0) + train_acc_grid = np.flip(train_acc_grid, axis=0) + fig = plt.figure(figsize=(10, 8)) ax = plt.axes(projection='3d') + ax.invert_xaxis() surf = ax.plot_surface(sw_grid, sb_grid, train_acc_grid, rstride=1, cstride=1, cmap=cm.coolwarm, - linewidth=0, antialiased=False) + linewidth=0, antialiased=False, alpha=0.6) fig.colorbar(surf, shrink=0.5, aspect=10) ax.set_xlabel("sw") ax.set_ylabel("sb") - ax.set_zlabel("train acc") + ax.set_zlabel("eval acc") # EOC curve num_taus = sw_grid.shape[0] @@ -333,9 +321,11 @@ def plot_3d(filepath:str): eoc_sw_list = sw_grid[eoc_idx, :] eoc_sb_list = sb_grid[eoc_idx, :] eoc_eval_acc_list = train_acc_grid[eoc_idx, :] - ax.plot(eoc_sw_list, eoc_sb_list, eoc_eval_acc_list, label='EOC') + ax.plot(eoc_sw_list, eoc_sb_list, eoc_eval_acc_list, label='EOC', linewidth=5, alpha=1.0, color='black') plt.show() + return fig + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -436,6 +426,7 @@ def plot_3d(filepath:str): ) args = parser.parse_args() assert args.num_taus % 2 != 0, f"{args.num_taus} should be odd to include tau_per = 1 which is EOC case" + exp_trainability(args) - # filepath = os.path.join("logs_3d", "run_", "3d_graph_log.json") - # plot_3d(filepath) \ No newline at end of file + filepath = os.path.join("logs_3d", "run_0", "3d_graph_log.json") + plot_3d(filepath, is_flip_xaxis=False) \ No newline at end of file diff --git a/exp_block_pruning.py b/exp_block_pruning.py deleted file mode 100644 index e69de29..0000000 diff --git a/exps/exp_block_pruning.py b/exps/exp_block_pruning.py new file mode 100644 index 0000000..acc1b68 --- /dev/null +++ b/exps/exp_block_pruning.py @@ -0,0 +1,88 @@ +""" +Train Models +""" + +import torch +import torch.nn as nn +import os +from typing import List + +import custom_utils.utils as utils + + +def train( + data_type: str = "MNIST", + num_layers_list: List[int] = [5], + seed: int = 0, + lr_rate: float = 1e-3, + num_epochs: int = 100, + epoch_rewind: int =3, + weight_decay: float = 5e-4, + patience: int = 30, +): + for num_layers in num_layers_list: + utils.set_random_seeds(seed) + cuda_device = torch.device("cuda:0") + + # Load dataset + train_loader, test_loader, classes = utils.prepare_dataloader( + num_workers=8, + train_batch_size=128, + test_batch_size=256, + data_type=data_type, + is_flatten=True, + ) + input_dims = next(iter(train_loader))[0].size()[-1] + + # Initialize model + model = utils.FCN( + num_layers=num_layers, input_dims=input_dims, num_classes=len(classes) + ) + model_filename = f"FCN_{num_layers}_{data_type}.pt" + model_dir = "saved_models" + model_filepath = os.path.join(model_dir, model_filename) + if not (os.path.exists(model_dir)): + os.makedirs(model_dir) + + if os.path.exists(model_filepath): + print( + "FCN is already trained. To create new pre-trained model, delete the existing model file." + ) + return + + filepath_rewind = os.path.join( + model_dir, f"FCN_{num_layers}_{data_type}_rewind_{epoch_rewind}.pt" + ) + utils.train_model( + model=model, + train_loader=train_loader, + test_loader=test_loader, + device=cuda_device, + optimizer="SGD", + l2_regularization_strength=weight_decay, + learning_rate=lr_rate, + num_epochs=num_epochs, + T_max=num_epochs, + verbose=True, + epoch_rewind=epoch_rewind, + filepath_rewind=filepath_rewind, + patience=patience, + ) + + utils.save_model(model=model, model_dir=model_dir, model_filename=model_filename) + # utils.load_model(model, os.path.join(model_dir, model_filename), cuda_device) + + _, eval_accuracy, _ = utils.evaluate_model( + model=model, test_loader=test_loader, device=cuda_device, criterion=None + ) + print(f"Number of layers: {num_layers}/ Test Accuracy: {eval_accuracy}") + + +# print("==SVHN==") +# train("SVHN") +# print("==FASHION MNIST") +# train("FASHION_MNIST") +# print("==CIFAR_10==") +# train("CIFAR_10") +# print("==MNIST==") +# train("MNIST") diff --git a/logs_3d/run_0/3d_graph_log.json b/logs_3d/run_0/3d_graph_log.json new file mode 100644 index 0000000..2940723 --- /dev/null +++ b/logs_3d/run_0/3d_graph_log.json @@ -0,0 +1 @@ +{"sw_grid": [[0.37238004461059915, 0.7313088689473617, 0.9617137147938714, 1.146396057626332, 1.3051079027486134, 1.446470424466619, 1.5751780883960629, 1.6941248031613738, 1.8052445965419424, 1.9099060371695256], [0.46879870100532883, 0.9206633189882839, 1.2107258344249727, 1.4432271289264396, 1.64303350390362, 1.820998374769722, 1.9830317235831625, 2.132776765450645, 2.27266829709076, 2.4044292443318107], [0.5901825977117042, 1.15904644798111, 1.5242135196732955, 1.8169153075962317, 2.068456630493478, 2.292501128833482, 2.49648902921283, 2.685006867709985, 2.8611198717865336, 3.026997076550363], [0.7429958698579588, 1.4591530268131863, 1.918871132916941, 2.2873608518107122, 2.604032615322378, 2.886087927655285, 3.1428934791413834, 3.3802233766019607, 3.6019365127813376, 3.8107635410959073], [0.9353763814222961, 1.8369648251514956, 2.4157156311873997, 2.8796167022876578, 3.2782828325701665, 3.6333694327976036, 3.956668467453268, 4.255449106344893, 4.534569407609691, 4.797467060224363], [1.1775690961646021, 2.312601698955247, 3.0412057957702037, 3.6252226427368512, 4.127113564971146, 4.574141109385063, 4.981150479721602, 5.3572930185743655, 5.7086846587842635, 6.039653194361909], [1.4824716593051075, 2.9113930461731257, 3.8286512587908126, 4.563884908353031, 5.195728144302595, 5.758502479737221, 6.270896918892233, 6.744432319510752, 7.186808184863025, 7.603472884806026], [1.8663212441638646, 3.665226689548177, 4.81998636258948, 5.745590687629524, 6.541034193636688, 7.249525105620913, 7.894591485935179, 8.490737235157964, 9.047655453614375, 9.57220523257021], [2.3495592408491954, 4.614247019458411, 6.068003316365231, 7.23327012242473, 8.234674165783794, 9.12661137890596, 9.93870183737767, 10.689204870207305, 11.390323367713219, 12.050692414191733], [2.957919834820864, 5.808992829091688, 7.6391635738234696, 9.106147567491998, 10.366840565150154, 11.489722988474517, 12.51208430332013, 13.45691164297795, 14.339567536167088, 15.170922909941172], [3.7238004461059915, 7.313088689473617, 9.617137147938715, 11.46396057626332, 13.051079027486134, 14.464704244666192, 15.75178088396063, 16.94124803161374, 18.052445965419423, 19.099060371695256]], "sb_grid": [[0.0008322804246838761, 0.21122100085919293, 0.6441674382842419, 1.1894344153733893, 1.8048055921439317, 2.4691943250043717, 3.1703060319994543, 3.900240535624925, 4.653569347593386, 5.426374311609609], [0.0008322804246838761, 0.21122100085919293, 0.6441674382842419, 1.1894344153733893, 1.8048055921439317, 2.4691943250043717, 3.1703060319994543, 3.900240535624925, 4.653569347593386, 5.426374311609609], [0.0008322804246838761, 0.21122100085919293, 0.6441674382842419, 1.1894344153733893, 1.8048055921439317, 2.4691943250043717, 3.1703060319994543, 3.900240535624925, 4.653569347593386, 5.426374311609609], [0.0008322804246838761, 0.21122100085919293, 0.6441674382842419, 1.1894344153733893, 1.8048055921439317, 2.4691943250043717, 3.1703060319994543, 3.900240535624925, 4.653569347593386, 5.426374311609609], [0.0008322804246838761, 0.21122100085919293, 0.6441674382842419, 1.1894344153733893, 1.8048055921439317, 2.4691943250043717, 3.1703060319994543, 3.900240535624925, 4.653569347593386, 5.426374311609609], [0.0008322804246838761, 0.21122100085919293, 0.6441674382842419, 1.1894344153733893, 1.8048055921439317, 2.4691943250043717, 3.1703060319994543, 3.900240535624925, 4.653569347593386, 5.426374311609609], [0.0008322804246838761, 0.21122100085919293, 0.6441674382842419, 1.1894344153733893, 1.8048055921439317, 2.4691943250043717, 3.1703060319994543, 3.900240535624925, 4.653569347593386, 5.426374311609609], [0.0008322804246838761, 0.21122100085919293, 0.6441674382842419, 1.1894344153733893, 1.8048055921439317, 2.4691943250043717, 3.1703060319994543, 3.900240535624925, 4.653569347593386, 5.426374311609609], [0.0008322804246838761, 0.21122100085919293, 0.6441674382842419, 1.1894344153733893, 1.8048055921439317, 2.4691943250043717, 3.1703060319994543, 3.900240535624925, 4.653569347593386, 5.426374311609609], [0.0008322804246838761, 0.21122100085919293, 0.6441674382842419, 1.1894344153733893, 1.8048055921439317, 2.4691943250043717, 3.1703060319994543, 3.900240535624925, 4.653569347593386, 5.426374311609609], [0.0008322804246838761, 0.21122100085919293, 0.6441674382842419, 1.1894344153733893, 1.8048055921439317, 2.4691943250043717, 3.1703060319994543, 3.900240535624925, 4.653569347593386, 5.426374311609609]], "train_acc_grid": [[0.11236666887998581, 0.10975000262260437, 0.10441666841506958, 0.1076333299279213, 0.10436666756868362, 0.10533333569765091, 0.1050499975681305, 0.10525000095367432, 0.10281666368246078, 0.10396666824817657], [0.11236666887998581, 0.9507666826248169, 0.9422666430473328, 0.9172999858856201, 0.9209166765213013, 0.5378000140190125, 0.8891833424568176, 0.11046666651964188, 0.266566663980484, 0.10270000249147415], [0.9351500272750854, 0.9732666611671448, 0.9726999998092651, 0.9722333550453186, 0.9659333229064941, 0.9567333459854126, 0.9626500010490417, 0.9483500123023987, 0.9518666863441467, 0.9566666483879089], [0.9710833430290222, 0.9828833341598511, 0.979449987411499, 0.9776999950408936, 0.9744333624839783, 0.9699000120162964, 0.9702333211898804, 0.9652166962623596, 0.965233325958252, 0.9632999897003174], [0.9794999957084656, 0.9875333309173584, 0.9831166863441467, 0.9807000160217285, 0.9804166555404663, 0.9732000231742859, 0.9734333157539368, 0.9698833227157593, 0.969083309173584, 0.9645500183105469], [0.98580002784729, 0.9929333329200745, 0.9882833361625671, 0.982283353805542, 0.9784500002861023, 0.9754833579063416, 0.9695333242416382, 0.9674166440963745, 0.9670666456222534, 0.9612666964530945], [0.9905166625976562, 0.9926666617393494, 0.9873833060264587, 0.9801333546638489, 0.9751499891281128, 0.9701833128929138, 0.9639666676521301, 0.964033305644989, 0.958133339881897, 0.9557333588600159], [0.994700014591217, 0.9896666407585144, 0.9816333055496216, 0.9700000286102295, 0.9652833342552185, 0.9578666687011719, 0.9532333612442017, 0.9473166465759277, 0.9424833059310913, 0.9398166537284851], [0.9961833357810974, 0.9798166751861572, 0.9671333432197571, 0.9542833566665649, 0.9469333291053772, 0.942883312702179, 0.9356833100318909, 0.9293166399002075, 0.9283166527748108, 0.9220166802406311], [0.9954500198364258, 0.9640499949455261, 0.9475833177566528, 0.9365166425704956, 0.9292333126068115, 0.914816677570343, 0.900433361530304, 0.8905500173568726, 0.8855166435241699, 0.868066668510437], [0.9877833127975464, 0.9453666806221008, 0.9196000099182129, 0.8787999749183655, 0.8644333481788635, 0.8256666660308838, 0.7794166803359985, 0.10365000367164612, 0.10459999740123749, 0.7007166743278503]], "eval_acc_grid": [[0.11349999904632568, 0.11349999904632568, 0.09799999743700027, 0.11349999904632568, 0.11349999904632568, 0.10320000350475311, 0.10279999673366547, 0.11349999904632568, 0.11349999904632568, 0.09740000218153], [0.11349999904632568, 0.9477999806404114, 0.9401999711990356, 0.926800012588501, 0.9223999977111816, 0.6901000142097473, 0.9034000039100647, 0.11349999904632568, 0.374099999666214, 0.11349999904632568], [0.9369000196456909, 0.9641000032424927, 0.9642999768257141, 0.9627000093460083, 0.9553999900817871, 0.9417999982833862, 0.9555000066757202, 0.9440000057220459, 0.9478999972343445, 0.9440000057220459], [0.9621999859809875, 0.964900016784668, 0.960099995136261, 0.9585999846458435, 0.9578999876976013, 0.9570000171661377, 0.9498000144958496, 0.9510999917984009, 0.9447000026702881, 0.9501000046730042], [0.9660000205039978, 0.9656000137329102, 0.9610999822616577, 0.9545999765396118, 0.9545000195503235, 0.9476000070571899, 0.9501000046730042, 0.9413999915122986, 0.9484000205993652, 0.9409999847412109], [0.9721999764442444, 0.9614999890327454, 0.9495999813079834, 0.9501000046730042, 0.9467999935150146, 0.9415000081062317, 0.9398999810218811, 0.9380999803543091, 0.9375, 0.9398999810218811], [0.9672999978065491, 0.9505000114440918, 0.9444000124931335, 0.9358000159263611, 0.9301000237464905, 0.9302999973297119, 0.925000011920929, 0.9259999990463257, 0.9236000180244446, 0.9243000149726868], [0.9642000198364258, 0.9398000240325928, 0.9348999857902527, 0.9284999966621399, 0.9229999780654907, 0.9169999957084656, 0.911899983882904, 0.9147999882698059, 0.9070000052452087, 0.9117000102996826], [0.9545000195503235, 0.9365000128746033, 0.9277999997138977, 0.9205999970436096, 0.9150999784469604, 0.9111999869346619, 0.9140999913215637, 0.9106000065803528, 0.9067999720573425, 0.9014999866485596], [0.9453999996185303, 0.9318000078201294, 0.9208999872207642, 0.9212999939918518, 0.9180999994277954, 0.9057000279426575, 0.8984000086784363, 0.8848999738693237, 0.883899986743927, 0.8604999780654907], [0.9444000124931335, 0.9289000034332275, 0.9115999937057495, 0.8726000189781189, 0.839900016784668, 0.8313000202178955, 0.7889000177383423, 0.11349999904632568, 0.0957999974489212, 0.72079998254776]]} \ No newline at end of file diff --git a/logs_3d/run_0/params.json b/logs_3d/run_0/params.json new file mode 100644 index 0000000..699fd53 --- /dev/null +++ b/logs_3d/run_0/params.json @@ -0,0 +1 @@ +{"no_cuda": false, "num_exps": 1, "seed": 0, "model": "FCN", "block_size": null, "data_type": "MNIST", "act_func": "TANH", "num_taus": 11, "tau_range": [-0.5, 0.5], "qstar_range": [0.1, 10.0], "num_qstars": 10, "depth": 20, "width": 300, "batch_size": 128, "epochs": 7, "optimizer": "SGD", "lr": 0.001, "weight_decay": 0, "patience": 20, "debug": false, "weight_var": 6.039653194361909, "bias_var": 5.426374311609609} \ No newline at end of file diff --git a/models/fcn.py b/models/fcn.py index b60b2b0..7e894d1 100644 --- a/models/fcn.py +++ b/models/fcn.py @@ -26,32 +26,15 @@ def __init__( else: assert len(hidden_dims) == num_layers - 1 layer_dims = [input_dims] + hidden_dims + [num_classes] - print(layer_dims) + print("FCN architecture: ", layer_dims) modules = OrderedDict() for idx in range(0, num_layers - 1): modules[f"fc{idx}"] = nn.Linear(layer_dims[idx], layer_dims[idx + 1]) modules[f"act{idx}"] = self.act_func() modules[f"fc{num_layers - 1}"] = nn.Linear(layer_dims[num_layers - 1], num_classes) - # modules["final_softmax"] = nn.Softmax() - self.fcn = nn.Sequential(modules) - # - # modules = [] - # - # # Initial Layer - # modules.append(nn.Linear(input_dims, hidden_dims[0])) - # modules.append(self.act_func()) - # - # # Intermediate Layers - # for idx in range(1, num_layers - 1): - # modules.append(nn.Linear(hidden_dims[idx], hidden_dims[idx])) - # modules.append(self.act_func()) - # # Final Layer - # modules.append(nn.Linear(hidden_dims[num_layers-2], num_classes)) - # modules.append(nn.Softmax()) - # - # self.fcn = nn.Sequential(*modules) + self.fcn = nn.Sequential(modules) def forward(self, x): return self.fcn(x) diff --git a/models/resnet.py b/models/resnet.py index 936a64a..d21bff0 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -5,6 +5,8 @@ Reference: [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun Deep Residual Learning for Image Recognition. arXiv:1512.03385 + +Copied from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py ''' import torch import torch.nn as nn diff --git a/models/vggnet.py b/models/vggnet.py new file mode 100644 index 0000000..db1ba8e --- /dev/null +++ b/models/vggnet.py @@ -0,0 +1,112 @@ +''' +Modified from https://github.com/pytorch/vision.git +Copied from https://github.com/chengyangfu/pytorch-vgg-cifar10/blob/master/vgg.py + +VGG net on Cifar-10 dataset +''' +import math + +import torch.nn as nn +import torch.nn.init as init + +__all__ = [ + 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', + 'vgg19_bn', 'vgg19', +] + + +class VGG(nn.Module): + ''' + VGG model + ''' + def __init__(self, features): + super(VGG, self).__init__() + self.features = features + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(512, 512), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(512, 512), + nn.ReLU(True), + nn.Linear(512, 10), + ) + # Initialize weights + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + m.bias.data.zero_() + + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + +def make_layers(cfg, batch_norm=False): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + +cfg = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', + 512, 512, 512, 512, 'M'], +} + + +def vgg11(): + """VGG 11-layer model (configuration "A")""" + return VGG(make_layers(cfg['A'])) + + +def vgg11_bn(): + """VGG 11-layer model (configuration "A") with batch normalization""" + return VGG(make_layers(cfg['A'], batch_norm=True)) + + +def vgg13(): + """VGG 13-layer model (configuration "B")""" + return VGG(make_layers(cfg['B'])) + + +def vgg13_bn(): + """VGG 13-layer model (configuration "B") with batch normalization""" + return VGG(make_layers(cfg['B'], batch_norm=True)) + + +def vgg16(): + """VGG 16-layer model (configuration "D")""" + return VGG(make_layers(cfg['D'])) + + +def vgg16_bn(): + """VGG 16-layer model (configuration "D") with batch normalization""" + return VGG(make_layers(cfg['D'], batch_norm=True)) + + +def vgg19(): + """VGG 19-layer model (configuration "E")""" + return VGG(make_layers(cfg['E'])) + + +def vgg19_bn(): + """VGG 19-layer model (configuration 'E') with batch normalization""" + return VGG(make_layers(cfg['E'], batch_norm=True)) +