diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 0e69fcfd8feb..6353aad68700 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -503,11 +503,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multimodal_config = multimodal_config self.vocab_size = config.text_config.vocab_size - self.vision_tower = AutoModel.from_config(config=config.vision_config) + image_limit = multimodal_config.get_limit_per_prompt("image") + if image_limit: + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.embed_vision = Gemma3nMultimodalEmbedder( + config.vision_config, config.text_config + ) + else: + self.vision_tower = None + self.embed_vision = None + self.audio_tower = AutoModel.from_config(config=config.audio_config) - self.embed_vision = Gemma3nMultimodalEmbedder( - config.vision_config, config.text_config - ) self.embed_audio = Gemma3nMultimodalEmbedder( config.audio_config, config.text_config ) @@ -584,6 +590,7 @@ def _process_image_input( image_input: Gemma3nImageInputs, ) -> list[torch.Tensor]: assert self.vision_tower is not None + assert self.embed_vision is not None pixel_values = image_input["pixel_values"] vision_outputs = self.vision_tower( @@ -735,7 +742,13 @@ def compute_logits( return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + skip_prefixes = [] + if self.vision_tower is None: + skip_prefixes.append("vision_tower.") + if self.embed_vision is None: + skip_prefixes.append("embed_vision.") + + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: