Skip to content
33 changes: 33 additions & 0 deletions fake_quant/hadamard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def get_hadK(n, transpose=False):
assert (is_pow2(n // 20))
K = 20
hadK = get_had20().T if transpose else get_had20()
elif n % 24 == 0: # llama-3.2-3B
assert (is_pow2(n // 24))
K = 24
hadK = get_had24().T if transpose else get_had24()
elif n % 12 == 0:
assert (is_pow2(n // 12))
K = 12
Expand Down Expand Up @@ -165,6 +169,35 @@ def get_had12():
[+1, -1, +1, -1, -1, -1, +1, +1, +1, -1, +1, +1],
])

# hadamard matrices for had24.pal
# print("\n".join(["[" + ", ".join([f"{c}1" for c in l]) + "]," for l in a.split("\n")[:-1]]))
def get_had24():
return torch.FloatTensor([
[+1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[+1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1],
[+1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1],
[+1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1],
[+1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1],
[+1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1],
[+1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1],
[+1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1],
[+1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1],
[+1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1],
[+1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1],
[+1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1],
[+1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1],
[+1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1],
[+1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1],
[+1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1],
[+1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1],
[+1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1],
[+1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1],
[+1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1],
[+1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1],
[+1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1],
[+1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1],
[+1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1],
])

def get_had40():
return torch.FloatTensor([
Expand Down
8 changes: 6 additions & 2 deletions fake_quant/rotation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,11 @@ def rotate_model(model, args):

model_type = model_utils.model_type_extractor(model)
rotate_embeddings(model, Q)
rotate_head(model, Q)

# if the input_embeddings (embeddings) and output_embeddings (lm_head) are tied, avoid rotating twice since they reference the same data.
if not model.config.tie_word_embeddings:
rotate_head(model, Q)

utils.cleanup_memory()
layers = model_utils.get_transformer_layers(model,
model_type=model_type)
Expand Down Expand Up @@ -294,7 +298,7 @@ def forward(self, *args, **kwargs):


if self.k_groupsize == -1: #token-wise quantization
token_wise_k = k.transpose(1, 2).reshape(-1, self.config.hidden_size)
token_wise_k = k.transpose(1, 2).reshape(-1, self.config.hidden_size * self.config.num_key_value_heads / self.config.num_attention_heads)
self.k_quantizer.find_params(token_wise_k)
k = self.k_quantizer(token_wise_k).reshape((bsz, seq_len, num_heads, head_dim)).transpose(1, 2).to(q)
else: #head-wise quantization
Expand Down
8 changes: 6 additions & 2 deletions fake_quant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from datetime import datetime
import logging


from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory

Expand All @@ -17,7 +16,12 @@
'meta-llama/Llama-2-70b-hf',
'meta-llama/Meta-Llama-3-8B',
'meta-llama/Meta-Llama-3-70B',
'facebook/opt-125m'
'meta-llama/Llama-3.1-8B',
'meta-llama/Llama-3.1-70B',
'meta-llama/Llama-3.1-405B',
'meta-llama/Llama-3.2-1B',
'meta-llama/Llama-3.2-3B',
'facebook/opt-125m',
]
supported_datasets = ['wikitext2', 'ptb', 'c4']

Expand Down