From 7a0ac034a23126219f6b6261e618e8efde4fc7c2 Mon Sep 17 00:00:00 2001 From: CryVeck Date: Fri, 22 Nov 2024 14:04:42 +0900 Subject: [PATCH 1/5] fixing token_wise rotation size when different from value size --- fake_quant/rotation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fake_quant/rotation_utils.py b/fake_quant/rotation_utils.py index 97a73ef..07add7d 100644 --- a/fake_quant/rotation_utils.py +++ b/fake_quant/rotation_utils.py @@ -294,7 +294,7 @@ def forward(self, *args, **kwargs): if self.k_groupsize == -1: #token-wise quantization - token_wise_k = k.transpose(1, 2).reshape(-1, self.config.hidden_size) + token_wise_k = k.transpose(1, 2).reshape(-1, self.config.hidden_size * self.config.num_key_value_heads / self.config.num_attention_heads) self.k_quantizer.find_params(token_wise_k) k = self.k_quantizer(token_wise_k).reshape((bsz, seq_len, num_heads, head_dim)).transpose(1, 2).to(q) else: #head-wise quantization From a11699abe78434ebffd4eccca03511b407327dd2 Mon Sep 17 00:00:00 2001 From: CryVeck Date: Sun, 24 Nov 2024 17:51:07 +0900 Subject: [PATCH 2/5] Fixing rotation issue due to tie_weight --- fake_quant/rotation_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fake_quant/rotation_utils.py b/fake_quant/rotation_utils.py index 07add7d..e0eb1a2 100644 --- a/fake_quant/rotation_utils.py +++ b/fake_quant/rotation_utils.py @@ -238,7 +238,11 @@ def rotate_model(model, args): model_type = model_utils.model_type_extractor(model) rotate_embeddings(model, Q) - rotate_head(model, Q) + + # if the input_embeddings (embeddings) and output_embeddings (lm_head) are tied, avoid rotating twice since they reference the same data. + if not model.config.tie_word_embeddings: + rotate_head(model, Q) + utils.cleanup_memory() layers = model_utils.get_transformer_layers(model, model_type=model_type) From e1ed87482681c1ff7393363ce335d7e2e56e4464 Mon Sep 17 00:00:00 2001 From: CryVeck Date: Sun, 24 Nov 2024 18:16:38 +0900 Subject: [PATCH 3/5] Support for Hadamar matrix for Llama 3.2 3B --- fake_quant/hadamard_utils.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/fake_quant/hadamard_utils.py b/fake_quant/hadamard_utils.py index 20b35e0..9e3cc93 100644 --- a/fake_quant/hadamard_utils.py +++ b/fake_quant/hadamard_utils.py @@ -44,6 +44,10 @@ def get_hadK(n, transpose=False): assert (is_pow2(n // 20)) K = 20 hadK = get_had20().T if transpose else get_had20() + elif n % 24 == 0: # llama-3.2-3B + assert (is_pow2(n // 24)) + K = 24 + hadK = get_had24().T if transpose else get_had24() elif n % 12 == 0: assert (is_pow2(n // 12)) K = 12 @@ -165,6 +169,35 @@ def get_had12(): [+1, -1, +1, -1, -1, -1, +1, +1, +1, -1, +1, +1], ]) +# hadamard matrices for had24.pal +# print("\n".join(["[" + ", ".join([f"{c}1" for c in l]) + "]," for l in a.split("\n")[:-1]])) +def get_had24(): + return torch.FloatTensor([ + [+1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [+1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1], + [+1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1], + [+1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1], + [+1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1], + [+1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1], + [+1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1], + [+1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1], + [+1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1], + [+1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1], + [+1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1], + [+1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1], + [+1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1], + [+1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1], + [+1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1], + [+1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1], + [+1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1], + [+1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1], + [+1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1], + [+1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1], + [+1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1], + [+1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1], + [+1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1], + [+1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1], + ]) def get_had40(): return torch.FloatTensor([ From 3f6f88436f95ec01d5253f73f925f47fe91ed2bf Mon Sep 17 00:00:00 2001 From: CryVeck Date: Sun, 24 Nov 2024 18:17:13 +0900 Subject: [PATCH 4/5] Adding the new models into the supported models --- fake_quant/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fake_quant/utils.py b/fake_quant/utils.py index a31a067..d0c0f7c 100644 --- a/fake_quant/utils.py +++ b/fake_quant/utils.py @@ -16,7 +16,8 @@ 'meta-llama/Llama-2-13b-hf', 'meta-llama/Llama-2-70b-hf', 'meta-llama/Meta-Llama-3-8B', - 'meta-llama/Meta-Llama-3-70B', + 'meta-llama/Llama-3.2-1B', + 'meta-llama/Llama-3.2-3B', 'facebook/opt-125m' ] supported_datasets = ['wikitext2', 'ptb', 'c4'] From f2707f557dd427ed8848c5715cfeeaa48997e461 Mon Sep 17 00:00:00 2001 From: CryVeck Date: Wed, 18 Dec 2024 14:07:51 +0900 Subject: [PATCH 5/5] support for llama 3.1 --- fake_quant/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fake_quant/utils.py b/fake_quant/utils.py index d0c0f7c..31b8276 100644 --- a/fake_quant/utils.py +++ b/fake_quant/utils.py @@ -7,7 +7,6 @@ from datetime import datetime import logging - from accelerate import dispatch_model, infer_auto_device_map from accelerate.utils import get_balanced_memory @@ -16,9 +15,13 @@ 'meta-llama/Llama-2-13b-hf', 'meta-llama/Llama-2-70b-hf', 'meta-llama/Meta-Llama-3-8B', + 'meta-llama/Meta-Llama-3-70B', + 'meta-llama/Llama-3.1-8B', + 'meta-llama/Llama-3.1-70B', + 'meta-llama/Llama-3.1-405B', 'meta-llama/Llama-3.2-1B', 'meta-llama/Llama-3.2-3B', - 'facebook/opt-125m' + 'facebook/opt-125m', ] supported_datasets = ['wikitext2', 'ptb', 'c4']