Skip to content

Conversation

@pragnyanramtha
Copy link

Fixes #516

Problem

When using IntWrapper with multimodal models (e.g., Gemma3_4B), quantization fails with:

Could not find parameter named "scale" in scope "/jit()/jit(vision_encoder)/..."

Root cause: IntWrapper was applying quantization to ALL modules including the vision encoder, but peft.quantize() only quantizes text model parameters. This mismatch caused vision encoder modules to expect quantized parameters that didn't exist.

Solution

Added scope-based exclusion mechanism to IntWrapper:

  • New exclude_scopes parameter (default: ('vision_encoder',))
  • Vision encoder stays in full precision (maintains accuracy)
  • Text model is quantized to INT4 (reduces memory)

Changes

gemma/gm/nn/_quantization.py

  1. IntWrapper class:

    • Added exclude_scopes: tuple[str, ...] = ('vision_encoder',) parameter
    • Updated __call__ to pass exclusions to _replace_by_int
  2. _replace_by_int function:

    • Added exclude_scopes parameter
    • Added scope checking logic before quantization
    • Returns original module if scope path contains excluded pattern

gemma/gm/nn/_quantization_multimodal_test.py

Added comprehensive unit tests:

  • Default exclusion behavior
  • Custom exclusion scopes
  • Scope checking logic
  • Edge cases (no scope, multiple exclusions)

Usage

from gemma import gm
from gemma import peft
import jax.numpy as jnp

# Default behavior - vision encoder excluded automatically
model = gm.nn.IntWrapper(model=gm.nn.Gemma3_4B(), dtype=jnp.int4)
params = peft.quantize(params, method='INT4', checkpoint_kernel_key='w')

# Custom exclusions
model = gm.nn.IntWrapper(
    model=gm.nn.Gemma3_4B(),
    dtype=jnp.int4,
    exclude_scopes=('vision_encoder', 'custom_module'),
)

this code is fully backward compatible with existing code works without changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Issue with Quantization of the multimodal model (sampling) with gm.nn.IntWrapper

1 participant