Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions custom_utils/exp_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import numpy as np
q_stars = np.linspace(0.1, 10.0, 15)
tau_pers = np.logspace(-1, 1, 21)
q_stars = q_stars[:2]

accuracy = [[0.1135, 0.1135, 0.1135, 0.1135, 0.1135, 0.1135, 0.1135, 0.5165, 0.8742, 0.96355, 0.9647, 0.9657,
0.9599, 0.94795, 0.9345, 0.9183, 0.9201, 0.9205, 0.9158, 0.89535, 0.87705, ],
[0.1135, 0.1135, 0.1135, 0.1135, 0.1135, 0.1135, 0.57705, 0.89775, 0.9637, 0.9633, 0.9589, 0.9506,
0.93465, 0.9174, 0.87675, 0.90865, 0.91655, 0.8647, 0.81115, 0.87085, 0.6902],
]
129 changes: 102 additions & 27 deletions custom_utils/pruning/diag_pruning.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
from numpy.ma.extras import mask_cols

"""
Diagonal Block Pruning
"""
# !pip install betterspy

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import scipy.sparse as sparse
import numpy as np
import matplotlib.pyplot as plt
import betterspy
from typing import Iterable, Callable

from scipy import sparse

from custom_utils.constants import LINEAR, CONV2D

def diag_pruning_linear(
module: nn.Linear,
block_size: int = 4,
perm_type: str = None,
):
mask = torch.zeros_like(module.weight)
assert (
mask.size()[0] == mask.size()[1]
), "Diagonal pruning isn't implemented for rectangular matrix"
def diag_pruning(weight: torch.tensor, mask: torch.tensor, block_size: int = 4):
num_rows, num_cols = mask.shape

