diff --git a/lib/data.py b/lib/data.py index b6842c40..b15ca375 100644 --- a/lib/data.py +++ b/lib/data.py @@ -40,8 +40,8 @@ def get_wikitext2(nsamples, seed, seqlen, tokenizer): # Load and process c4 dataset def get_c4(nsamples, seed, seqlen, tokenizer): # Load train and validation datasets - traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') - valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') + traindata = load_dataset('allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') + valdata = load_dataset('allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') # Generate samples from training set random.seed(seed) @@ -70,4 +70,4 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None): if 'wikitext2' in name: return get_wikitext2(nsamples, seed, seqlen, tokenizer) if "c4" in name: - return get_c4(nsamples, seed, seqlen, tokenizer) \ No newline at end of file + return get_c4(nsamples, seed, seqlen, tokenizer) \ No newline at end of file diff --git a/lib/eval.py b/lib/eval.py index 5214d50f..2eed305f 100644 --- a/lib/eval.py +++ b/lib/eval.py @@ -11,7 +11,7 @@ # Function to evaluate perplexity (ppl) on a specified model and tokenizer -def eval_ppl(args, model, tokenizer, device=torch.device("cuda:0")): +def eval_ppl(model, tokenizer, device=torch.device("cuda:0")): # Set dataset dataset = "wikitext2" diff --git a/lib/prune.py b/lib/prune.py index 01d981c4..0bfa9339 100644 --- a/lib/prune.py +++ b/lib/prune.py @@ -2,6 +2,7 @@ import heapq import torch import torch.nn as nn +import tqdm from .sparsegpt import SparseGPT from .layerwrapper import WrappedGPT from .data import get_loaders @@ -128,14 +129,14 @@ def prune_wanda(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0 use_cache = model.config.use_cache model.config.use_cache = False - print("loading calibdation data") + print("loading calibration data") dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer) print("dataset loading complete") with torch.no_grad(): inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device) layers = model.model.layers - for i in range(len(layers)): + for i in tqdm.tqdm(range(len(layers))): layer = layers[i] subset = find_layers(layer) @@ -162,7 +163,7 @@ def tmp(_, inp, out): h.remove() for name in subset: - print(f"pruning layer {i} name {name}") + #print(f"pruning layer {i} name {name}") W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1))) W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False diff --git a/lib/prune_opt.py b/lib/prune_opt.py index 5910bcf3..804cec7b 100644 --- a/lib/prune_opt.py +++ b/lib/prune_opt.py @@ -2,6 +2,7 @@ import heapq import torch import torch.nn as nn +import tqdm from .sparsegpt import SparseGPT from .layerwrapper import WrappedGPT from .data import get_loaders @@ -125,14 +126,14 @@ def prune_wanda(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0 use_cache = model.config.use_cache model.config.use_cache = False - print("loading calibdation data") + print("loading calibration data") dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer) print("dataset loading complete") with torch.no_grad(): inps, outs, attention_mask = prepare_calibration_input(model, dataloader, device) layers = model.model.decoder.layers - for i in range(len(layers)): + for i in tqdm.tqdm(range(len(layers))): layer = layers[i] subset = find_layers(layer) @@ -159,7 +160,7 @@ def tmp(_, inp, out): h.remove() for name in subset: - print(f"pruning layer {i} name {name}") + #print(f"pruning layer {i} name {name}") W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1))) W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False