Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 79 additions & 7 deletions components/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,27 @@ def build_sampling_params(
input_length = len(token_ids)
dynamic_default = max(1, model_max_len - input_length)
model_config_max_tokens = default_sampling_params.get("max_tokens")
sampling_params.max_tokens = min(filter(lambda x: x is not None,
[provided_max_tokens, dynamic_default, model_config_max_tokens]))
sampling_params.max_tokens = min(
filter(
lambda x: x is not None,
[provided_max_tokens, dynamic_default, model_config_max_tokens],
)
)
return sampling_params


def build_sampling_params_openai(
request: Dict[str, Any],
default_sampling_params: Dict[str, Any],
model_max_len: int | None = None,
) -> SamplingParams:
"""
Build SamplingParams from an OpenAI-compatible request format.

Args:
request: The OpenAI-style request dict with parameters like temperature, max_tokens, etc.
default_sampling_params: Default sampling parameters to initialize with
model_max_len: Maximum model context length for computing dynamic max_tokens default

Returns:
SamplingParams configured from the request
Expand Down Expand Up @@ -208,8 +214,15 @@ def build_sampling_params_openai(
setattr(sampling_params, param_key, request[req_key])

# Handle max_tokens
if "max_tokens" in request and request["max_tokens"] is not None:
sampling_params.max_tokens = request["max_tokens"]
provided_max_tokens = request.get("max_tokens")
model_config_max_tokens = default_sampling_params.get("max_tokens")

sampling_params.max_tokens = min(
filter(
lambda x: x is not None,
[provided_max_tokens, model_max_len, model_config_max_tokens],
)
)

# Handle stop sequences
if "stop" in request and request["stop"] is not None:
Expand Down Expand Up @@ -893,6 +906,59 @@ async def _extract_multimodal_data(

return vllm_mm_data if vllm_mm_data else None

async def _extract_multimodal_from_openai_messages(
self, request: Dict[str, Any]
) -> Dict[str, Any] | None:
messages = request.get("messages")
if not messages:
return None

image_urls = []
for message in messages:
content = message.get("content")
if not isinstance(content, list):
continue

for item in content:
if not isinstance(item, dict) or item.get("type") != "image_url":
continue

image_url_data = item.get("image_url")
if isinstance(image_url_data, dict):
url = image_url_data.get("url")
elif isinstance(image_url_data, str):
url = image_url_data
else:
continue

if url:
image_urls.append(url)

if not image_urls:
return None

if not self.enable_multimodal:
raise ValueError(
"Received multimodal data but multimodal processing is not enabled. "
"Use --enable-multimodal flag to enable multimodal processing."
)

images = []
for url in image_urls:
try:
image = await self.image_loader.load_image(url)
images.append(image)
logger.debug(f"Loaded image from OpenAI message: {url[:80]}...")
except Exception:
logger.exception(f"Failed to load image from {url[:80]}...")
raise

vllm_mm_data = {"image": images[0] if len(images) == 1 else images}
logger.debug(
f"Extracted {len(images)} image(s) from OpenAI messages for multimodal processing"
)
return vllm_mm_data

def _build_prompt_from_request(
self,
request: Dict[str, Any],
Expand Down Expand Up @@ -1300,20 +1366,26 @@ async def _generate_token_mode(self, request, context, request_id):

async def _generate_text_mode(self, request, context, request_id):
"""Generate text using OpenAI-compatible format (text-in-text-out)."""

# Get text input using InputParamManager
input_data = self.input_param_manager.get_input_param(
request, use_tokenizer=True
)

# Extract multimodal data
multi_modal_data = await self._extract_multimodal_from_openai_messages(request)

# Build prompt for vLLM
if isinstance(input_data, list):
prompt = TokensPrompt(prompt_token_ids=input_data)
prompt = TokensPrompt(
prompt_token_ids=input_data, multi_modal_data=multi_modal_data
)
else:
prompt = TextPrompt(prompt=input_data)
prompt = TextPrompt(prompt=input_data, multi_modal_data=multi_modal_data)

# Build sampling params from OpenAI-style request
sampling_params = build_sampling_params_openai(
request, self.default_sampling_params
request, self.default_sampling_params, self.model_max_len
)

dp_rank = request.get("dp_rank", None)
Expand Down
Loading