Skip to content
15 changes: 15 additions & 0 deletions components/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,11 @@ async def _generate_token_mode(self, request, context, request_id):
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
finally:
if multi_modal_data is not None:
images = multi_modal_data.get("image")
count = len(images) if isinstance(images, list) else 1
self.image_loader.mark_consumed(count)

async def _generate_text_mode(self, request, context, request_id):
"""Generate text using OpenAI-compatible format (text-in-text-out)."""
Expand Down Expand Up @@ -1455,6 +1460,11 @@ async def _generate_text_mode(self, request, context, request_id):
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
finally:
if multi_modal_data is not None:
images = multi_modal_data.get("image")
count = len(images) if isinstance(images, list) else 1
self.image_loader.mark_consumed(count)


class PrefillWorkerHandler(BaseWorkerHandler):
Expand Down Expand Up @@ -1608,3 +1618,8 @@ async def _generate_token_mode(self, request, context, request_id):
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
finally:
if multi_modal_data is not None:
images = multi_modal_data.get("image")
count = len(images) if isinstance(images, list) else 1
self.image_loader.mark_consumed(count)
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self.engine_args = engine_args
self.model = self.engine_args.model

self.image_loader = ImageLoader(cache_size=CACHE_SIZE_MAXIMUM)
self.image_loader = ImageLoader()
self.image_processor = AutoImageProcessor.from_pretrained(
self.model, trust_remote_code=True
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ async def generate(self, request: vLLMMultimodalRequest, context):
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")

multi_modal_data = defaultdict(list)
num_loaded_images = 0
for mi in request.multimodal_inputs:
# ECConnector consumer mode: vLLM loads embeddings automatically from disk
# We need to pass multimodal_input so vLLM can generate mm_hash and look up cache
Expand Down Expand Up @@ -273,6 +274,8 @@ async def generate(self, request: vLLMMultimodalRequest, context):
await self.image_loader.load_image(mi.multimodal_input.image_url)
)

num_loaded_images += 1

# Remove the image features from the request as they are not required
request.multimodal_inputs = None

Expand Down Expand Up @@ -362,3 +365,5 @@ async def generate(self, request: vLLMMultimodalRequest, context):
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()

self.image_loader.mark_consumed(num_loaded_images)
174 changes: 152 additions & 22 deletions components/src/dynamo/vllm/multimodal_utils/image_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,144 @@
import base64
import binascii
import logging
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import TypeAlias, Union
from urllib.parse import urlparse

import httpx
import torch
from PIL import Image

from .http_client import get_http_client

logger = logging.getLogger(__name__)

# Image output can be either PIL Image or Tensor (from nvimgcodec)
ImageOutput: TypeAlias = Union[Image.Image, torch.Tensor]

# Thread-local storage for nvimgcodec decoders
_thread_local = threading.local()

# Lazy import for nvimgcodec
_nvimgcodec = None
_nvimgcodec_available: bool | None = None # None = not yet probed

# Global thread pool for nvimgcodec decoding operations
# Default to 8 workers, configurable via DYN_IMAGE_DECODE_WORKERS env var
_IMAGE_DECODE_WORKERS = int(os.environ.get("DYN_IMAGE_DECODE_WORKERS", 8))
_decode_thread_pool = ThreadPoolExecutor(
max_workers=_IMAGE_DECODE_WORKERS,
thread_name_prefix="image_decode_",
)


def _is_nvimgcodec_available() -> bool:
"""Check whether nvimgcodec can be imported. Result is cached."""
global _nvimgcodec_available
if _nvimgcodec_available is None:
try:
_get_nvimgcodec()
_nvimgcodec_available = True
except (ImportError, ModuleNotFoundError):
_nvimgcodec_available = False
return _nvimgcodec_available


def _get_nvimgcodec():
"""Lazy import nvimgcodec. Raises ImportError if not installed."""
global _nvimgcodec
if _nvimgcodec is None:
from nvidia import nvimgcodec

_nvimgcodec = nvimgcodec
return _nvimgcodec


def get_decoder():
"""Get or create a thread-local nvimgcodec decoder instance."""
if not hasattr(_thread_local, "decoder"):
nvimgcodec = _get_nvimgcodec()
_thread_local.decoder = nvimgcodec.Decoder()
logger.info("nvimgcodec decoder initialized for thread")
return _thread_local.decoder


class ImageLoader:
CACHE_SIZE_MAXIMUM = 8
DEFAULT_MAX_PENDING = 64

