From 3d2320e13094a4af56d78d73ecb4664d6a68420e Mon Sep 17 00:00:00 2001 From: Sriram Sowmithri Date: Thu, 8 Jan 2026 21:42:51 +0530 Subject: [PATCH] Fix SimulateQuantizedEinsum to use einsum_str instead of module name for pattern-specific quantization - Changed line 192 in gemma/peft/_quantization.py from einsum_str=self.wrapped.name to einsum_str=einsum_str - This enables pattern-specific quantization axis selection logic in get_axis_to_reduce_from_einsum_str() - Previously, module names like 'attention_proj' were passed, always returning None and forcing fallback to generic per-channel scaling - Now actual einsum equations like 'BTD,NDH->BTNH' are passed, enabling optimal pattern-specific scaling - Improves quantization accuracy for all einsum operations in quantization-aware training workflows --- gemma/peft/_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gemma/peft/_quantization.py b/gemma/peft/_quantization.py index aede9c25..6574ee1d 100644 --- a/gemma/peft/_quantization.py +++ b/gemma/peft/_quantization.py @@ -189,7 +189,7 @@ def __call__(self, inputs: Array, einsum_str: str | None = None) -> Array: kernel, self.method, axis_to_reduce=get_axis_to_reduce_from_einsum_str( - einsum_str=self.wrapped.name + einsum_str=einsum_str ), )