diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index b9a7ba6baa5..b6daf581294 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -166,14 +166,19 @@ def build_sampling_params( input_length = len(token_ids) dynamic_default = max(1, model_max_len - input_length) model_config_max_tokens = default_sampling_params.get("max_tokens") - sampling_params.max_tokens = min(filter(lambda x: x is not None, - [provided_max_tokens, dynamic_default, model_config_max_tokens])) + sampling_params.max_tokens = min( + filter( + lambda x: x is not None, + [provided_max_tokens, dynamic_default, model_config_max_tokens], + ) + ) return sampling_params def build_sampling_params_openai( request: Dict[str, Any], default_sampling_params: Dict[str, Any], + model_max_len: int | None = None, ) -> SamplingParams: """ Build SamplingParams from an OpenAI-compatible request format. @@ -181,6 +186,7 @@ def build_sampling_params_openai( Args: request: The OpenAI-style request dict with parameters like temperature, max_tokens, etc. default_sampling_params: Default sampling parameters to initialize with + model_max_len: Maximum model context length for computing dynamic max_tokens default Returns: SamplingParams configured from the request @@ -208,8 +214,15 @@ def build_sampling_params_openai( setattr(sampling_params, param_key, request[req_key]) # Handle max_tokens - if "max_tokens" in request and request["max_tokens"] is not None: - sampling_params.max_tokens = request["max_tokens"] + provided_max_tokens = request.get("max_tokens") + model_config_max_tokens = default_sampling_params.get("max_tokens") + + sampling_params.max_tokens = min( + filter( + lambda x: x is not None, + [provided_max_tokens, model_max_len, model_config_max_tokens], + ) + ) # Handle stop sequences if "stop" in request and request["stop"] is not None: @@ -893,6 +906,59 @@ async def _extract_multimodal_data( return vllm_mm_data if vllm_mm_data else None + async def _extract_multimodal_from_openai_messages( + self, request: Dict[str, Any] + ) -> Dict[str, Any] | None: + messages = request.get("messages") + if not messages: + return None + + image_urls = [] + for message in messages: + content = message.get("content") + if not isinstance(content, list): + continue + + for item in content: + if not isinstance(item, dict) or item.get("type") != "image_url": + continue + + image_url_data = item.get("image_url") + if isinstance(image_url_data, dict): + url = image_url_data.get("url") + elif isinstance(image_url_data, str): + url = image_url_data + else: + continue + + if url: + image_urls.append(url) + + if not image_urls: + return None + + if not self.enable_multimodal: + raise ValueError( + "Received multimodal data but multimodal processing is not enabled. " + "Use --enable-multimodal flag to enable multimodal processing." + ) + + images = [] + for url in image_urls: + try: + image = await self.image_loader.load_image(url) + images.append(image) + logger.debug(f"Loaded image from OpenAI message: {url[:80]}...") + except Exception: + logger.exception(f"Failed to load image from {url[:80]}...") + raise + + vllm_mm_data = {"image": images[0] if len(images) == 1 else images} + logger.debug( + f"Extracted {len(images)} image(s) from OpenAI messages for multimodal processing" + ) + return vllm_mm_data + def _build_prompt_from_request( self, request: Dict[str, Any], @@ -1300,20 +1366,26 @@ async def _generate_token_mode(self, request, context, request_id): async def _generate_text_mode(self, request, context, request_id): """Generate text using OpenAI-compatible format (text-in-text-out).""" + # Get text input using InputParamManager input_data = self.input_param_manager.get_input_param( request, use_tokenizer=True ) + # Extract multimodal data + multi_modal_data = await self._extract_multimodal_from_openai_messages(request) + # Build prompt for vLLM if isinstance(input_data, list): - prompt = TokensPrompt(prompt_token_ids=input_data) + prompt = TokensPrompt( + prompt_token_ids=input_data, multi_modal_data=multi_modal_data + ) else: - prompt = TextPrompt(prompt=input_data) + prompt = TextPrompt(prompt=input_data, multi_modal_data=multi_modal_data) # Build sampling params from OpenAI-style request sampling_params = build_sampling_params_openai( - request, self.default_sampling_params + request, self.default_sampling_params, self.model_max_len ) dp_rank = request.get("dp_rank", None)