def __init__(
self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0
self,
cache_size: int = CACHE_SIZE_MAXIMUM,
http_timeout: float = 30.0,
use_nvimgcodec: bool = True,
max_pending: int | None = None,
):
self._http_timeout = http_timeout
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)

async def load_image(self, image_url: str) -> Image.Image:
# Fall back to PIL if nvimgcodec was requested but is not installed
if use_nvimgcodec and not _is_nvimgcodec_available():
logger.warning(
"nvimgcodec requested but not installed — "
"falling back to PIL for image decoding"
)
use_nvimgcodec = False
self._use_nvimgcodec = use_nvimgcodec

if max_pending is None:
max_pending = int(
os.environ.get("DYN_IMAGE_MAX_PENDING", self.DEFAULT_MAX_PENDING)
)
self._pending_semaphore = asyncio.Semaphore(max_pending)
self._max_pending = max_pending

def mark_consumed(self, count: int = 1):
"""
Signal that decoded images have been consumed by the vLLM prefill batch.
Call this after the prefill batch completes to allow more images to be decoded.

Args:
count: Number of images consumed (default: 1)
"""
for _ in range(count):
self._pending_semaphore.release()

def _decode_with_nvimgcodec(self, data: bytes) -> torch.Tensor:
"""
Decode image bytes using nvimgcodec for GPU-accelerated decoding.

Returns:
torch.Tensor in NCHW format (4D) on CUDA device.
Shape: (1, C, H, W) - batch dimension added so vLLM treats it as
a batch of images, not as embeddings.
"""
nvimgcodec = _get_nvimgcodec()
decoder = get_decoder()
code_stream = nvimgcodec.CodeStream(data)
decoded = decoder.decode(code_stream)

device = torch.device("cuda", torch.cuda.current_device())
tensor = torch.as_tensor(decoded, device=device)
# HWC -> CHW
tensor = tensor.permute(2, 0, 1)
# Add batch dimension: CHW -> NCHW (1, C, H, W)
# This is critical: 3D tensors are interpreted as embeddings by vLLM,
# but 4D tensors are interpreted as a batch of images.
tensor = tensor.unsqueeze(0)

return tensor

async def load_image(self, image_url: str) -> ImageOutput:
"""Load an image from a URL or data URI."""
parsed_url = urlparse(image_url)

# For HTTP(S) URLs, check cache first
if parsed_url.scheme in ("http", "https"):
# For HTTP(S) URLs, check cache first (PIL path only)
if not self._use_nvimgcodec and parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
if image_url_lower in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}")
Expand All @@ -61,7 +173,6 @@ async def load_image(self, image_url: str) -> Image.Image:

try:
image_bytes = base64.b64decode(data)
image_data = BytesIO(image_bytes)
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}")
elif parsed_url.scheme in ("http", "https"):
Expand All @@ -73,31 +184,50 @@ async def load_image(self, image_url: str) -> Image.Image:
if not response.content:
raise ValueError("Empty response content from image URL")

image_data = BytesIO(response.content)
image_bytes = response.content
else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")

# PIL is sync, so offload to a thread to avoid blocking the event loop
image = await asyncio.to_thread(Image.open, image_data)
# Wait if too many decoded images are pending in the vLLM scheduler.
# Released when the caller invokes mark_consumed() after prefill.
await self._pending_semaphore.acquire()

try:
if self._use_nvimgcodec:
# nvimgcodec decoding (GPU-accelerated, returns 4D tensor)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
_decode_thread_pool,
self._decode_with_nvimgcodec,
image_bytes,
)
else:
# Original PIL path
image_data = BytesIO(image_bytes)
image = await asyncio.to_thread(Image.open, image_data)

# Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")

# Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")
image_converted = image.convert("RGB")

image_converted = image.convert("RGB")
# Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
if self._cache_queue.full():
oldest_image_url = await self._cache_queue.get()
del self._image_cache[oldest_image_url]

# Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
# Cache the image for future use, and evict the oldest image if the cache is full
if self._cache_queue.full():
oldest_image_url = await self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[image_url_lower] = image_converted
await self._cache_queue.put(image_url_lower)

self._image_cache[image_url_lower] = image_converted
await self._cache_queue.put(image_url_lower)
return image_converted

return image_converted
except Exception:
# Release semaphore on decode failure to prevent leak
self._pending_semaphore.release()
raise

except httpx.HTTPError as e:
logger.error(f"HTTP error loading image: {e}")
Expand Down
Loading