diff --git a/fake_quant/hadamard_utils.py b/fake_quant/hadamard_utils.py index 20b35e0..9e3cc93 100644 --- a/fake_quant/hadamard_utils.py +++ b/fake_quant/hadamard_utils.py @@ -44,6 +44,10 @@ def get_hadK(n, transpose=False): assert (is_pow2(n // 20)) K = 20 hadK = get_had20().T if transpose else get_had20() + elif n % 24 == 0: # llama-3.2-3B + assert (is_pow2(n // 24)) + K = 24 + hadK = get_had24().T if transpose else get_had24() elif n % 12 == 0: assert (is_pow2(n // 12)) K = 12 @@ -165,6 +169,35 @@ def get_had12(): [+1, -1, +1, -1, -1, -1, +1, +1, +1, -1, +1, +1], ]) +# hadamard matrices for had24.pal +# print("\n".join(["[" + ", ".join([f"{c}1" for c in l]) + "]," for l in a.split("\n")[:-1]])) +def get_had24(): + return torch.FloatTensor([ + [+1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [+1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1], + [+1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1], + [+1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1], + [+1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1], + [+1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1], + [+1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1], + [+1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1], + [+1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1], + [+1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1], + [+1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1], + [+1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1], + [+1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1], + [+1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1], + [+1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1], + [+1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1], + [+1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1], + [+1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1], + [+1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1], + [+1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1], + [+1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1], + [+1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1], + [+1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1], + [+1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1], + ]) def get_had40(): return torch.FloatTensor([ diff --git a/fake_quant/rotation_utils.py b/fake_quant/rotation_utils.py index 97a73ef..e0eb1a2 100644 --- a/fake_quant/rotation_utils.py +++ b/fake_quant/rotation_utils.py @@ -238,7 +238,11 @@ def rotate_model(model, args): model_type = model_utils.model_type_extractor(model) rotate_embeddings(model, Q) - rotate_head(model, Q) + + # if the input_embeddings (embeddings) and output_embeddings (lm_head) are tied, avoid rotating twice since they reference the same data. + if not model.config.tie_word_embeddings: + rotate_head(model, Q) + utils.cleanup_memory() layers = model_utils.get_transformer_layers(model, model_type=model_type) @@ -294,7 +298,7 @@ def forward(self, *args, **kwargs): if self.k_groupsize == -1: #token-wise quantization - token_wise_k = k.transpose(1, 2).reshape(-1, self.config.hidden_size) + token_wise_k = k.transpose(1, 2).reshape(-1, self.config.hidden_size * self.config.num_key_value_heads / self.config.num_attention_heads) self.k_quantizer.find_params(token_wise_k) k = self.k_quantizer(token_wise_k).reshape((bsz, seq_len, num_heads, head_dim)).transpose(1, 2).to(q) else: #head-wise quantization diff --git a/fake_quant/utils.py b/fake_quant/utils.py index a31a067..31b8276 100644 --- a/fake_quant/utils.py +++ b/fake_quant/utils.py @@ -7,7 +7,6 @@ from datetime import datetime import logging - from accelerate import dispatch_model, infer_auto_device_map from accelerate.utils import get_balanced_memory @@ -17,7 +16,12 @@ 'meta-llama/Llama-2-70b-hf', 'meta-llama/Meta-Llama-3-8B', 'meta-llama/Meta-Llama-3-70B', - 'facebook/opt-125m' + 'meta-llama/Llama-3.1-8B', + 'meta-llama/Llama-3.1-70B', + 'meta-llama/Llama-3.1-405B', + 'meta-llama/Llama-3.2-1B', + 'meta-llama/Llama-3.2-3B', + 'facebook/opt-125m', ] supported_datasets = ['wikitext2', 'ptb', 'c4']