diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index b6daf581294..884608f7e1d 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -1363,6 +1363,11 @@ async def _generate_token_mode(self, request, context, request_id): logger.warning("Initiating Dynamo Runtime shutdown.") self.runtime.shutdown() os._exit(1) + finally: + if multi_modal_data is not None: + images = multi_modal_data.get("image") + count = len(images) if isinstance(images, list) else 1 + self.image_loader.mark_consumed(count) async def _generate_text_mode(self, request, context, request_id): """Generate text using OpenAI-compatible format (text-in-text-out).""" @@ -1455,6 +1460,11 @@ async def _generate_text_mode(self, request, context, request_id): logger.warning("Initiating Dynamo Runtime shutdown.") self.runtime.shutdown() os._exit(1) + finally: + if multi_modal_data is not None: + images = multi_modal_data.get("image") + count = len(images) if isinstance(images, list) else 1 + self.image_loader.mark_consumed(count) class PrefillWorkerHandler(BaseWorkerHandler): @@ -1608,3 +1618,8 @@ async def _generate_token_mode(self, request, context, request_id): raise GeneratorExit( "Prefill engine was shut down during token generation" ) from None + finally: + if multi_modal_data is not None: + images = multi_modal_data.get("image") + count = len(images) if isinstance(images, list) else 1 + self.image_loader.mark_consumed(count) diff --git a/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py b/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py index 2348e8a838a..782bf69ba99 100644 --- a/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py +++ b/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py @@ -58,7 +58,7 @@ def __init__( self.engine_args = engine_args self.model = self.engine_args.model - self.image_loader = ImageLoader(cache_size=CACHE_SIZE_MAXIMUM) + self.image_loader = ImageLoader() self.image_processor = AutoImageProcessor.from_pretrained( self.model, trust_remote_code=True ) diff --git a/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py b/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py index d423860936f..ed9e084b290 100644 --- a/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py +++ b/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py @@ -170,6 +170,7 @@ async def generate(self, request: vLLMMultimodalRequest, context): logger.debug(f"Received PD request: {{ id: {request.request_id} }}.") multi_modal_data = defaultdict(list) + num_loaded_images = 0 for mi in request.multimodal_inputs: # ECConnector consumer mode: vLLM loads embeddings automatically from disk # We need to pass multimodal_input so vLLM can generate mm_hash and look up cache @@ -273,6 +274,8 @@ async def generate(self, request: vLLMMultimodalRequest, context): await self.image_loader.load_image(mi.multimodal_input.image_url) ) + num_loaded_images += 1 + # Remove the image features from the request as they are not required request.multimodal_inputs = None @@ -362,3 +365,5 @@ async def generate(self, request: vLLMMultimodalRequest, context): metrics=response.metrics, kv_transfer_params=response.kv_transfer_params, ).model_dump_json() + + self.image_loader.mark_consumed(num_loaded_images) diff --git a/components/src/dynamo/vllm/multimodal_utils/image_loader.py b/components/src/dynamo/vllm/multimodal_utils/image_loader.py index d34fd9863e1..0e3c68f5b8f 100644 --- a/components/src/dynamo/vllm/multimodal_utils/image_loader.py +++ b/components/src/dynamo/vllm/multimodal_utils/image_loader.py @@ -17,32 +17,144 @@ import base64 import binascii import logging +import os +import threading +from concurrent.futures import ThreadPoolExecutor from io import BytesIO +from typing import TypeAlias, Union from urllib.parse import urlparse import httpx +import torch from PIL import Image from .http_client import get_http_client logger = logging.getLogger(__name__) +# Image output can be either PIL Image or Tensor (from nvimgcodec) +ImageOutput: TypeAlias = Union[Image.Image, torch.Tensor] + +# Thread-local storage for nvimgcodec decoders +_thread_local = threading.local() + +# Lazy import for nvimgcodec +_nvimgcodec = None +_nvimgcodec_available: bool | None = None # None = not yet probed + +# Global thread pool for nvimgcodec decoding operations +# Default to 8 workers, configurable via DYN_IMAGE_DECODE_WORKERS env var +_IMAGE_DECODE_WORKERS = int(os.environ.get("DYN_IMAGE_DECODE_WORKERS", 8)) +_decode_thread_pool = ThreadPoolExecutor( + max_workers=_IMAGE_DECODE_WORKERS, + thread_name_prefix="image_decode_", +) + + +def _is_nvimgcodec_available() -> bool: + """Check whether nvimgcodec can be imported. Result is cached.""" + global _nvimgcodec_available + if _nvimgcodec_available is None: + try: + _get_nvimgcodec() + _nvimgcodec_available = True + except (ImportError, ModuleNotFoundError): + _nvimgcodec_available = False + return _nvimgcodec_available + + +def _get_nvimgcodec(): + """Lazy import nvimgcodec. Raises ImportError if not installed.""" + global _nvimgcodec + if _nvimgcodec is None: + from nvidia import nvimgcodec + + _nvimgcodec = nvimgcodec + return _nvimgcodec + + +def get_decoder(): + """Get or create a thread-local nvimgcodec decoder instance.""" + if not hasattr(_thread_local, "decoder"): + nvimgcodec = _get_nvimgcodec() + _thread_local.decoder = nvimgcodec.Decoder() + logger.info("nvimgcodec decoder initialized for thread") + return _thread_local.decoder + class ImageLoader: CACHE_SIZE_MAXIMUM = 8 + DEFAULT_MAX_PENDING = 64 def __init__( - self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0 + self, + cache_size: int = CACHE_SIZE_MAXIMUM, + http_timeout: float = 30.0, + use_nvimgcodec: bool = True, + max_pending: int | None = None, ): self._http_timeout = http_timeout self._image_cache: dict[str, Image.Image] = {} self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size) - async def load_image(self, image_url: str) -> Image.Image: + # Fall back to PIL if nvimgcodec was requested but is not installed + if use_nvimgcodec and not _is_nvimgcodec_available(): + logger.warning( + "nvimgcodec requested but not installed — " + "falling back to PIL for image decoding" + ) + use_nvimgcodec = False + self._use_nvimgcodec = use_nvimgcodec + + if max_pending is None: + max_pending = int( + os.environ.get("DYN_IMAGE_MAX_PENDING", self.DEFAULT_MAX_PENDING) + ) + self._pending_semaphore = asyncio.Semaphore(max_pending) + self._max_pending = max_pending + + def mark_consumed(self, count: int = 1): + """ + Signal that decoded images have been consumed by the vLLM prefill batch. + Call this after the prefill batch completes to allow more images to be decoded. + + Args: + count: Number of images consumed (default: 1) + """ + for _ in range(count): + self._pending_semaphore.release() + + def _decode_with_nvimgcodec(self, data: bytes) -> torch.Tensor: + """ + Decode image bytes using nvimgcodec for GPU-accelerated decoding. + + Returns: + torch.Tensor in NCHW format (4D) on CUDA device. + Shape: (1, C, H, W) - batch dimension added so vLLM treats it as + a batch of images, not as embeddings. + """ + nvimgcodec = _get_nvimgcodec() + decoder = get_decoder() + code_stream = nvimgcodec.CodeStream(data) + decoded = decoder.decode(code_stream) + + device = torch.device("cuda", torch.cuda.current_device()) + tensor = torch.as_tensor(decoded, device=device) + # HWC -> CHW + tensor = tensor.permute(2, 0, 1) + # Add batch dimension: CHW -> NCHW (1, C, H, W) + # This is critical: 3D tensors are interpreted as embeddings by vLLM, + # but 4D tensors are interpreted as a batch of images. + tensor = tensor.unsqueeze(0) + + return tensor + + async def load_image(self, image_url: str) -> ImageOutput: + """Load an image from a URL or data URI.""" parsed_url = urlparse(image_url) - # For HTTP(S) URLs, check cache first - if parsed_url.scheme in ("http", "https"): + # For HTTP(S) URLs, check cache first (PIL path only) + if not self._use_nvimgcodec and parsed_url.scheme in ("http", "https"): image_url_lower = image_url.lower() if image_url_lower in self._image_cache: logger.debug(f"Image found in cache for URL: {image_url}") @@ -61,7 +173,6 @@ async def load_image(self, image_url: str) -> Image.Image: try: image_bytes = base64.b64decode(data) - image_data = BytesIO(image_bytes) except binascii.Error as e: raise ValueError(f"Invalid base64 encoding: {e}") elif parsed_url.scheme in ("http", "https"): @@ -73,31 +184,50 @@ async def load_image(self, image_url: str) -> Image.Image: if not response.content: raise ValueError("Empty response content from image URL") - image_data = BytesIO(response.content) + image_bytes = response.content else: raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}") - # PIL is sync, so offload to a thread to avoid blocking the event loop - image = await asyncio.to_thread(Image.open, image_data) + # Wait if too many decoded images are pending in the vLLM scheduler. + # Released when the caller invokes mark_consumed() after prefill. + await self._pending_semaphore.acquire() + + try: + if self._use_nvimgcodec: + # nvimgcodec decoding (GPU-accelerated, returns 4D tensor) + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + _decode_thread_pool, + self._decode_with_nvimgcodec, + image_bytes, + ) + else: + # Original PIL path + image_data = BytesIO(image_bytes) + image = await asyncio.to_thread(Image.open, image_data) + + # Validate image format and convert to RGB + if image.format not in ("JPEG", "PNG", "WEBP"): + raise ValueError(f"Unsupported image format: {image.format}") - # Validate image format and convert to RGB - if image.format not in ("JPEG", "PNG", "WEBP"): - raise ValueError(f"Unsupported image format: {image.format}") + image_converted = image.convert("RGB") - image_converted = image.convert("RGB") + # Cache HTTP(S) URLs + if parsed_url.scheme in ("http", "https"): + image_url_lower = image_url.lower() + if self._cache_queue.full(): + oldest_image_url = await self._cache_queue.get() + del self._image_cache[oldest_image_url] - # Cache HTTP(S) URLs - if parsed_url.scheme in ("http", "https"): - image_url_lower = image_url.lower() - # Cache the image for future use, and evict the oldest image if the cache is full - if self._cache_queue.full(): - oldest_image_url = await self._cache_queue.get() - del self._image_cache[oldest_image_url] + self._image_cache[image_url_lower] = image_converted + await self._cache_queue.put(image_url_lower) - self._image_cache[image_url_lower] = image_converted - await self._cache_queue.put(image_url_lower) + return image_converted - return image_converted + except Exception: + # Release semaphore on decode failure to prevent leak + self._pending_semaphore.release() + raise except httpx.HTTPError as e: logger.error(f"HTTP error loading image: {e}")