diff --git a/gemma/peft/_quantization.py b/gemma/peft/_quantization.py index aede9c25..6574ee1d 100644 --- a/gemma/peft/_quantization.py +++ b/gemma/peft/_quantization.py @@ -189,7 +189,7 @@ def __call__(self, inputs: Array, einsum_str: str | None = None) -> Array: kernel, self.method, axis_to_reduce=get_axis_to_reduce_from_einsum_str( - einsum_str=self.wrapped.name + einsum_str=einsum_str ), )