num_rows = num_cols = mask.size()[0]
# Compute the number of complete blocks and remaining diagonal size
num_blocks = min(num_rows // block_size, num_cols // block_size)
residual_diagonal = min(num_rows % block_size, num_cols % block_size)
if num_blocks == 0:
mask[:, :] = torch.ones_like(mask)
return torch.ones_like(mask)

if num_cols > num_rows:
residual_diagonal = num_rows % block_size
else:
residual_diagonal = num_cols % block_size

for i in range(num_blocks):
block_dim = block_size if i < num_blocks - 1 else residual_diagonal

# Create a mask to zero out non-block diagonal elements
for i in range(num_blocks):
start_row = i * block_size
end_row = start_row + block_size
Expand All @@ -39,27 +45,96 @@ def diag_pruning_linear(
end_col = start_col + residual_diagonal
mask[start_row:end_row, start_col:end_col] = 1

if perm_type is not None:
return mask


def diag_pruning_linear(
module: nn.Linear,
block_size: int = 4,
perm_type: str = None,
col_perm: torch.tensor = None,
row_perm: torch.tensor = None,
):
assert isinstance(module, nn.Linear)
mask = torch.zeros_like(module.weight)
num_rows, num_cols = module.weight.shape
max_size = max(num_rows, num_cols)
min_size = min(num_rows, num_cols)

num_reps = max_size // min_size + 1
for j in range(num_reps):
if num_rows > num_cols:
diag_pruning(module.weight, mask[j * min_size :, :], block_size=block_size)
else:
diag_pruning(module.weight, mask[:, j * min_size :], block_size=block_size)
if perm_type == "RANDOM":
col_perm = torch.randperm(num_cols)
row_perm = torch.randperm(num_rows)
mask = mask[:, col_perm]
mask = mask[row_perm, :]
elif perm_type == "CUSTOM":
mask = mask[:, col_perm]
mask = mask[row_perm, :]

prune.custom_from_mask(module, "weight", mask)


def diag_pruning_conv2d(
module: nn.Conv2d,
block_size: int = 4,
perm_type: str = None,
):
assert isinstance(module, nn.Conv2d)
in_out_channels = module.weight[:, :, 0, 0]
num_rows, num_cols = in_out_channels.shape
max_size = max(num_rows, num_cols)
min_size = min(num_rows, num_cols)
num_reps = max_size // min_size + 1
mask = torch.zeros_like(in_out_channels)
for j in range(num_reps):
if num_rows > num_cols:
diag_pruning(
in_out_channels, mask[j * min_size :, :], block_size=block_size
)
else:
diag_pruning(
in_out_channels, mask[:, j * min_size :], block_size=block_size
)
if perm_type == "RANDOM":
col_perm = torch.randperm(num_cols)
row_perm = torch.randperm(num_rows)
mask = mask[:, col_perm]
mask = mask[row_perm, :]
conv2d_mask = torch.zeros_like(module.weight) # 4-dimensional
non_zero_idcs = torch.nonzero(mask, as_tuple=True)
conv2d_mask[non_zero_idcs] = 1

prune.custom_from_mask(module, "weight", conv2d_mask)


def exp():
layer1 = nn.Linear(100, 100)
diag_pruning_linear(layer1, perm_type="random")
print(layer1.weight)
layer1 = nn.Linear(37, 91)
diag_pruning_linear(layer1, block_size=10, perm_type="RANDOM")

sparse_weight = sparse.csr_matrix(layer1.weight.detach())
print(sparse_weight)
# Show and save the sparsity pattern
sparse_matrix = sparse.csr_matrix(layer1.weight.detach().numpy())
betterspy.show(sparse_matrix)

import matplotlib
matplotlib.use('TkAgg')
plt.spy(sparse_weight)
layer1 = nn.Linear(37, 91)
diag_pruning_linear(layer1, block_size=10, perm_type="")

# Show and save the sparsity pattern
sparse_matrix = sparse.csr_matrix(layer1.weight.detach().numpy())
betterspy.show(sparse_matrix)


if __name__ == "__main__":
exp()
a = torch.rand((2, 2, 2))
b = torch.rand((2, 2))
idcs = torch.nonzero(b < 0.7, as_tuple=True)
a[idcs] = 0
a = torch.rand((10, 100, 100))
c = nn.Conv2d(10, 10, kernel_size=3)
d = c.weight.cpu().detach()
b = c(a)
diag_pruning_conv2d(c, block_size=15)
137 changes: 137 additions & 0 deletions custom_utils/train_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Train Models
"""

import torch
import torch.nn as nn
import os
from typing import List

import custom_utils.utils as utils
from models.fcn import FCN
from models.resnet import ResNet18
from models.vggnet import vgg19_bn


def train(
data_type: str = "MNIST",
model_type: str = "FCN",
num_layers: int = 5, # For FCN model only
seed: int = 0,
lr_rate: float = 1e-3,
num_epochs: int = 100,
epoch_rewind: int = 3,
weight_decay: float = 5e-4,
patience: int = 30,
verbose: bool = True,
model_dir: str = None,
):
utils.set_random_seeds(seed)
cuda_device = torch.device("cuda:0")

# The CNN models assume that the dataset is of size (32, 32, 3), so we need to adapt the greyscale dataset to fit
# this size.
if (model_type == "FCN") or (data_type in ["CIFAR-10", "SVHN"]):
is_resize_greyscale = False
else:
is_resize_greyscale = True

# Flatten dataset if model_type is "FCN".
is_flatten = True if model_type == "FCN" else False

# Load dataset
train_loader, val_loader, test_loader, classes = utils.prepare_dataloader(
num_workers=8,
train_batch_size=128,
test_batch_size=256,
data_type=data_type,
is_flatten=is_flatten,
is_resize_greyscale=is_resize_greyscale,
train_eval_prop=[0.9, 0.1],
seed=seed,
)
input_dims = next(iter(train_loader))[0].size()[-1]

# Initialize model
if model_type == "FCN":
model = FCN(
num_layers=num_layers, input_dims=input_dims, num_classes=len(classes)
)
elif model_type == "VGG-19":
model = vgg19_bn()
elif model_type == "RESNET-18":
model = ResNet18()
else:
raise NotImplementedError(f"{model_type} is not implemented.")

# Logging directory and filename
model_filename = f"FCN_{num_layers}_{data_type}.pt"
if model_dir is None:
model_dir = "saved_models"
model_filepath = os.path.join(model_dir, model_filename)

if not (os.path.exists(model_dir)):
os.makedirs(model_dir)

if os.path.exists(model_filepath):
print(
"FCN is already trained. To create new pre-trained model, delete the existing model file."
)
return

filepath_rewind = os.path.join(
model_dir, f"FCN_{num_layers}_{data_type}_rewind_{epoch_rewind}.pt"
)

# Train model
utils.train_model(
model=model,
train_loader=train_loader,
test_loader=val_loader,
device=cuda_device,
optimizer="SGD",
l2_regularization_strength=weight_decay,
learning_rate=lr_rate,
num_epochs=num_epochs,
T_max=num_epochs,
verbose=verbose,
epoch_rewind=epoch_rewind,
filepath_rewind=filepath_rewind,
patience=patience,
)

utils.save_model(model=model, model_dir=model_dir, model_filename=model_filename)
# utils.load_model(model, os.path.join(model_dir, model_filename), cuda_device)

_, eval_accuracy, _ = utils.evaluate_model(
model=model, test_loader=test_loader, device=cuda_device, criterion=None
)
print(f"Number of layers: {num_layers}/ Test Accuracy: {eval_accuracy}")


if __name__ == "__main__":
dataset_types = ["MNIST", "CIFAR_10", "SVHN", "FASHION_MNIST"]
model_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "saved_models")
print(model_dir)
print("Pretraining Models...")

# Pre-train FCN
model_type = "FCN"
for dataset_type in dataset_types:
print(f"\n==Training {model_type} on {dataset_type}==")
num_epochs = 300 if dataset_type == "CIFAR_10" else 100
train(data_type=dataset_type, model_type=model_type, num_epochs=num_epochs)

# Pre-train VGG-19bn
model_type = "VGG-19"
num_epochs = 300
lr_rate = 5e-3
for dataset_type in dataset_types:
print(f"Training {model_type} on {dataset_type}..")
num_epochs = 300 if dataset_type == "CIFAR_10" else 100
train(
data_type=dataset_type,
model_type=model_type,
num_epochs=num_epochs,
lr_rate=lr_rate,
)
Loading