Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions vllm/model_executor/models/gemma3n_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,11 +503,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.multimodal_config = multimodal_config
self.vocab_size = config.text_config.vocab_size

self.vision_tower = AutoModel.from_config(config=config.vision_config)
image_limit = multimodal_config.get_limit_per_prompt("image")
if image_limit:
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.embed_vision = Gemma3nMultimodalEmbedder(
config.vision_config, config.text_config
)
else:
self.vision_tower = None
self.embed_vision = None

self.audio_tower = AutoModel.from_config(config=config.audio_config)
self.embed_vision = Gemma3nMultimodalEmbedder(
config.vision_config, config.text_config
)
self.embed_audio = Gemma3nMultimodalEmbedder(
config.audio_config, config.text_config
)
Expand Down Expand Up @@ -584,6 +590,7 @@ def _process_image_input(
image_input: Gemma3nImageInputs,
) -> list[torch.Tensor]:
assert self.vision_tower is not None
assert self.embed_vision is not None

pixel_values = image_input["pixel_values"]
vision_outputs = self.vision_tower(
Expand Down Expand Up @@ -735,7 +742,13 @@ def compute_logits(
return self.language_model.compute_logits(hidden_states)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
skip_prefixes = []
if self.vision_tower is None:
skip_prefixes.append("vision_tower.")
if self.embed_vision is None:
skip_prefixes.append("embed_vision.")

loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

def get_mm_mapping(self) -> MultiModelKeys:
Expand Down