diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 884608f7e1d..b6daf581294 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -1363,11 +1363,6 @@ async def _generate_token_mode(self, request, context, request_id): logger.warning("Initiating Dynamo Runtime shutdown.") self.runtime.shutdown() os._exit(1) - finally: - if multi_modal_data is not None: - images = multi_modal_data.get("image") - count = len(images) if isinstance(images, list) else 1 - self.image_loader.mark_consumed(count) async def _generate_text_mode(self, request, context, request_id): """Generate text using OpenAI-compatible format (text-in-text-out).""" @@ -1460,11 +1455,6 @@ async def _generate_text_mode(self, request, context, request_id): logger.warning("Initiating Dynamo Runtime shutdown.") self.runtime.shutdown() os._exit(1) - finally: - if multi_modal_data is not None: - images = multi_modal_data.get("image") - count = len(images) if isinstance(images, list) else 1 - self.image_loader.mark_consumed(count) class PrefillWorkerHandler(BaseWorkerHandler): @@ -1618,8 +1608,3 @@ async def _generate_token_mode(self, request, context, request_id): raise GeneratorExit( "Prefill engine was shut down during token generation" ) from None - finally: - if multi_modal_data is not None: - images = multi_modal_data.get("image") - count = len(images) if isinstance(images, list) else 1 - self.image_loader.mark_consumed(count) diff --git a/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py b/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py index 782bf69ba99..2348e8a838a 100644 --- a/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py +++ b/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py @@ -58,7 +58,7 @@ def __init__( self.engine_args = engine_args self.model = self.engine_args.model - self.image_loader = ImageLoader() + self.image_loader = ImageLoader(cache_size=CACHE_SIZE_MAXIMUM) self.image_processor = AutoImageProcessor.from_pretrained( self.model, trust_remote_code=True ) diff --git a/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py b/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py index ed9e084b290..d423860936f 100644 --- a/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py +++ b/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py @@ -170,7 +170,6 @@ async def generate(self, request: vLLMMultimodalRequest, context): logger.debug(f"Received PD request: {{ id: {request.request_id} }}.") multi_modal_data = defaultdict(list) - num_loaded_images = 0 for mi in request.multimodal_inputs: # ECConnector consumer mode: vLLM loads embeddings automatically from disk # We need to pass multimodal_input so vLLM can generate mm_hash and look up cache @@ -274,8 +273,6 @@ async def generate(self, request: vLLMMultimodalRequest, context): await self.image_loader.load_image(mi.multimodal_input.image_url) ) - num_loaded_images += 1 - # Remove the image features from the request as they are not required request.multimodal_inputs = None @@ -365,5 +362,3 @@ async def generate(self, request: vLLMMultimodalRequest, context): metrics=response.metrics, kv_transfer_params=response.kv_transfer_params, ).model_dump_json() - - self.image_loader.mark_consumed(num_loaded_images) diff --git a/components/src/dynamo/vllm/multimodal_utils/image_loader.py b/components/src/dynamo/vllm/multimodal_utils/image_loader.py index 0e3c68f5b8f..d34fd9863e1 100644 --- a/components/src/dynamo/vllm/multimodal_utils/image_loader.py +++ b/components/src/dynamo/vllm/multimodal_utils/image_loader.py @@ -17,144 +17,32 @@ import base64 import binascii import logging -import os -import threading -from concurrent.futures import ThreadPoolExecutor from io import BytesIO -from typing import TypeAlias, Union from urllib.parse import urlparse import httpx -import torch from PIL import Image from .http_client import get_http_client logger = logging.getLogger(__name__) -# Image output can be either PIL Image or Tensor (from nvimgcodec) -ImageOutput: TypeAlias = Union[Image.Image, torch.Tensor] - -# Thread-local storage for nvimgcodec decoders -_thread_local = threading.local() - -# Lazy import for nvimgcodec -_nvimgcodec = None -_nvimgcodec_available: bool | None = None # None = not yet probed - -# Global thread pool for nvimgcodec decoding operations -# Default to 8 workers, configurable via DYN_IMAGE_DECODE_WORKERS env var -_IMAGE_DECODE_WORKERS = int(os.environ.get("DYN_IMAGE_DECODE_WORKERS", 8)) -_decode_thread_pool = ThreadPoolExecutor( - max_workers=_IMAGE_DECODE_WORKERS, - thread_name_prefix="image_decode_", -) - - -def _is_nvimgcodec_available() -> bool: - """Check whether nvimgcodec can be imported. Result is cached.""" - global _nvimgcodec_available - if _nvimgcodec_available is None: - try: - _get_nvimgcodec() - _nvimgcodec_available = True - except (ImportError, ModuleNotFoundError): - _nvimgcodec_available = False - return _nvimgcodec_available - - -def _get_nvimgcodec(): - """Lazy import nvimgcodec. Raises ImportError if not installed.""" - global _nvimgcodec - if _nvimgcodec is None: - from nvidia import nvimgcodec - - _nvimgcodec = nvimgcodec - return _nvimgcodec - - -def get_decoder(): - """Get or create a thread-local nvimgcodec decoder instance.""" - if not hasattr(_thread_local, "decoder"): - nvimgcodec = _get_nvimgcodec() - _thread_local.decoder = nvimgcodec.Decoder() - logger.info("nvimgcodec decoder initialized for thread") - return _thread_local.decoder - class ImageLoader: CACHE_SIZE_MAXIMUM = 8 - DEFAULT_MAX_PENDING = 64 def __init__( - self, - cache_size: int = CACHE_SIZE_MAXIMUM, - http_timeout: float = 30.0, - use_nvimgcodec: bool = True, - max_pending: int | None = None, + self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0 ): self._http_timeout = http_timeout self._image_cache: dict[str, Image.Image] = {} self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size) - # Fall back to PIL if nvimgcodec was requested but is not installed - if use_nvimgcodec and not _is_nvimgcodec_available(): - logger.warning( - "nvimgcodec requested but not installed — " - "falling back to PIL for image decoding" - ) - use_nvimgcodec = False - self._use_nvimgcodec = use_nvimgcodec - - if max_pending is None: - max_pending = int( - os.environ.get("DYN_IMAGE_MAX_PENDING", self.DEFAULT_MAX_PENDING) - ) - self._pending_semaphore = asyncio.Semaphore(max_pending) - self._max_pending = max_pending - - def mark_consumed(self, count: int = 1): - """ - Signal that decoded images have been consumed by the vLLM prefill batch. - Call this after the prefill batch completes to allow more images to be decoded. - - Args: - count: Number of images consumed (default: 1) - """ - for _ in range(count): - self._pending_semaphore.release() - - def _decode_with_nvimgcodec(self, data: bytes) -> torch.Tensor: - """ - Decode image bytes using nvimgcodec for GPU-accelerated decoding. - - Returns: - torch.Tensor in NCHW format (4D) on CUDA device. - Shape: (1, C, H, W) - batch dimension added so vLLM treats it as - a batch of images, not as embeddings. - """ - nvimgcodec = _get_nvimgcodec() - decoder = get_decoder() - code_stream = nvimgcodec.CodeStream(data) - decoded = decoder.decode(code_stream) - - device = torch.device("cuda", torch.cuda.current_device()) - tensor = torch.as_tensor(decoded, device=device) - # HWC -> CHW - tensor = tensor.permute(2, 0, 1) - # Add batch dimension: CHW -> NCHW (1, C, H, W) - # This is critical: 3D tensors are interpreted as embeddings by vLLM, - # but 4D tensors are interpreted as a batch of images. - tensor = tensor.unsqueeze(0) - - return tensor - - async def load_image(self, image_url: str) -> ImageOutput: - """Load an image from a URL or data URI.""" + async def load_image(self, image_url: str) -> Image.Image: parsed_url = urlparse(image_url) - # For HTTP(S) URLs, check cache first (PIL path only) - if not self._use_nvimgcodec and parsed_url.scheme in ("http", "https"): + # For HTTP(S) URLs, check cache first + if parsed_url.scheme in ("http", "https"): image_url_lower = image_url.lower() if image_url_lower in self._image_cache: logger.debug(f"Image found in cache for URL: {image_url}") @@ -173,6 +61,7 @@ async def load_image(self, image_url: str) -> ImageOutput: try: image_bytes = base64.b64decode(data) + image_data = BytesIO(image_bytes) except binascii.Error as e: raise ValueError(f"Invalid base64 encoding: {e}") elif parsed_url.scheme in ("http", "https"): @@ -184,50 +73,31 @@ async def load_image(self, image_url: str) -> ImageOutput: if not response.content: raise ValueError("Empty response content from image URL") - image_bytes = response.content + image_data = BytesIO(response.content) else: raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}") - # Wait if too many decoded images are pending in the vLLM scheduler. - # Released when the caller invokes mark_consumed() after prefill. - await self._pending_semaphore.acquire() - - try: - if self._use_nvimgcodec: - # nvimgcodec decoding (GPU-accelerated, returns 4D tensor) - loop = asyncio.get_running_loop() - return await loop.run_in_executor( - _decode_thread_pool, - self._decode_with_nvimgcodec, - image_bytes, - ) - else: - # Original PIL path - image_data = BytesIO(image_bytes) - image = await asyncio.to_thread(Image.open, image_data) - - # Validate image format and convert to RGB - if image.format not in ("JPEG", "PNG", "WEBP"): - raise ValueError(f"Unsupported image format: {image.format}") + # PIL is sync, so offload to a thread to avoid blocking the event loop + image = await asyncio.to_thread(Image.open, image_data) - image_converted = image.convert("RGB") + # Validate image format and convert to RGB + if image.format not in ("JPEG", "PNG", "WEBP"): + raise ValueError(f"Unsupported image format: {image.format}") - # Cache HTTP(S) URLs - if parsed_url.scheme in ("http", "https"): - image_url_lower = image_url.lower() - if self._cache_queue.full(): - oldest_image_url = await self._cache_queue.get() - del self._image_cache[oldest_image_url] + image_converted = image.convert("RGB") - self._image_cache[image_url_lower] = image_converted - await self._cache_queue.put(image_url_lower) + # Cache HTTP(S) URLs + if parsed_url.scheme in ("http", "https"): + image_url_lower = image_url.lower() + # Cache the image for future use, and evict the oldest image if the cache is full + if self._cache_queue.full(): + oldest_image_url = await self._cache_queue.get() + del self._image_cache[oldest_image_url] - return image_converted + self._image_cache[image_url_lower] = image_converted + await self._cache_queue.put(image_url_lower) - except Exception: - # Release semaphore on decode failure to prevent leak - self._pending_semaphore.release() - raise + return image_converted except httpx.HTTPError as e: logger.error(f"HTTP error loading image: {e}")