From 54682736b3acc5dc077d0f999a08bf7e0c415cb8 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Tue, 3 Feb 2026 13:30:35 -0500 Subject: [PATCH 01/33] upd --- vllm/multimodal/media/image.py | 113 +++++++++++++++++++++++++++++++-- 1 file changed, 109 insertions(+), 4 deletions(-) diff --git a/vllm/multimodal/media/image.py b/vllm/multimodal/media/image.py index 977a67007363..309f8ea15447 100644 --- a/vllm/multimodal/media/image.py +++ b/vllm/multimodal/media/image.py @@ -3,6 +3,7 @@ from io import BytesIO from pathlib import Path +from typing import TypeAlias import pybase64 import torch @@ -15,8 +16,16 @@ logger = init_logger(__file__) +# Image output can be either PIL Image or Tensor (from nvJPEG) +ImageOutput: TypeAlias = Image.Image | torch.Tensor + + +class ImageMediaIO(MediaIO[ImageOutput]): + # Class-level counters for nvJPEG statistics + _nvjpeg_success_count: int = 0 + _nvjpeg_fallback_count: int = 0 + _nvjpeg_available: bool | None = None # Lazy initialization -class ImageMediaIO(MediaIO[Image.Image]): def __init__(self, image_mode: str = "RGB", **kwargs) -> None: super().__init__() @@ -47,6 +56,87 @@ def __init__(self, image_mode: str = "RGB", **kwargs) -> None: ) self.rgba_background_color = rgba_bg + # Check nvJPEG availability on first instantiation + if ImageMediaIO._nvjpeg_available is None: + ImageMediaIO._nvjpeg_available = self._check_nvjpeg_available() + + @staticmethod + def _check_nvjpeg_available() -> bool: + """Check if nvJPEG is available (CUDA + torchvision decode_jpeg).""" + try: + # torch.cuda.is_available() can raise RuntimeError if CUDA driver fails + if not torch.cuda.is_available(): + logger.debug("nvJPEG not available: CUDA not available") + return False + # Check if torchvision decode_jpeg is available + from torchvision.io import decode_jpeg # noqa: F401 + logger.info("nvJPEG available: using GPU-accelerated JPEG decoding") + return True + except ImportError: + logger.debug("nvJPEG not available: torchvision.io.decode_jpeg not found") + return False + except RuntimeError as e: + # CUDA driver initialization can fail with RuntimeError + logger.debug(f"nvJPEG not available: CUDA driver error - {e}") + return False + except Exception as e: + logger.debug(f"nvJPEG not available: {e}") + return False + + @staticmethod + def _is_jpeg(data: bytes) -> bool: + """Detect JPEG format from magic bytes.""" + return len(data) >= 3 and data[:3] == b'\xff\xd8\xff' + + def _decode_with_nvjpeg(self, data: bytes) -> torch.Tensor | None: + """ + Try to decode JPEG using nvJPEG (GPU-accelerated). + + Returns: + torch.Tensor in CHW format on CPU, or None on failure. + Note: Decoding happens on GPU for speed, then moved to CPU + for compatibility with vLLM's memory pinning. + """ + try: + from torchvision.io import decode_jpeg, ImageReadMode + + # Convert bytes to tensor + data_tensor = torch.frombuffer(bytearray(data), dtype=torch.uint8) + + # Select mode based on image_mode + if self.image_mode == "RGB": + mode = ImageReadMode.RGB + elif self.image_mode == "L": + mode = ImageReadMode.GRAY + else: + mode = ImageReadMode.UNCHANGED + + # Decode on GPU using nvJPEG + tensor = decode_jpeg(data_tensor, mode=mode, device='cuda') + + # Move to CPU for compatibility with vLLM's memory pinning + tensor = tensor.cpu() + + # Update success counter and log periodically + ImageMediaIO._nvjpeg_success_count += 1 + self._log_stats_if_needed() + + return tensor # CHW tensor on CPU + + except Exception as e: + logger.debug(f"nvJPEG decode failed, falling back to PIL: {e}") + ImageMediaIO._nvjpeg_fallback_count += 1 + return None + + def _log_stats_if_needed(self) -> None: + """Log nvJPEG statistics periodically.""" + total = ImageMediaIO._nvjpeg_success_count + ImageMediaIO._nvjpeg_fallback_count + if total > 0 and total % 100 == 0: + logger.info( + f"nvJPEG decode stats: {ImageMediaIO._nvjpeg_success_count} successful, " + f"{ImageMediaIO._nvjpeg_fallback_count} fallback to PIL" + ) + def _convert_image_mode( self, image: Image.Image | MediaWithBytes[Image.Image] ) -> Image.Image: @@ -60,16 +150,31 @@ def _convert_image_mode( else: return convert_image_mode(image, self.image_mode) - def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]: + def load_bytes(self, data: bytes) -> MediaWithBytes[ImageOutput]: + # Try nvJPEG for JPEG images when available + if ImageMediaIO._nvjpeg_available and self._is_jpeg(data): + tensor = self._decode_with_nvjpeg(data) + if tensor is not None: + return MediaWithBytes(tensor, data) + + # Fallback to PIL for non-JPEG or when nvJPEG fails image = Image.open(BytesIO(data)) return MediaWithBytes(self._convert_image_mode(image), data) - def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]: + def load_base64(self, media_type: str, data: str) -> MediaWithBytes[ImageOutput]: return self.load_bytes(pybase64.b64decode(data, validate=True)) - def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]: + def load_file(self, filepath: Path) -> MediaWithBytes[ImageOutput]: with open(filepath, "rb") as f: data = f.read() + + # Try nvJPEG for JPEG images when available + if ImageMediaIO._nvjpeg_available and self._is_jpeg(data): + tensor = self._decode_with_nvjpeg(data) + if tensor is not None: + return MediaWithBytes(tensor, data) + + # Fallback to PIL for non-JPEG or when nvJPEG fails image = Image.open(BytesIO(data)) return MediaWithBytes(self._convert_image_mode(image), data) From 60557b2992ac996a382e00633696cb573e00c6d0 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Tue, 3 Feb 2026 21:24:56 -0500 Subject: [PATCH 02/33] upd --- requirements/cuda.txt | 4 +++- vllm/v1/utils.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 380bbc30e3d1..fda957258600 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -12,4 +12,6 @@ torchvision==0.24.1 # Required for phi3v processor. See https://github.com/pytor # FlashInfer should be updated together with the Dockerfile flashinfer-python==0.5.3 # FA4 -flash-attn-cute @ git+https://github.com/Dao-AILab/flash-attention.git@2580b5a4882562640f3cfbffd2bb8d2de9268f9f#subdirectory=flash_attn/cute \ No newline at end of file +flash-attn-cute @ git+https://github.com/Dao-AILab/flash-attention.git@2580b5a4882562640f3cfbffd2bb8d2de9268f9f#subdirectory=flash_attn/cute +# nvimgcodec +nvidia-nvimgcodec-cu13==0.7.0.11 \ No newline at end of file diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 75ad304ddf1a..b3ebd72af992 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -413,7 +413,7 @@ def tensor_data(tensor: torch.Tensor) -> memoryview: Returns: A memoryview of the tensor data as uint8. """ - return tensor.flatten().contiguous().view(torch.uint8).numpy().data + return tensor.cpu().flatten().contiguous().view(torch.uint8).numpy().data @dataclass From 54968a859b7ce11a23a3aebe3ed9f5e8ee9664ac Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Sat, 10 Jan 2026 22:28:32 +0000 Subject: [PATCH 03/33] Add tensor IPC transfer mechanism for multimodal data Make tensor IPC datapath optional/config-based Signed-off-by: Brandon Pelfrey --- tests/v1/test_tensor_ipc_queue.py | 621 ++++++++++++++++++++++++++++++ vllm/config/model.py | 6 + vllm/config/multimodal.py | 11 + vllm/engine/arg_utils.py | 23 ++ vllm/entrypoints/cli/serve.py | 1 + vllm/envs.py | 11 + vllm/multimodal/inputs.py | 2 +- vllm/v1/engine/core.py | 46 ++- vllm/v1/engine/core_client.py | 33 +- vllm/v1/engine/utils.py | 45 ++- vllm/v1/serial_utils.py | 150 +++++++- vllm/v1/utils.py | 4 + 12 files changed, 926 insertions(+), 27 deletions(-) create mode 100644 tests/v1/test_tensor_ipc_queue.py diff --git a/tests/v1/test_tensor_ipc_queue.py b/tests/v1/test_tensor_ipc_queue.py new file mode 100644 index 000000000000..62ff697c2724 --- /dev/null +++ b/tests/v1/test_tensor_ipc_queue.py @@ -0,0 +1,621 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Tests for tensor IPC queue functionality.""" + +import multiprocessing as mp +from typing import Any + +import pytest +import torch +import torch.multiprocessing as torch_mp + +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, TensorIpcData, TensorIpcHandle + + +@pytest.fixture(scope="module", autouse=True) +def setup_multiprocessing(): + """Set multiprocessing start method to 'spawn' for compatibility.""" + try: + torch_mp.set_start_method('spawn', force=True) + except RuntimeError: + # Already set, which is fine + pass + yield + + +def encoder_process( + tensor_queues: list[torch_mp.Queue], + result_queue: mp.Queue, + target_engine: int, + tensor_data: dict[str, Any], + ready_event: mp.Event, +): + """Process that encodes and sends CUDA tensors via queue.""" + try: + # Create encoder with tensor queues + encoder = MsgpackEncoder(tensor_queues=tensor_queues) + encoder.set_target_engine(target_engine) + + # Create a CUDA tensor if available + if torch.cuda.is_available(): + device = "cuda:0" + tensor = torch.randn( + *tensor_data["shape"], dtype=tensor_data["dtype"], device=device + ) + else: + # Fall back to CPU for testing + device = "cpu" + tensor = torch.randn(*tensor_data["shape"], dtype=tensor_data["dtype"]) + + # Encode the tensor + encoded = encoder.encode({"test_tensor": tensor}) + + # Signal that encoding is complete before sending result + ready_event.set() + + result_queue.put( + { + "success": True, + "encoded_length": len(encoded), + "device": str(device), + "tensor_shape": tuple(tensor.shape), + } + ) + except Exception as e: + import traceback + ready_event.set() # Signal even on failure + result_queue.put({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }) + + +def decoder_process( + tensor_queue: torch_mp.Queue, + result_queue: mp.Queue, + expected_shape: tuple, + encoder_ready: mp.Event, +): + """Process that decodes and receives CUDA tensors from queue.""" + try: + # Create decoder with tensor queue + decoder = MsgpackDecoder(tensor_queue=tensor_queue) + + # Wait for encoder to finish sending + if not encoder_ready.wait(timeout=10.0): + raise TimeoutError("Encoder did not signal ready") + + # Try to get tensor from queue directly for testing + ipc_data = tensor_queue.get(timeout=5.0) + + result_queue.put( + { + "success": True, + "tensor_id": ipc_data.tensor_id, + "tensor_shape": tuple(ipc_data.tensor.shape), + "device": str(ipc_data.tensor.device), + "matches_expected": tuple(ipc_data.tensor.shape) == expected_shape, + } + ) + except Exception as e: + import traceback + result_queue.put({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_tensor_queue_basic(): + """Test basic CUDA tensor sharing via queue.""" + # Set up queues and synchronization + num_engines = 2 + tensor_queues = [torch_mp.Queue() for _ in range(num_engines)] + result_queue = mp.Queue() + encoder_ready = mp.Event() + + target_engine = 0 + tensor_shape = (4, 8, 16) + tensor_dtype = torch.float32 + + # Start encoder process + encoder_proc = mp.Process( + target=encoder_process, + args=( + tensor_queues, + result_queue, + target_engine, + {"shape": tensor_shape, "dtype": tensor_dtype}, + encoder_ready, + ), + ) + encoder_proc.start() + + # Start decoder process + decoder_proc = mp.Process( + target=decoder_process, + args=(tensor_queues[target_engine], result_queue, tensor_shape, encoder_ready), + ) + decoder_proc.start() + + # Wait for processes and collect results + encoder_result = result_queue.get(timeout=10.0) + decoder_result = result_queue.get(timeout=10.0) + + encoder_proc.join(timeout=5.0) + decoder_proc.join(timeout=5.0) + + # Verify results + assert encoder_result["success"], f"Encoder failed: {encoder_result.get('error')}\n{encoder_result.get('traceback', '')}" + assert decoder_result["success"], f"Decoder failed: {decoder_result.get('error')}\n{decoder_result.get('traceback', '')}" + assert decoder_result["matches_expected"], "Tensor shape mismatch" + assert "cuda" in decoder_result["device"], "Tensor not on CUDA device" + + +def test_cpu_tensor_fallback(): + """Test that CPU tensors use standard serialization path.""" + encoder = MsgpackEncoder(tensor_queues=None) + + # Create a CPU tensor + tensor = torch.randn(3, 4, dtype=torch.float32) + + # Encode the tensor (should use standard path, not queue) + encoded = encoder.encode({"test_tensor": tensor}) + + # Verify encoding succeeded + assert len(encoded) > 0 + assert isinstance(encoded, (list, tuple)) + + # Basic check: no queue should be used, so tensor goes through standard path + # This is mainly to ensure no exceptions are raised + + +def test_encoder_without_target_engine(): + """Test that encoder handles missing target engine gracefully.""" + tensor_queues = [torch_mp.Queue()] + encoder = MsgpackEncoder(tensor_queues=tensor_queues) + + # Don't set target engine + if torch.cuda.is_available(): + tensor = torch.randn(2, 3, device="cuda:0") + else: + tensor = torch.randn(2, 3) + + # Should fall back to standard serialization + encoded = encoder.encode({"test_tensor": tensor}) + assert len(encoded) > 0 + + +def test_decoder_buffer_management(): + """Test decoder's tensor buffer management when draining queue.""" + tensor_queue = torch_mp.Queue() + + # Put multiple tensors in queue using TensorIpcData + tensors = { + "tensor_1": torch.randn(2, 3), + "tensor_2": torch.randn(4, 5), + "tensor_3": torch.randn(6, 7), + } + + for tensor_id, tensor in tensors.items(): + ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=tensor) + tensor_queue.put(ipc_data) + + # Create decoder + decoder = MsgpackDecoder(tensor_queue=tensor_queue) + + # Request tensor_3 (should buffer tensor_1 and tensor_2) + handle = TensorIpcHandle( + tensor_id="tensor_3", + shape=[6, 7], + dtype="float32", + device="cpu", + ) + + result = decoder._decode_cuda_queue_tensor(handle) + assert result.shape == (6, 7) + + # Verify buffer has tensor_1 and tensor_2 + assert "tensor_1" in decoder._tensor_buffer + assert "tensor_2" in decoder._tensor_buffer + + # Request buffered tensor + handle2 = TensorIpcHandle( + tensor_id="tensor_1", + shape=[2, 3], + dtype="float32", + device="cpu", + ) + + result2 = decoder._decode_cuda_queue_tensor(handle2) + assert result2.shape == (2, 3) + # tensor_1 should be removed from buffer + assert "tensor_1" not in decoder._tensor_buffer + + +def api_server_worker( + server_id: int, + tensor_queue: torch_mp.Queue, + result_queue: mp.Queue, + barrier: mp.Barrier, + retrieval_done: mp.Event, +): + """Worker simulating an API server sending tensors.""" + try: + # Each server sends a unique tensor + tensor = torch.ones(server_id + 1, server_id + 2) * server_id + tensor_id = f"server_{server_id}_tensor" + + # Wait for all servers to be ready + barrier.wait() + + # Send tensor using TensorIpcData + ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=tensor) + tensor_queue.put(ipc_data) + + result_queue.put({"server_id": server_id, "success": True}) + + # Keep process alive until main process has retrieved all tensors + # This prevents shared memory handles from being invalidated + retrieval_done.wait(timeout=30.0) + except Exception as e: + import traceback + result_queue.put({ + "server_id": server_id, + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }) + + +def test_multiple_api_servers_to_engine(): + """Test multiple API servers sending to one engine core via multiprocessing.""" + num_api_servers = 3 + tensor_queue = torch_mp.Queue() + result_queue = mp.Queue() + barrier = mp.Barrier(num_api_servers) + retrieval_done = mp.Event() + + # Start multiple API server processes + processes = [] + for server_id in range(num_api_servers): + proc = mp.Process( + target=api_server_worker, + args=(server_id, tensor_queue, result_queue, barrier, retrieval_done), + ) + proc.start() + processes.append(proc) + + # Collect results from all servers + results = [] + for _ in range(num_api_servers): + result = result_queue.get(timeout=10.0) + results.append(result) + + # Verify all servers succeeded + for result in results: + assert result["success"], f"Server {result['server_id']} failed: {result.get('error')}" + + # Verify all tensors are in queue + received_tensors = [] + for _ in range(num_api_servers): + ipc_data = tensor_queue.get(timeout=1.0) + received_tensors.append((ipc_data.tensor_id, ipc_data.tensor)) + + assert len(received_tensors) == num_api_servers + + # Verify tensor content (order may vary with multiprocessing) + tensor_by_id = {tid: t for tid, t in received_tensors} + for server_id in range(num_api_servers): + expected_id = f"server_{server_id}_tensor" + assert expected_id in tensor_by_id, f"Missing tensor from server {server_id}" + expected_tensor = torch.ones(server_id + 1, server_id + 2) * server_id + assert torch.allclose(tensor_by_id[expected_id], expected_tensor) + + # Signal workers that retrieval is complete + retrieval_done.set() + + # Wait for all processes to complete + for proc in processes: + proc.join(timeout=5.0) + + +def mixed_tensor_encoder_process( + tensor_queues: list[torch_mp.Queue], + result_queue: mp.Queue, + ready_event: mp.Event, + retrieval_done: mp.Event, +): + """Process that encodes mixed CPU/CUDA tensors (old behavior: only CUDA via IPC).""" + try: + # Use old behavior: multimodal_tensor_ipc defaults to True but only CUDA went through + # For this test, we want to test the old behavior where only CUDA uses IPC + encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=False) + encoder.set_target_engine(0) + + # Create only CUDA tensor for IPC (CPU will be serialized) + # But actually, let's just send CUDA tensor directly + cuda_tensor = torch.randn(4, 5, device="cuda:0") + + # Manually send via IPC to test the mechanism + tensor_id = "test_cuda_tensor" + cuda_tensor_shared = cuda_tensor.share_memory_() + from vllm.v1.serial_utils import TensorIpcData + ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=cuda_tensor_shared) + tensor_queues[0].put(ipc_data, timeout=10.0) + + ready_event.set() + + result_queue.put({ + "success": True, + "sent_cuda": True, + }) + + # Keep process alive until decoder has retrieved the tensor + retrieval_done.wait(timeout=30.0) + except Exception as e: + import traceback + ready_event.set() + result_queue.put({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }) + + +def mixed_tensor_decoder_process( + tensor_queue: torch_mp.Queue, + result_queue: mp.Queue, + encoder_ready: mp.Event, + retrieval_done: mp.Event, +): + """Process that retrieves mixed tensors from queue.""" + try: + # Wait for encoder to finish + if not encoder_ready.wait(timeout=10.0): + raise TimeoutError("Encoder did not signal ready") + + # Try to get CUDA tensor from queue + ipc_data = tensor_queue.get(timeout=5.0) + + result_queue.put({ + "success": True, + "is_cuda": ipc_data.tensor.is_cuda, + "shape": tuple(ipc_data.tensor.shape), + }) + + # Signal that retrieval is complete + retrieval_done.set() + except Exception as e: + import traceback + retrieval_done.set() # Signal even on failure + result_queue.put({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_mixed_cpu_cuda_tensors(): + """Test encoding with mixed CPU and CUDA tensors using multiprocessing.""" + tensor_queues = [torch_mp.Queue()] + result_queue = mp.Queue() + encoder_ready = mp.Event() + retrieval_done = mp.Event() + + # Start encoder process + encoder_proc = mp.Process( + target=mixed_tensor_encoder_process, + args=(tensor_queues, result_queue, encoder_ready, retrieval_done), + ) + encoder_proc.start() + + # Start decoder process + decoder_proc = mp.Process( + target=mixed_tensor_decoder_process, + args=(tensor_queues[0], result_queue, encoder_ready, retrieval_done), + ) + decoder_proc.start() + + # Get results + encoder_result = result_queue.get(timeout=10.0) + decoder_result = result_queue.get(timeout=10.0) + + encoder_proc.join(timeout=5.0) + decoder_proc.join(timeout=5.0) + + # Verify encoder succeeded + assert encoder_result["success"], f"Encoder failed: {encoder_result.get('error')}\n{encoder_result.get('traceback', '')}" + + # Verify decoder succeeded and got CUDA tensor + assert decoder_result["success"], f"Decoder failed: {decoder_result.get('error')}\n{decoder_result.get('traceback', '')}" + assert decoder_result["is_cuda"], "Retrieved tensor is not on CUDA" + assert decoder_result["shape"] == (4, 5), f"Unexpected shape: {decoder_result['shape']}" + + +def cpu_tensor_ipc_encoder_process( + tensor_queues: list[torch_mp.Queue], + result_queue: mp.Queue, + target_engine: int, + tensor_shape: tuple, + ready_event: mp.Event, + retrieval_done: mp.Event, +): + """Process that encodes and sends CPU tensors via IPC queue.""" + try: + # Create encoder with IPC enabled for all tensors + encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=True) + encoder.set_target_engine(target_engine) + + # Create a CPU tensor + tensor = torch.randn(*tensor_shape, dtype=torch.float32) + + # Encode the tensor (should use IPC queue, not standard serialization) + encoded = encoder.encode({"test_tensor": tensor}) + + # Signal that encoding is complete + ready_event.set() + + result_queue.put( + { + "success": True, + "encoded_length": len(encoded), + "device": str(tensor.device), + "tensor_shape": tuple(tensor.shape), + } + ) + + # Keep process alive until decoder has retrieved the tensor + # This is necessary for CPU tensor shared memory to remain valid + retrieval_done.wait(timeout=30.0) + except Exception as e: + import traceback + ready_event.set() + result_queue.put({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }) + + +def cpu_tensor_ipc_decoder_process( + tensor_queue: torch_mp.Queue, + result_queue: mp.Queue, + expected_shape: tuple, + encoder_ready: mp.Event, + retrieval_done: mp.Event, +): + """Process that decodes and receives CPU tensors from IPC queue.""" + try: + # Create decoder with tensor queue + decoder = MsgpackDecoder(tensor_queue=tensor_queue) + + # Wait for encoder to finish sending + if not encoder_ready.wait(timeout=10.0): + raise TimeoutError("Encoder did not signal ready") + + # Get tensor from queue + ipc_data = tensor_queue.get(timeout=5.0) + + result_queue.put( + { + "success": True, + "tensor_id": ipc_data.tensor_id, + "tensor_shape": tuple(ipc_data.tensor.shape), + "device": str(ipc_data.tensor.device), + "matches_expected": tuple(ipc_data.tensor.shape) == expected_shape, + "is_cpu": ipc_data.tensor.device.type == "cpu", + } + ) + + # Signal that retrieval is complete + retrieval_done.set() + except Exception as e: + import traceback + retrieval_done.set() # Signal even on failure + result_queue.put({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }) + + +def test_cpu_tensor_ipc(): + """Test CPU tensor sharing via IPC queue when multimodal_tensor_ipc is enabled.""" + # Set up queues and synchronization + num_engines = 2 + tensor_queues = [torch_mp.Queue() for _ in range(num_engines)] + result_queue = mp.Queue() + encoder_ready = mp.Event() + retrieval_done = mp.Event() + + target_engine = 0 + tensor_shape = (3, 5, 7) + + # Start encoder process + encoder_proc = mp.Process( + target=cpu_tensor_ipc_encoder_process, + args=( + tensor_queues, + result_queue, + target_engine, + tensor_shape, + encoder_ready, + retrieval_done, + ), + ) + encoder_proc.start() + + # Start decoder process + decoder_proc = mp.Process( + target=cpu_tensor_ipc_decoder_process, + args=(tensor_queues[target_engine], result_queue, tensor_shape, encoder_ready, retrieval_done), + ) + decoder_proc.start() + + # Wait for processes and collect results + encoder_result = result_queue.get(timeout=10.0) + decoder_result = result_queue.get(timeout=10.0) + + encoder_proc.join(timeout=5.0) + decoder_proc.join(timeout=5.0) + + # Verify results + assert encoder_result["success"], f"Encoder failed: {encoder_result.get('error')}\n{encoder_result.get('traceback', '')}" + assert decoder_result["success"], f"Decoder failed: {decoder_result.get('error')}\n{decoder_result.get('traceback', '')}" + assert decoder_result["matches_expected"], "Tensor shape mismatch" + assert decoder_result["is_cpu"], "Tensor not on CPU device" + + +def test_ipc_disabled_mode(): + """Test that IPC is disabled when multimodal_tensor_ipc=False.""" + tensor_queues = [torch_mp.Queue()] + + # Create encoder with IPC disabled + encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=False) + encoder.set_target_engine(0) + + # Create a CPU tensor + cpu_tensor = torch.randn(2, 3, dtype=torch.float32) + + # Encode the tensor (should use standard serialization, not IPC) + encoded = encoder.encode({"test_tensor": cpu_tensor}) + + # Verify encoding succeeded + assert len(encoded) > 0 + assert isinstance(encoded, (list, tuple)) + + # Verify queue is empty (no IPC was used) + assert tensor_queues[0].empty(), "Tensor queue should be empty when IPC is disabled" + + # If CUDA is available, test with CUDA tensor too + if torch.cuda.is_available(): + cuda_tensor = torch.randn(4, 5, device="cuda:0") + encoded_cuda = encoder.encode({"cuda_tensor": cuda_tensor}) + assert len(encoded_cuda) > 0 + assert tensor_queues[0].empty(), "Tensor queue should be empty for CUDA tensor when IPC is disabled" + + +def test_mixed_cpu_cuda_with_ipc_enabled(): + """Test that encoder is configured correctly for IPC with all tensor types.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + tensor_queues = [torch_mp.Queue()] + + # Create encoder with IPC enabled for all tensors + encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=True) + encoder.set_target_engine(0) + + # Verify encoder configuration + assert encoder.multimodal_tensor_ipc is True, "IPC should be enabled" + assert encoder.tensor_queues is not None, "Tensor queues should be set" + assert encoder.target_engine_index == 0, "Target engine should be set" + + # Note: Actual IPC transfer only works across processes (tested in test_cpu_tensor_ipc) + # This test just verifies the configuration is correct + diff --git a/vllm/config/model.py b/vllm/config/model.py index df25e900c354..962ca3caccb9 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -311,6 +311,8 @@ class ModelConfig: interleave_mm_strings: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None video_pruning_rate: InitVar[float | None] = None + maximum_concurrent_videos: InitVar[int | None] = None + multimodal_tensor_ipc: InitVar[bool | None] = None def compute_hash(self) -> str: """ @@ -425,6 +427,8 @@ def __post_init__( interleave_mm_strings: bool | None, skip_mm_profiling: bool | None, video_pruning_rate: float | None, + maximum_concurrent_videos: int | None, + multimodal_tensor_ipc: bool | None, ) -> None: # Keep set served_model_name before maybe_model_redirect(self.model) self.served_model_name = get_served_model_name( @@ -588,6 +592,8 @@ def __post_init__( interleave_mm_strings=interleave_mm_strings, skip_mm_profiling=skip_mm_profiling, video_pruning_rate=video_pruning_rate, + max_concurrent_videos=maximum_concurrent_videos, + multimodal_tensor_ipc=multimodal_tensor_ipc, ) mm_config_kwargs = { diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index ecb346af8f3c..6e3bcddda4c9 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -140,6 +140,17 @@ class MultiModalConfig: Value sits in range [0;1) and determines fraction of media tokens from each video to be pruned. """ + max_concurrent_videos: int | None = Field(default=None, gt=0) + """Maximum number of videos that can be preprocessed concurrently in this + process. This limits VRAM usage from video decoding libraries like + PyNvVideoCodec that allocate VRAM separately from PyTorch.""" + multimodal_tensor_ipc: bool | None = None + """Enable IPC (inter-process communication) for multimodal tensors. + When enabled, all multimodal tensors (CUDA and CPU) are transferred + via torch.multiprocessing shared memory for zero-copy IPC. + When disabled, all tensors use standard serialization. + If None, defaults to the value of VLLM_MULTIMODAL_TENSOR_IPC environment + variable (default: True).""" @field_validator("limit_per_prompt", mode="before") @classmethod diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cb82be6b6b6f..c1103d2c4dca 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -474,6 +474,8 @@ class EngineArgs: io_processor_plugin: str | None = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling video_pruning_rate: float = MultiModalConfig.video_pruning_rate + maximum_concurrent_videos: int | None = MultiModalConfig.max_concurrent_videos + multimodal_tensor_ipc: bool | None = MultiModalConfig.multimodal_tensor_ipc # LoRA fields enable_lora: bool = False max_loras: int = LoRAConfig.max_loras @@ -990,6 +992,25 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: multimodal_group.add_argument( "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"] ) + multimodal_group.add_argument( + "--maximum-concurrent-videos", + type=int, + default=None, + help="Maximum number of videos that can be preprocessed concurrently. " + "This limits VRAM usage from video decoding. The count is spread " + "evenly over API server processes.", + ) + multimodal_group.add_argument( + "--enable-multimodal-tensor-ipc", + "--disable-multimodal-tensor-ipc", + action=argparse.BooleanOptionalAction, + default=None, + help="Enable IPC (inter-process communication) for multimodal tensors. " + "When enabled, all multimodal tensors (CUDA and CPU) are transferred " + "via torch.multiprocessing shared memory for zero-copy IPC. " + "When disabled, all tensors use standard serialization. " + "If not specified, defaults to VLLM_MULTIMODAL_TENSOR_IPC env var (default: True).", + ) # LoRA related configs lora_kwargs = get_kwargs(LoRAConfig) @@ -1267,6 +1288,8 @@ def create_model_config(self) -> ModelConfig: override_attention_dtype=self.override_attention_dtype, logits_processors=self.logits_processors, video_pruning_rate=self.video_pruning_rate, + maximum_concurrent_videos=self.maximum_concurrent_videos, + multimodal_tensor_ipc=self.multimodal_tensor_ipc, io_processor_plugin=self.io_processor_plugin, ) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index f06a9391321f..96e59b08eb95 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -209,6 +209,7 @@ def run_multi_api_server(args: argparse.Namespace): stats_update_address=coordinator.get_stats_publish_address() if coordinator else None, + tensor_queues=addresses.tensor_queues, ) # For dp ranks > 0 in external/hybrid DP LB modes, we must delay the diff --git a/vllm/envs.py b/vllm/envs.py index f82dae108f6a..7d953aecbfb8 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -74,6 +74,7 @@ VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_MEDIA_CONNECTOR: str = "http" VLLM_MM_HASHER_ALGORITHM: str = "blake3" + VLLM_MULTIMODAL_TENSOR_IPC: bool = True VLLM_TARGET_DEVICE: str = "cuda" VLLM_MAIN_CUDA_VERSION: str = "12.9" VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest" @@ -807,6 +808,7 @@ def get_vllm_port() -> int | None: # imported at runtime. # If a non-existing backend is used, an AssertionError will be thrown. "VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"), +<<<<<<< HEAD # Hash algorithm for multimodal content hashing. # - "blake3": Default, fast cryptographic hash (not FIPS 140-3 compliant) # - "sha256": FIPS 140-3 compliant, widely supported @@ -817,6 +819,14 @@ def get_vllm_port() -> int | None: "blake3", ["blake3", "sha256", "sha512"], case_sensitive=False, +======= + # Enable IPC (inter-process communication) for multimodal tensors. + # When enabled, all multimodal tensors (CUDA and CPU) are transferred + # via torch.multiprocessing shared memory for zero-copy IPC. + # When disabled, all tensors use standard serialization. + "VLLM_MULTIMODAL_TENSOR_IPC": lambda: bool( + int(os.getenv("VLLM_MULTIMODAL_TENSOR_IPC", "1")) +>>>>>>> 5a8cec31e (Add tensor IPC transfer mechanism for multimodal data) ), # Path to the XLA persistent cache directory. # Only used for XLA devices such as TPUs. @@ -1758,6 +1768,7 @@ def compile_factors() -> dict[str, object]: "VLLM_MAX_AUDIO_CLIP_FILESIZE_MB", "VLLM_VIDEO_LOADER_BACKEND", "VLLM_MEDIA_CONNECTOR", + "VLLM_MULTIMODAL_TENSOR_IPC", "VLLM_ASSETS_CACHE", "VLLM_ASSETS_CACHE_MODEL_CLEAN", "VLLM_WORKER_MULTIPROC_METHOD", diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 7b12158763c3..a5e802e328a7 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -575,7 +575,7 @@ def _shape_before_after(tensor: torch.Tensor): (*shape_before, shape_concat, *shape_after), dtype=batch[0].dtype, device=batch[0].device, - pin_memory=pin_memory, + pin_memory=pin_memory and batch[0].device.type == 'cpu', ) return torch.concat(batch, dim=self.dim, out=out) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 141e5a459c5b..3d4b28356380 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -648,6 +648,7 @@ def __init__( client_handshake_address: str | None = None, *, engine_index: int = 0, + tensor_queues: list[Any] | None = None, ): self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]() @@ -668,6 +669,16 @@ def __init__( ) as addresses: self.client_count = len(addresses.outputs) + # Get this engine's tensor IPC queue for receiving multimodal tensors + # Queues are passed directly via constructor since they can't be serialized + self.tensor_queue = None + if tensor_queues and addresses.tensor_queue_index is not None: + self.tensor_queue = tensor_queues[addresses.tensor_queue_index] + logger.info( + "Engine %d using tensor IPC queue for multimodal tensor sharing", + self.engine_index, + ) + # Set up data parallel environment. self.has_coordinator = addresses.coordinator_output is not None self.frontend_stats_publish_address = ( @@ -875,10 +886,20 @@ def startup_handshake( for key, value in init_message.parallel_config.items(): setattr(parallel_config, key, value) - return init_message.addresses + # Store tensor_queue_index for engine to access + addresses = init_message.addresses + addresses.tensor_queue_index = init_message.tensor_queue_index + + return addresses @staticmethod - def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): + def run_engine_core( + *args, + dp_rank: int = 0, + local_dp_rank: int = 0, + tensor_queues: list[Any] | None = None, + **kwargs + ): """Launch EngineCore busy loop in background process.""" # Signal handler used for graceful termination. @@ -915,15 +936,10 @@ def signal_handler(signum, frame): if data_parallel and vllm_config.model_config.is_moe: # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank - engine_core = DPEngineCoreProc(*args, **kwargs) + parallel_config.data_parallel_rank_local = local_dp_rank + engine_core = DPEngineCoreProc(*args, tensor_queues=tensor_queues, **kwargs) else: - # Non-MoE DP ranks are completely independent, so treat like DP=1. - # Note that parallel_config.data_parallel_index will still reflect - # the original DP rank. - parallel_config.data_parallel_size = 1 - parallel_config.data_parallel_size_local = 1 - parallel_config.data_parallel_rank = 0 - engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs) + engine_core = EngineCoreProc(*args, tensor_queues=tensor_queues, **kwargs) engine_core.run_busy_loop() @@ -1073,9 +1089,11 @@ def process_input_sockets( ): """Input socket IO thread.""" - # Msgpack serialization decoding. - add_request_decoder = MsgpackDecoder(EngineCoreRequest) - generic_decoder = MsgpackDecoder() + # Msgpack serialization decoding with tensor queue for CUDA tensors. + add_request_decoder = MsgpackDecoder( + EngineCoreRequest, tensor_queue=self.tensor_queue + ) + generic_decoder = MsgpackDecoder(tensor_queue=self.tensor_queue) with ExitStack() as stack, zmq.Context() as ctx: input_sockets = [ @@ -1252,6 +1270,7 @@ def __init__( executor_class: type[Executor], log_stats: bool, client_handshake_address: str | None = None, + tensor_queues: list[Any] | None = None, ): assert vllm_config.model_config.is_moe, ( "DPEngineCoreProc should only be used for MoE models" @@ -1273,6 +1292,7 @@ def __init__( log_stats, client_handshake_address, engine_index=dp_rank, + tensor_queues=tensor_queues, ) def _init_data_parallel(self, vllm_config: VllmConfig): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index c9a1d53c8fb7..90dc4be18afa 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -450,10 +450,7 @@ def __init__( client_addresses: dict[str, str] | None = None, ): self.vllm_config = vllm_config - # Serialization setup. - self.encoder = MsgpackEncoder() - self.decoder = MsgpackDecoder(EngineCoreOutputs) - + # ZMQ setup. sync_ctx = zmq.Context(io_threads=2) self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx @@ -469,11 +466,14 @@ def __init__( self.engines_running = False self.stats_update_address: str | None = None + tensor_queues = None if client_addresses: # Engines are managed externally to this client. input_address = client_addresses["input_address"] output_address = client_addresses["output_address"] self.stats_update_address = client_addresses.get("stats_update_address") + # Tensor queues passed via client_addresses for multi-API-server case + tensor_queues = client_addresses.get("tensor_queues") else: # Engines are managed by this client. with launch_core_engines(vllm_config, executor_class, log_stats) as ( @@ -487,11 +487,32 @@ def __init__( (input_address,) = addresses.inputs (output_address,) = addresses.outputs self.stats_update_address = addresses.frontend_stats_publish_address + tensor_queues = addresses.tensor_queues if coordinator is not None: assert self.stats_update_address == ( coordinator.get_stats_publish_address() ) + # Serialization setup with tensor queues for multimodal tensor IPC. + # Get IPC config from multimodal_config, falling back to env var + multimodal_tensor_ipc = True # Default + if vllm_config.model_config.multimodal_config is not None: + mm_ipc = vllm_config.model_config.multimodal_config.multimodal_tensor_ipc + if mm_ipc is not None: + multimodal_tensor_ipc = mm_ipc + else: + # Fall back to environment variable + from vllm import envs + multimodal_tensor_ipc = envs.VLLM_MULTIMODAL_TENSOR_IPC + + self.encoder = MsgpackEncoder( + tensor_queues=tensor_queues, + multimodal_tensor_ipc=multimodal_tensor_ipc, + ) + self.decoder = MsgpackDecoder(EngineCoreOutputs) + # Store tensor queues for routing + self.resources.tensor_queues = tensor_queues + # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( self.ctx, input_address, zmq.ROUTER, bind=True @@ -903,6 +924,10 @@ def _send_input( if engine is None: engine = self.core_engine + # Set target engine index for CUDA tensor routing + engine_index = int.from_bytes(engine, "little") + self.encoder.set_target_engine(engine_index) + message = (request_type.value, *self.encoder.encode(request)) return self._send_input_message(message, engine, request) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 5db3a53266f0..39503ba539bb 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -9,10 +9,11 @@ from enum import Enum, auto from multiprocessing import Process, connection from multiprocessing.process import BaseProcess -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from unittest.mock import patch import msgspec +import torch.multiprocessing as torch_mp import zmq from vllm import envs @@ -64,6 +65,11 @@ class EngineZmqAddresses: # Not used by engine, just relayed to front-end in handshake response. # Only required for external DP LB case. frontend_stats_publish_address: str | None = None + # Tensor IPC queues for sharing CUDA tensors between API servers and engines + # One queue per engine core for direct GPU tensor transfer + tensor_queues: list[Any] | None = None + # Index of this engine's tensor queue (set during handshake) + tensor_queue_index: int | None = None @dataclass @@ -75,6 +81,8 @@ class EngineHandshakeMetadata: addresses: EngineZmqAddresses parallel_config: dict[str, int | str | list[int]] + # Index of this engine's tensor queue in addresses.tensor_queues + tensor_queue_index: int | None = None class CoreEngineProcManager: @@ -95,6 +103,7 @@ def __init__( executor_class: type[Executor], log_stats: bool, client_handshake_address: str | None = None, + tensor_queues: list[Any] | None = None, ): context = get_mp_context() common_kwargs = { @@ -108,6 +117,9 @@ def __init__( if client_handshake_address: common_kwargs["client_handshake_address"] = client_handshake_address + # Store tensor_queues for passing to engine processes + self.tensor_queues = tensor_queues + self.processes: list[BaseProcess] = [] local_dp_ranks = [] for index in range(local_engine_count): @@ -124,6 +136,7 @@ def __init__( | { "dp_rank": global_index, "local_dp_rank": local_index, + "tensor_queues": tensor_queues, }, ) ) @@ -800,6 +813,15 @@ def launch_core_engines( offline_mode or local_engines_only or (local_engine_count == dp_size) ) + # Create tensor IPC queues for sharing multimodal tensors between API servers + # and engine cores. One queue per engine core. + # Use torch.multiprocessing for tensor sharing via IPC/shared memory. + # Set start method to 'spawn' for compatibility with multiprocessing. + torch_mp.set_start_method('spawn', force=True) + tensor_queues: list[torch_mp.Queue] = [ + torch_mp.Queue() for _ in range(dp_size) + ] + # Set up input and output addresses. addresses = EngineZmqAddresses( inputs=[ @@ -810,6 +832,7 @@ def launch_core_engines( get_engine_client_zmq_addr(client_local_only, host) for _ in range(num_api_servers) ], + tensor_queues=tensor_queues, ) # Run the DP Coordinator process with rank 0 when in online DP mode. @@ -908,6 +931,7 @@ def launch_core_engines( local_engine_count=local_engine_count, start_index=dp_rank, local_start_index=local_start_index or 0, + tensor_queues=tensor_queues, ) else: local_engine_manager = None @@ -1015,9 +1039,21 @@ def wait_for_engine_startup( if status == "HELLO" and engine.state == CoreEngineState.NEW: # Send init message with DP config info. + # Note: tensor_queues are excluded from serialization as they can't be + # serialized by msgspec. They are passed directly to engine processes + # when spawning them. + addresses_for_handshake = EngineZmqAddresses( + inputs=addresses.inputs, + outputs=addresses.outputs, + coordinator_input=addresses.coordinator_input, + coordinator_output=addresses.coordinator_output, + frontend_stats_publish_address=addresses.frontend_stats_publish_address, + tensor_queues=None, # Don't serialize queues + tensor_queue_index=None, # Will be set separately + ) init_message = msgspec.msgpack.encode( EngineHandshakeMetadata( - addresses=addresses, + addresses=addresses_for_handshake, parallel_config={ k: getattr(parallel_config, k) for k in ( @@ -1026,9 +1062,8 @@ def wait_for_engine_startup( "_data_parallel_master_port_list", "data_parallel_size", ) - } - if coordinated_dp - else {}, + } if coordinated_dp else {}, + tensor_queue_index=eng_index, ) ) handshake_socket.send_multipart((eng_identity, init_message), copy=False) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index a3c30e368b82..89af004ad360 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -41,6 +41,35 @@ CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_RAW_VIEW = 3 + +@dataclasses.dataclass +class TensorIpcData: + """ + Data sent via torch.multiprocessing.Queue for zero-copy IPC. + + Contains the tensor_id and the actual tensor. The tensor is shared + in memory (GPU or CPU) for efficient inter-process communication. + """ + tensor_id: str + tensor: torch.Tensor + + +@dataclasses.dataclass +class TensorIpcHandle: + """ + Handle for a tensor sent via IPC queue (zero-copy transfer). + + Contains only metadata about the tensor. This is serialized via msgpack + and used by the decoder to retrieve the actual tensor from the queue. + The actual tensor is sent separately via torch.multiprocessing.Queue + as TensorIpcData. Works for both CUDA and CPU tensors. + """ + tensor_id: str + shape: list[int] + dtype: str + device: str + + # MultiModalField class serialization type map. # These need to list all possible field types and match them # to factory methods in `MultiModalFieldConfig`. @@ -119,9 +148,18 @@ class MsgpackEncoder: By default, arrays below 256B are serialized inline Larger will get sent via dedicated messages. Note that this is a per-tensor limit. + + When multimodal_tensor_ipc is enabled and tensor_queues is provided, + all multimodal tensors (CUDA and CPU) will be sent via + torch.multiprocessing.Queue for zero-copy IPC instead of serialization. """ - def __init__(self, size_threshold: int | None = None): + def __init__( + self, + size_threshold: int | None = None, + tensor_queues: list[Any] | None = None, + multimodal_tensor_ipc: bool = True, + ): if size_threshold is None: size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) @@ -130,9 +168,21 @@ def __init__(self, size_threshold: int | None = None): # pass custom data to the hook otherwise. self.aux_buffers: list[bytestr] | None = None self.size_threshold = size_threshold + # Tensor IPC queues for sharing multimodal tensors (one per engine core) + self.tensor_queues = tensor_queues + # Enable IPC for all multimodal tensors (CUDA and CPU) + self.multimodal_tensor_ipc = multimodal_tensor_ipc + # Target engine index for routing tensors to the correct queue + self.target_engine_index: int | None = None + # Counter for generating unique tensor IDs + self._tensor_id_counter = 0 if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() + def set_target_engine(self, engine_index: int | None) -> None: + """Set the target engine index for routing multimodal tensors to IPC queues.""" + self.target_engine_index = engine_index + def encode(self, obj: Any) -> Sequence[bytestr]: try: self.aux_buffers = bufs = [b""] @@ -168,7 +218,7 @@ def enc_hook(self, obj: Any) -> Any: int(v) if v is not None else None for v in (obj.start, obj.stop, obj.step) ) - + if isinstance(obj, MultiModalKwargsItem): return self._encode_mm_item(obj) @@ -222,8 +272,64 @@ def _encode_ndarray( def _encode_tensor( self, obj: torch.Tensor - ) -> tuple[str, tuple[int, ...], int | memoryview]: + ) -> tuple[str, tuple[int, ...], int | memoryview] | dict[str, Any]: assert self.aux_buffers is not None + + # Check if we should use IPC for this tensor + # IPC is used when: multimodal_tensor_ipc is enabled, queues are available, + # and we have a target engine + if ( + self.multimodal_tensor_ipc + and self.tensor_queues is not None + and self.target_engine_index is not None + ): + # Send tensor via torch.multiprocessing.Queue for zero-copy IPC + # This works for both CUDA and CPU tensors + # Generate unique tensor ID + tensor_id = f"{id(self)}_{self._tensor_id_counter}" + self._tensor_id_counter += 1 + + try: + # Move tensor to shared memory for IPC + # This is required for proper inter-process communication + if not obj.is_shared(): + obj = obj.share_memory_() + + # Put TensorIpcData (tensor_id + tensor) into the target engine's queue + target_queue = self.tensor_queues[self.target_engine_index] + ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=obj) + # Use a timeout to avoid blocking indefinitely + target_queue.put(ipc_data, timeout=10.0) + + logger.debug( + "Sent tensor %s (shape=%s, device=%s) to engine %d via IPC queue (shared memory)", + tensor_id, + obj.shape, + obj.device, + self.target_engine_index, + ) + + return TensorIpcHandle(tensor_id=tensor_id, shape=list(obj.shape), dtype=str(obj.dtype).removeprefix("torch."), device=str(obj.device)) + except Exception as e: + logger.warning( + "Failed to send tensor via IPC queue: %s. " + "Falling back to standard serialization.", + e, + ) + # Fall through to standard serialization + + + # Standard serialization fallback + # For CUDA tensors without IPC support, we need to move to CPU first + if obj.is_cuda: + if self.multimodal_tensor_ipc and self.tensor_queues is not None: + # Only warn if IPC was expected but unavailable + logger.warning( + "CUDA tensor without IPC support encountered (no target engine set). " + "Moving to CPU for serialization. This will be slow." + ) + obj = obj.cpu() + # view the tensor as a contiguous 1D array of bytes arr_data = tensor_data(obj) if obj.nbytes < self.size_threshold: @@ -281,9 +387,17 @@ class MsgpackDecoder: Note that unlike vanilla `msgspec` Decoders, this interface is generally not thread-safe when encoding tensors / numpy arrays. + + For multimodal tensors sent via torch.multiprocessing.Queue (when IPC is enabled), + they will be retrieved from the queue during decoding. Works for both CUDA and CPU tensors. """ - def __init__(self, t: Any | None = None, share_mem: bool = True): + def __init__( + self, + t: Any | None = None, + share_mem: bool = True, + tensor_queue: Any | None = None, + ): self.share_mem = share_mem self.pin_tensors = is_pin_memory_available() args = () if t is None else (t,) @@ -291,6 +405,11 @@ def __init__(self, t: Any | None = None, share_mem: bool = True): *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook ) self.aux_buffers: Sequence[bytestr] = () + # Tensor IPC queue for receiving multimodal tensors from API servers + self.tensor_queue = tensor_queue + # Buffer for temporarily storing tensors retrieved from queue + # that don't match the current request + self._tensor_buffer: dict[str, torch.Tensor] = {} if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() @@ -309,6 +428,8 @@ def dec_hook(self, t: type, obj: Any) -> Any: if isclass(t): if issubclass(t, np.ndarray): return self._decode_ndarray(obj) + if issubclass(t, TensorIpcHandle): + return self._decode_cuda_queue_tensor(obj) if issubclass(t, torch.Tensor): return self._decode_tensor(obj) if t is slice: @@ -354,6 +475,7 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray: return arr.reshape(shape) def _decode_tensor(self, arr: Any) -> torch.Tensor: + # Standard tensor decoding dtype, shape, data = arr is_aux = isinstance(data, int) buffer = self.aux_buffers[data] if is_aux else data @@ -374,6 +496,18 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor: arr = arr.pin_memory() if self.pin_tensors else arr.clone() # Convert back to proper shape & type return arr.view(torch_dtype).view(shape) + + def _decode_cuda_queue_tensor(self, handle: TensorIpcHandle) -> torch.Tensor: + """Retrieve a tensor from the torch.multiprocessing.Queue (works for CUDA and CPU).""" + + # Drain all available tensors. We save them regardless if this is the one + # we're waiting for as they may arrive out of order from multiple producers. + while handle.tensor_id not in self._tensor_buffer: + ipc_data: TensorIpcData = self.tensor_queue.get(timeout=10.0) + self._tensor_buffer[ipc_data.tensor_id] = ipc_data.tensor + + tensor = self._tensor_buffer.pop(handle.tensor_id) + return tensor def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems: return MultiModalKwargsItems( @@ -409,6 +543,14 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: # Although it violates NestedTensors type, MultiModalKwargs # values are sometimes floats. return obj + if isinstance(obj, TensorIpcHandle): + return self._decode_cuda_queue_tensor(obj) + # Check if this is a dict that represents a TensorIpcHandle + # (msgspec serializes dataclasses as dicts without type info in nested structures) + if isinstance(obj, dict) and 'tensor_id' in obj and 'shape' in obj and 'dtype' in obj and 'device' in obj: + # Convert dict to TensorIpcHandle and decode it + handle = TensorIpcHandle(**obj) + return self._decode_cuda_queue_tensor(handle) if not isinstance(obj, list): raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}") if obj and isinstance(obj[0], str): diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index b3ebd72af992..666e5c5d313d 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -174,6 +174,7 @@ def __init__( input_addresses: list[str], output_addresses: list[str], stats_update_address: str | None = None, + tensor_queues: list[Any] | None = None, ): """Initialize and start API server worker processes. @@ -186,6 +187,7 @@ def __init__( input_addresses: Input addresses for each API server output_addresses: Output addresses for each API server stats_update_address: Optional stats update address + tensor_queues: Optional tensor IPC queues for CUDA tensor sharing """ self.listen_address = listen_address self.sock = sock @@ -206,6 +208,8 @@ def __init__( } if stats_update_address is not None: client_config["stats_update_address"] = stats_update_address + if tensor_queues is not None: + client_config["tensor_queues"] = tensor_queues proc = spawn_context.Process( target=target_server_fn, From 84c8d65fd9b11fcb96449145f8ef03c575184a72 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Sun, 11 Jan 2026 04:03:20 +0000 Subject: [PATCH 04/33] Default to not use Tensor IPC datapath Signed-off-by: Brandon Pelfrey --- vllm/engine/arg_utils.py | 2 +- vllm/envs.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c1103d2c4dca..532d64fff53a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1009,7 +1009,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "When enabled, all multimodal tensors (CUDA and CPU) are transferred " "via torch.multiprocessing shared memory for zero-copy IPC. " "When disabled, all tensors use standard serialization. " - "If not specified, defaults to VLLM_MULTIMODAL_TENSOR_IPC env var (default: True).", + "If not specified, defaults to VLLM_MULTIMODAL_TENSOR_IPC env var (default: False).", ) # LoRA related configs diff --git a/vllm/envs.py b/vllm/envs.py index 7d953aecbfb8..d22718ef67f5 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -75,6 +75,7 @@ VLLM_MEDIA_CONNECTOR: str = "http" VLLM_MM_HASHER_ALGORITHM: str = "blake3" VLLM_MULTIMODAL_TENSOR_IPC: bool = True + VLLM_MULTIMODAL_TENSOR_IPC: bool = False VLLM_TARGET_DEVICE: str = "cuda" VLLM_MAIN_CUDA_VERSION: str = "12.9" VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest" @@ -808,7 +809,6 @@ def get_vllm_port() -> int | None: # imported at runtime. # If a non-existing backend is used, an AssertionError will be thrown. "VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"), -<<<<<<< HEAD # Hash algorithm for multimodal content hashing. # - "blake3": Default, fast cryptographic hash (not FIPS 140-3 compliant) # - "sha256": FIPS 140-3 compliant, widely supported @@ -819,14 +819,13 @@ def get_vllm_port() -> int | None: "blake3", ["blake3", "sha256", "sha512"], case_sensitive=False, -======= + ), # Enable IPC (inter-process communication) for multimodal tensors. # When enabled, all multimodal tensors (CUDA and CPU) are transferred # via torch.multiprocessing shared memory for zero-copy IPC. # When disabled, all tensors use standard serialization. "VLLM_MULTIMODAL_TENSOR_IPC": lambda: bool( int(os.getenv("VLLM_MULTIMODAL_TENSOR_IPC", "1")) ->>>>>>> 5a8cec31e (Add tensor IPC transfer mechanism for multimodal data) ), # Path to the XLA persistent cache directory. # Only used for XLA devices such as TPUs. From e104cbda89bf39c84f9da78df9529fc6c157b35d Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Sat, 10 Jan 2026 20:05:10 -0800 Subject: [PATCH 05/33] Update vllm/v1/engine/core.py Missed as part of rebase. This suggestion makes sense Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Brandon Pelfrey Signed-off-by: Brandon Pelfrey --- vllm/v1/engine/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 3d4b28356380..d41dd65d4b1f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -939,7 +939,7 @@ def signal_handler(signum, frame): parallel_config.data_parallel_rank_local = local_dp_rank engine_core = DPEngineCoreProc(*args, tensor_queues=tensor_queues, **kwargs) else: - engine_core = EngineCoreProc(*args, tensor_queues=tensor_queues, **kwargs) + engine_core = EngineCoreProc(*args, engine_index=dp_rank, tensor_queues=tensor_queues, **kwargs) engine_core.run_busy_loop() From cb0893f9097e6d8853edb50773671c8556c712ff Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Sun, 11 Jan 2026 04:19:37 +0000 Subject: [PATCH 06/33] Enable/Disable Tensor IPC datapath via args with explicit dest Signed-off-by: Brandon Pelfrey --- vllm/engine/arg_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 532d64fff53a..366805d33834 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1010,6 +1010,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "via torch.multiprocessing shared memory for zero-copy IPC. " "When disabled, all tensors use standard serialization. " "If not specified, defaults to VLLM_MULTIMODAL_TENSOR_IPC env var (default: False).", + dest="multimodal_tensor_ipc", ) # LoRA related configs From ba500df8d407156195a4f330f4839c978b0dbd5f Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Sun, 11 Jan 2026 04:34:22 +0000 Subject: [PATCH 07/33] Normalize DP config in engine/core Signed-off-by: Brandon Pelfrey --- vllm/v1/engine/core.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d41dd65d4b1f..f68785b78a34 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -936,10 +936,15 @@ def signal_handler(signum, frame): if data_parallel and vllm_config.model_config.is_moe: # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank - parallel_config.data_parallel_rank_local = local_dp_rank engine_core = DPEngineCoreProc(*args, tensor_queues=tensor_queues, **kwargs) else: - engine_core = EngineCoreProc(*args, engine_index=dp_rank, tensor_queues=tensor_queues, **kwargs) + # Non-MoE DP ranks are completely independent, so treat like DP=1. + # Note that parallel_config.data_parallel_index will still reflect + # the original DP rank. + parallel_config.data_parallel_size = 1 + parallel_config.data_parallel_size_local = 1 + parallel_config.data_parallel_rank = 0 + engine_core = EngineCoreProc(*args, tensor_queues=tensor_queues, **kwargs) engine_core.run_busy_loop() From 12fbbf4a7d2a2275785ed10d86380833f8000d13 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Sun, 11 Jan 2026 04:35:20 +0000 Subject: [PATCH 08/33] Handling TensorIpcHandle for dec_hook Signed-off-by: Brandon Pelfrey --- vllm/v1/serial_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 89af004ad360..9d5edecbff31 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -429,6 +429,9 @@ def dec_hook(self, t: type, obj: Any) -> Any: if issubclass(t, np.ndarray): return self._decode_ndarray(obj) if issubclass(t, TensorIpcHandle): + # msgspec deserializes dataclasses to dicts, so convert to TensorIpcHandle + if isinstance(obj, dict): + obj = TensorIpcHandle(**obj) return self._decode_cuda_queue_tensor(obj) if issubclass(t, torch.Tensor): return self._decode_tensor(obj) From c8c3daf41acb7b85777b129cbb260118de0c9b02 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Sun, 11 Jan 2026 05:03:26 +0000 Subject: [PATCH 09/33] formatting, type fixes, additional issues from CI review bots Signed-off-by: Brandon Pelfrey --- tests/v1/test_tensor_ipc_queue.py | 258 +++++++++++++++++------------- vllm/engine/arg_utils.py | 9 +- vllm/multimodal/inputs.py | 2 +- vllm/v1/engine/core.py | 10 +- vllm/v1/engine/core_client.py | 16 +- vllm/v1/engine/utils.py | 10 +- vllm/v1/serial_utils.py | 70 +++++--- 7 files changed, 221 insertions(+), 154 deletions(-) diff --git a/tests/v1/test_tensor_ipc_queue.py b/tests/v1/test_tensor_ipc_queue.py index 62ff697c2724..5c50a0457b30 100644 --- a/tests/v1/test_tensor_ipc_queue.py +++ b/tests/v1/test_tensor_ipc_queue.py @@ -3,24 +3,30 @@ """Tests for tensor IPC queue functionality.""" +import contextlib import multiprocessing as mp +from multiprocessing.synchronize import Barrier as BarrierType +from multiprocessing.synchronize import Event as EventType from typing import Any import pytest import torch import torch.multiprocessing as torch_mp -from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, TensorIpcData, TensorIpcHandle +from vllm.v1.serial_utils import ( + MsgpackDecoder, + MsgpackEncoder, + TensorIpcData, + TensorIpcHandle, +) @pytest.fixture(scope="module", autouse=True) def setup_multiprocessing(): """Set multiprocessing start method to 'spawn' for compatibility.""" - try: - torch_mp.set_start_method('spawn', force=True) - except RuntimeError: + with contextlib.suppress(RuntimeError): # Already set, which is fine - pass + torch_mp.set_start_method("spawn", force=True) yield @@ -29,7 +35,7 @@ def encoder_process( result_queue: mp.Queue, target_engine: int, tensor_data: dict[str, Any], - ready_event: mp.Event, + ready_event: EventType, ): """Process that encodes and sends CUDA tensors via queue.""" try: @@ -64,25 +70,21 @@ def encoder_process( ) except Exception as e: import traceback + ready_event.set() # Signal even on failure - result_queue.put({ - "success": False, - "error": str(e), - "traceback": traceback.format_exc() - }) + result_queue.put( + {"success": False, "error": str(e), "traceback": traceback.format_exc()} + ) def decoder_process( tensor_queue: torch_mp.Queue, result_queue: mp.Queue, expected_shape: tuple, - encoder_ready: mp.Event, + encoder_ready: EventType, ): """Process that decodes and receives CUDA tensors from queue.""" try: - # Create decoder with tensor queue - decoder = MsgpackDecoder(tensor_queue=tensor_queue) - # Wait for encoder to finish sending if not encoder_ready.wait(timeout=10.0): raise TimeoutError("Encoder did not signal ready") @@ -101,11 +103,10 @@ def decoder_process( ) except Exception as e: import traceback - result_queue.put({ - "success": False, - "error": str(e), - "traceback": traceback.format_exc() - }) + + result_queue.put( + {"success": False, "error": str(e), "traceback": traceback.format_exc()} + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -114,7 +115,7 @@ def test_cuda_tensor_queue_basic(): # Set up queues and synchronization num_engines = 2 tensor_queues = [torch_mp.Queue() for _ in range(num_engines)] - result_queue = mp.Queue() + result_queue: mp.Queue = mp.Queue() encoder_ready = mp.Event() target_engine = 0 @@ -149,8 +150,14 @@ def test_cuda_tensor_queue_basic(): decoder_proc.join(timeout=5.0) # Verify results - assert encoder_result["success"], f"Encoder failed: {encoder_result.get('error')}\n{encoder_result.get('traceback', '')}" - assert decoder_result["success"], f"Decoder failed: {decoder_result.get('error')}\n{decoder_result.get('traceback', '')}" + assert encoder_result["success"], ( + f"Encoder failed: {encoder_result.get('error')}\n" + f"{encoder_result.get('traceback', '')}" + ) + assert decoder_result["success"], ( + f"Decoder failed: {decoder_result.get('error')}\n" + f"{decoder_result.get('traceback', '')}" + ) assert decoder_result["matches_expected"], "Tensor shape mismatch" assert "cuda" in decoder_result["device"], "Tensor not on CUDA device" @@ -168,7 +175,7 @@ def test_cpu_tensor_fallback(): # Verify encoding succeeded assert len(encoded) > 0 assert isinstance(encoded, (list, tuple)) - + # Basic check: no queue should be used, so tensor goes through standard path # This is mainly to ensure no exceptions are raised @@ -240,42 +247,45 @@ def api_server_worker( server_id: int, tensor_queue: torch_mp.Queue, result_queue: mp.Queue, - barrier: mp.Barrier, - retrieval_done: mp.Event, + barrier: BarrierType, + retrieval_done: EventType, ): """Worker simulating an API server sending tensors.""" try: # Each server sends a unique tensor tensor = torch.ones(server_id + 1, server_id + 2) * server_id tensor_id = f"server_{server_id}_tensor" - + # Wait for all servers to be ready barrier.wait() - + # Send tensor using TensorIpcData ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=tensor) tensor_queue.put(ipc_data) - + result_queue.put({"server_id": server_id, "success": True}) - + # Keep process alive until main process has retrieved all tensors # This prevents shared memory handles from being invalidated retrieval_done.wait(timeout=30.0) except Exception as e: import traceback - result_queue.put({ - "server_id": server_id, - "success": False, - "error": str(e), - "traceback": traceback.format_exc() - }) + + result_queue.put( + { + "server_id": server_id, + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + } + ) def test_multiple_api_servers_to_engine(): """Test multiple API servers sending to one engine core via multiprocessing.""" num_api_servers = 3 tensor_queue = torch_mp.Queue() - result_queue = mp.Queue() + result_queue: mp.Queue = mp.Queue() barrier = mp.Barrier(num_api_servers) retrieval_done = mp.Event() @@ -297,7 +307,9 @@ def test_multiple_api_servers_to_engine(): # Verify all servers succeeded for result in results: - assert result["success"], f"Server {result['server_id']} failed: {result.get('error')}" + assert result["success"], ( + f"Server {result['server_id']} failed: {result.get('error')}" + ) # Verify all tensors are in queue received_tensors = [] @@ -306,7 +318,7 @@ def test_multiple_api_servers_to_engine(): received_tensors.append((ipc_data.tensor_id, ipc_data.tensor)) assert len(received_tensors) == num_api_servers - + # Verify tensor content (order may vary with multiprocessing) tensor_by_id = {tid: t for tid, t in received_tensors} for server_id in range(num_api_servers): @@ -314,10 +326,10 @@ def test_multiple_api_servers_to_engine(): assert expected_id in tensor_by_id, f"Missing tensor from server {server_id}" expected_tensor = torch.ones(server_id + 1, server_id + 2) * server_id assert torch.allclose(tensor_by_id[expected_id], expected_tensor) - + # Signal workers that retrieval is complete retrieval_done.set() - + # Wait for all processes to complete for proc in processes: proc.join(timeout=5.0) @@ -326,51 +338,59 @@ def test_multiple_api_servers_to_engine(): def mixed_tensor_encoder_process( tensor_queues: list[torch_mp.Queue], result_queue: mp.Queue, - ready_event: mp.Event, - retrieval_done: mp.Event, + ready_event: EventType, + retrieval_done: EventType, ): - """Process that encodes mixed CPU/CUDA tensors (old behavior: only CUDA via IPC).""" + """Process that encodes mixed CPU/CUDA tensors. + + Old behavior: only CUDA via IPC. + """ try: - # Use old behavior: multimodal_tensor_ipc defaults to True but only CUDA went through - # For this test, we want to test the old behavior where only CUDA uses IPC - encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=False) + # Use old behavior: multimodal_tensor_ipc defaults to True but only CUDA went + # through. For this test, we want to test the old behavior where only CUDA + # uses IPC. + encoder = MsgpackEncoder( + tensor_queues=tensor_queues, multimodal_tensor_ipc=False + ) encoder.set_target_engine(0) # Create only CUDA tensor for IPC (CPU will be serialized) # But actually, let's just send CUDA tensor directly cuda_tensor = torch.randn(4, 5, device="cuda:0") - + # Manually send via IPC to test the mechanism tensor_id = "test_cuda_tensor" cuda_tensor_shared = cuda_tensor.share_memory_() from vllm.v1.serial_utils import TensorIpcData + ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=cuda_tensor_shared) tensor_queues[0].put(ipc_data, timeout=10.0) - + ready_event.set() - - result_queue.put({ - "success": True, - "sent_cuda": True, - }) - + + result_queue.put( + { + "success": True, + "sent_cuda": True, + } + ) + # Keep process alive until decoder has retrieved the tensor retrieval_done.wait(timeout=30.0) except Exception as e: import traceback + ready_event.set() - result_queue.put({ - "success": False, - "error": str(e), - "traceback": traceback.format_exc() - }) + result_queue.put( + {"success": False, "error": str(e), "traceback": traceback.format_exc()} + ) def mixed_tensor_decoder_process( tensor_queue: torch_mp.Queue, result_queue: mp.Queue, - encoder_ready: mp.Event, - retrieval_done: mp.Event, + encoder_ready: EventType, + retrieval_done: EventType, ): """Process that retrieves mixed tensors from queue.""" try: @@ -380,30 +400,31 @@ def mixed_tensor_decoder_process( # Try to get CUDA tensor from queue ipc_data = tensor_queue.get(timeout=5.0) - - result_queue.put({ - "success": True, - "is_cuda": ipc_data.tensor.is_cuda, - "shape": tuple(ipc_data.tensor.shape), - }) - + + result_queue.put( + { + "success": True, + "is_cuda": ipc_data.tensor.is_cuda, + "shape": tuple(ipc_data.tensor.shape), + } + ) + # Signal that retrieval is complete retrieval_done.set() except Exception as e: import traceback + retrieval_done.set() # Signal even on failure - result_queue.put({ - "success": False, - "error": str(e), - "traceback": traceback.format_exc() - }) + result_queue.put( + {"success": False, "error": str(e), "traceback": traceback.format_exc()} + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_mixed_cpu_cuda_tensors(): """Test encoding with mixed CPU and CUDA tensors using multiprocessing.""" tensor_queues = [torch_mp.Queue()] - result_queue = mp.Queue() + result_queue: mp.Queue = mp.Queue() encoder_ready = mp.Event() retrieval_done = mp.Event() @@ -429,12 +450,20 @@ def test_mixed_cpu_cuda_tensors(): decoder_proc.join(timeout=5.0) # Verify encoder succeeded - assert encoder_result["success"], f"Encoder failed: {encoder_result.get('error')}\n{encoder_result.get('traceback', '')}" - + assert encoder_result["success"], ( + f"Encoder failed: {encoder_result.get('error')}\n" + f"{encoder_result.get('traceback', '')}" + ) + # Verify decoder succeeded and got CUDA tensor - assert decoder_result["success"], f"Decoder failed: {decoder_result.get('error')}\n{decoder_result.get('traceback', '')}" + assert decoder_result["success"], ( + f"Decoder failed: {decoder_result.get('error')}\n" + f"{decoder_result.get('traceback', '')}" + ) assert decoder_result["is_cuda"], "Retrieved tensor is not on CUDA" - assert decoder_result["shape"] == (4, 5), f"Unexpected shape: {decoder_result['shape']}" + assert decoder_result["shape"] == (4, 5), ( + f"Unexpected shape: {decoder_result['shape']}" + ) def cpu_tensor_ipc_encoder_process( @@ -442,13 +471,15 @@ def cpu_tensor_ipc_encoder_process( result_queue: mp.Queue, target_engine: int, tensor_shape: tuple, - ready_event: mp.Event, - retrieval_done: mp.Event, + ready_event: EventType, + retrieval_done: EventType, ): """Process that encodes and sends CPU tensors via IPC queue.""" try: # Create encoder with IPC enabled for all tensors - encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=True) + encoder = MsgpackEncoder( + tensor_queues=tensor_queues, multimodal_tensor_ipc=True + ) encoder.set_target_engine(target_engine) # Create a CPU tensor @@ -468,32 +499,28 @@ def cpu_tensor_ipc_encoder_process( "tensor_shape": tuple(tensor.shape), } ) - + # Keep process alive until decoder has retrieved the tensor # This is necessary for CPU tensor shared memory to remain valid retrieval_done.wait(timeout=30.0) except Exception as e: import traceback + ready_event.set() - result_queue.put({ - "success": False, - "error": str(e), - "traceback": traceback.format_exc() - }) + result_queue.put( + {"success": False, "error": str(e), "traceback": traceback.format_exc()} + ) def cpu_tensor_ipc_decoder_process( tensor_queue: torch_mp.Queue, result_queue: mp.Queue, expected_shape: tuple, - encoder_ready: mp.Event, - retrieval_done: mp.Event, + encoder_ready: EventType, + retrieval_done: EventType, ): """Process that decodes and receives CPU tensors from IPC queue.""" try: - # Create decoder with tensor queue - decoder = MsgpackDecoder(tensor_queue=tensor_queue) - # Wait for encoder to finish sending if not encoder_ready.wait(timeout=10.0): raise TimeoutError("Encoder did not signal ready") @@ -511,17 +538,16 @@ def cpu_tensor_ipc_decoder_process( "is_cpu": ipc_data.tensor.device.type == "cpu", } ) - + # Signal that retrieval is complete retrieval_done.set() except Exception as e: import traceback + retrieval_done.set() # Signal even on failure - result_queue.put({ - "success": False, - "error": str(e), - "traceback": traceback.format_exc() - }) + result_queue.put( + {"success": False, "error": str(e), "traceback": traceback.format_exc()} + ) def test_cpu_tensor_ipc(): @@ -529,7 +555,7 @@ def test_cpu_tensor_ipc(): # Set up queues and synchronization num_engines = 2 tensor_queues = [torch_mp.Queue() for _ in range(num_engines)] - result_queue = mp.Queue() + result_queue: mp.Queue = mp.Queue() encoder_ready = mp.Event() retrieval_done = mp.Event() @@ -553,7 +579,13 @@ def test_cpu_tensor_ipc(): # Start decoder process decoder_proc = mp.Process( target=cpu_tensor_ipc_decoder_process, - args=(tensor_queues[target_engine], result_queue, tensor_shape, encoder_ready, retrieval_done), + args=( + tensor_queues[target_engine], + result_queue, + tensor_shape, + encoder_ready, + retrieval_done, + ), ) decoder_proc.start() @@ -565,8 +597,14 @@ def test_cpu_tensor_ipc(): decoder_proc.join(timeout=5.0) # Verify results - assert encoder_result["success"], f"Encoder failed: {encoder_result.get('error')}\n{encoder_result.get('traceback', '')}" - assert decoder_result["success"], f"Decoder failed: {decoder_result.get('error')}\n{decoder_result.get('traceback', '')}" + assert encoder_result["success"], ( + f"Encoder failed: {encoder_result.get('error')}\n" + f"{encoder_result.get('traceback', '')}" + ) + assert decoder_result["success"], ( + f"Decoder failed: {decoder_result.get('error')}\n" + f"{decoder_result.get('traceback', '')}" + ) assert decoder_result["matches_expected"], "Tensor shape mismatch" assert decoder_result["is_cpu"], "Tensor not on CPU device" @@ -574,7 +612,7 @@ def test_cpu_tensor_ipc(): def test_ipc_disabled_mode(): """Test that IPC is disabled when multimodal_tensor_ipc=False.""" tensor_queues = [torch_mp.Queue()] - + # Create encoder with IPC disabled encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=False) encoder.set_target_engine(0) @@ -597,7 +635,9 @@ def test_ipc_disabled_mode(): cuda_tensor = torch.randn(4, 5, device="cuda:0") encoded_cuda = encoder.encode({"cuda_tensor": cuda_tensor}) assert len(encoded_cuda) > 0 - assert tensor_queues[0].empty(), "Tensor queue should be empty for CUDA tensor when IPC is disabled" + assert tensor_queues[0].empty(), ( + "Tensor queue should be empty for CUDA tensor when IPC is disabled" + ) def test_mixed_cpu_cuda_with_ipc_enabled(): @@ -606,7 +646,7 @@ def test_mixed_cpu_cuda_with_ipc_enabled(): pytest.skip("CUDA not available") tensor_queues = [torch_mp.Queue()] - + # Create encoder with IPC enabled for all tensors encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=True) encoder.set_target_engine(0) @@ -615,7 +655,7 @@ def test_mixed_cpu_cuda_with_ipc_enabled(): assert encoder.multimodal_tensor_ipc is True, "IPC should be enabled" assert encoder.tensor_queues is not None, "Tensor queues should be set" assert encoder.target_engine_index == 0, "Target engine should be set" - - # Note: Actual IPC transfer only works across processes (tested in test_cpu_tensor_ipc) - # This test just verifies the configuration is correct + # Note: Actual IPC transfer only works across processes + # (tested in test_cpu_tensor_ipc) + # This test just verifies the configuration is correct diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 366805d33834..715cafc1d3cd 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1006,10 +1006,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action=argparse.BooleanOptionalAction, default=None, help="Enable IPC (inter-process communication) for multimodal tensors. " - "When enabled, all multimodal tensors (CUDA and CPU) are transferred " - "via torch.multiprocessing shared memory for zero-copy IPC. " - "When disabled, all tensors use standard serialization. " - "If not specified, defaults to VLLM_MULTIMODAL_TENSOR_IPC env var (default: False).", + "When enabled, all multimodal tensors (CUDA and CPU) are " + "transferred via torch.multiprocessing shared memory for " + "zero-copy IPC. When disabled, all tensors use standard " + "serialization. If not specified, defaults to " + "VLLM_MULTIMODAL_TENSOR_IPC env var (default: False).", dest="multimodal_tensor_ipc", ) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index a5e802e328a7..814e92dd41ff 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -575,7 +575,7 @@ def _shape_before_after(tensor: torch.Tensor): (*shape_before, shape_concat, *shape_after), dtype=batch[0].dtype, device=batch[0].device, - pin_memory=pin_memory and batch[0].device.type == 'cpu', + pin_memory=pin_memory and batch[0].device.type == "cpu", ) return torch.concat(batch, dim=self.dim, out=out) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f68785b78a34..dfa50e336270 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -889,7 +889,7 @@ def startup_handshake( # Store tensor_queue_index for engine to access addresses = init_message.addresses addresses.tensor_queue_index = init_message.tensor_queue_index - + return addresses @staticmethod @@ -898,7 +898,7 @@ def run_engine_core( dp_rank: int = 0, local_dp_rank: int = 0, tensor_queues: list[Any] | None = None, - **kwargs + **kwargs, ): """Launch EngineCore busy loop in background process.""" @@ -936,7 +936,7 @@ def signal_handler(signum, frame): if data_parallel and vllm_config.model_config.is_moe: # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank - engine_core = DPEngineCoreProc(*args, tensor_queues=tensor_queues, **kwargs) + engine_core = DPEngineCoreProc(**kwargs, tensor_queues=tensor_queues) else: # Non-MoE DP ranks are completely independent, so treat like DP=1. # Note that parallel_config.data_parallel_index will still reflect @@ -944,7 +944,9 @@ def signal_handler(signum, frame): parallel_config.data_parallel_size = 1 parallel_config.data_parallel_size_local = 1 parallel_config.data_parallel_rank = 0 - engine_core = EngineCoreProc(*args, tensor_queues=tensor_queues, **kwargs) + engine_core = EngineCoreProc( + **kwargs, engine_index=dp_rank, tensor_queues=tensor_queues + ) engine_core.run_busy_loop() diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 90dc4be18afa..4041e988fcea 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -360,6 +360,7 @@ class BackgroundResources: output_queue_task: asyncio.Task | None = None stats_update_task: asyncio.Task | None = None shutdown_path: str | None = None + tensor_queues: list[Any] | None = None # Set if any of the engines are dead. Here so that the output # processing threads can access it without holding a ref to the client. @@ -450,7 +451,7 @@ def __init__( client_addresses: dict[str, str] | None = None, ): self.vllm_config = vllm_config - + # ZMQ setup. sync_ctx = zmq.Context(io_threads=2) self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx @@ -466,14 +467,14 @@ def __init__( self.engines_running = False self.stats_update_address: str | None = None - tensor_queues = None + tensor_queues: list[Any] | None = None if client_addresses: # Engines are managed externally to this client. input_address = client_addresses["input_address"] output_address = client_addresses["output_address"] self.stats_update_address = client_addresses.get("stats_update_address") # Tensor queues passed via client_addresses for multi-API-server case - tensor_queues = client_addresses.get("tensor_queues") + tensor_queues = client_addresses.get("tensor_queues") # type: ignore[assignment] else: # Engines are managed by this client. with launch_core_engines(vllm_config, executor_class, log_stats) as ( @@ -497,14 +498,17 @@ def __init__( # Get IPC config from multimodal_config, falling back to env var multimodal_tensor_ipc = True # Default if vllm_config.model_config.multimodal_config is not None: - mm_ipc = vllm_config.model_config.multimodal_config.multimodal_tensor_ipc + mm_ipc = ( + vllm_config.model_config.multimodal_config.multimodal_tensor_ipc + ) if mm_ipc is not None: multimodal_tensor_ipc = mm_ipc else: # Fall back to environment variable from vllm import envs + multimodal_tensor_ipc = envs.VLLM_MULTIMODAL_TENSOR_IPC - + self.encoder = MsgpackEncoder( tensor_queues=tensor_queues, multimodal_tensor_ipc=multimodal_tensor_ipc, @@ -927,7 +931,7 @@ def _send_input( # Set target engine index for CUDA tensor routing engine_index = int.from_bytes(engine, "little") self.encoder.set_target_engine(engine_index) - + message = (request_type.value, *self.encoder.encode(request)) return self._send_input_message(message, engine, request) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 39503ba539bb..eb24fe368499 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -817,10 +817,8 @@ def launch_core_engines( # and engine cores. One queue per engine core. # Use torch.multiprocessing for tensor sharing via IPC/shared memory. # Set start method to 'spawn' for compatibility with multiprocessing. - torch_mp.set_start_method('spawn', force=True) - tensor_queues: list[torch_mp.Queue] = [ - torch_mp.Queue() for _ in range(dp_size) - ] + torch_mp.set_start_method("spawn", force=True) + tensor_queues: list[torch_mp.Queue] = [torch_mp.Queue() for _ in range(dp_size)] # Set up input and output addresses. addresses = EngineZmqAddresses( @@ -1062,7 +1060,9 @@ def wait_for_engine_startup( "_data_parallel_master_port_list", "data_parallel_size", ) - } if coordinated_dp else {}, + } + if coordinated_dp + else {}, tensor_queue_index=eng_index, ) ) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 9d5edecbff31..ed0a424f57ef 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -46,10 +46,11 @@ class TensorIpcData: """ Data sent via torch.multiprocessing.Queue for zero-copy IPC. - + Contains the tensor_id and the actual tensor. The tensor is shared in memory (GPU or CPU) for efficient inter-process communication. """ + tensor_id: str tensor: torch.Tensor @@ -58,12 +59,13 @@ class TensorIpcData: class TensorIpcHandle: """ Handle for a tensor sent via IPC queue (zero-copy transfer). - + Contains only metadata about the tensor. This is serialized via msgpack and used by the decoder to retrieve the actual tensor from the queue. The actual tensor is sent separately via torch.multiprocessing.Queue as TensorIpcData. Works for both CUDA and CPU tensors. """ + tensor_id: str shape: list[int] dtype: str @@ -148,7 +150,7 @@ class MsgpackEncoder: By default, arrays below 256B are serialized inline Larger will get sent via dedicated messages. Note that this is a per-tensor limit. - + When multimodal_tensor_ipc is enabled and tensor_queues is provided, all multimodal tensors (CUDA and CPU) will be sent via torch.multiprocessing.Queue for zero-copy IPC instead of serialization. @@ -218,7 +220,7 @@ def enc_hook(self, obj: Any) -> Any: int(v) if v is not None else None for v in (obj.start, obj.stop, obj.step) ) - + if isinstance(obj, MultiModalKwargsItem): return self._encode_mm_item(obj) @@ -274,7 +276,7 @@ def _encode_tensor( self, obj: torch.Tensor ) -> tuple[str, tuple[int, ...], int | memoryview] | dict[str, Any]: assert self.aux_buffers is not None - + # Check if we should use IPC for this tensor # IPC is used when: multimodal_tensor_ipc is enabled, queues are available, # and we have a target engine @@ -288,28 +290,34 @@ def _encode_tensor( # Generate unique tensor ID tensor_id = f"{id(self)}_{self._tensor_id_counter}" self._tensor_id_counter += 1 - + try: # Move tensor to shared memory for IPC # This is required for proper inter-process communication if not obj.is_shared(): obj = obj.share_memory_() - + # Put TensorIpcData (tensor_id + tensor) into the target engine's queue target_queue = self.tensor_queues[self.target_engine_index] ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=obj) # Use a timeout to avoid blocking indefinitely target_queue.put(ipc_data, timeout=10.0) - + logger.debug( - "Sent tensor %s (shape=%s, device=%s) to engine %d via IPC queue (shared memory)", + "Sent tensor %s (shape=%s, device=%s) to engine %d " + "via IPC queue (shared memory)", tensor_id, obj.shape, obj.device, self.target_engine_index, ) - - return TensorIpcHandle(tensor_id=tensor_id, shape=list(obj.shape), dtype=str(obj.dtype).removeprefix("torch."), device=str(obj.device)) + + return TensorIpcHandle( + tensor_id=tensor_id, + shape=list(obj.shape), + dtype=str(obj.dtype).removeprefix("torch."), + device=str(obj.device), + ) except Exception as e: logger.warning( "Failed to send tensor via IPC queue: %s. " @@ -317,19 +325,19 @@ def _encode_tensor( e, ) # Fall through to standard serialization - - + # Standard serialization fallback # For CUDA tensors without IPC support, we need to move to CPU first if obj.is_cuda: if self.multimodal_tensor_ipc and self.tensor_queues is not None: # Only warn if IPC was expected but unavailable logger.warning( - "CUDA tensor without IPC support encountered (no target engine set). " - "Moving to CPU for serialization. This will be slow." + "CUDA tensor without IPC support encountered " + "(no target engine set). Moving to CPU for " + "serialization. This will be slow." ) obj = obj.cpu() - + # view the tensor as a contiguous 1D array of bytes arr_data = tensor_data(obj) if obj.nbytes < self.size_threshold: @@ -387,9 +395,10 @@ class MsgpackDecoder: Note that unlike vanilla `msgspec` Decoders, this interface is generally not thread-safe when encoding tensors / numpy arrays. - - For multimodal tensors sent via torch.multiprocessing.Queue (when IPC is enabled), - they will be retrieved from the queue during decoding. Works for both CUDA and CPU tensors. + + For multimodal tensors sent via torch.multiprocessing.Queue (when IPC + is enabled), they will be retrieved from the queue during decoding. + Works for both CUDA and CPU tensors. """ def __init__( @@ -429,7 +438,8 @@ def dec_hook(self, t: type, obj: Any) -> Any: if issubclass(t, np.ndarray): return self._decode_ndarray(obj) if issubclass(t, TensorIpcHandle): - # msgspec deserializes dataclasses to dicts, so convert to TensorIpcHandle + # msgspec deserializes dataclasses to dicts, so convert + # to TensorIpcHandle if isinstance(obj, dict): obj = TensorIpcHandle(**obj) return self._decode_cuda_queue_tensor(obj) @@ -499,10 +509,13 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor: arr = arr.pin_memory() if self.pin_tensors else arr.clone() # Convert back to proper shape & type return arr.view(torch_dtype).view(shape) - + def _decode_cuda_queue_tensor(self, handle: TensorIpcHandle) -> torch.Tensor: - """Retrieve a tensor from the torch.multiprocessing.Queue (works for CUDA and CPU).""" - + """Retrieve a tensor from torch.multiprocessing.Queue. + + Works for CUDA and CPU. + """ + # Drain all available tensors. We save them regardless if this is the one # we're waiting for as they may arrive out of order from multiple producers. while handle.tensor_id not in self._tensor_buffer: @@ -549,8 +562,15 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if isinstance(obj, TensorIpcHandle): return self._decode_cuda_queue_tensor(obj) # Check if this is a dict that represents a TensorIpcHandle - # (msgspec serializes dataclasses as dicts without type info in nested structures) - if isinstance(obj, dict) and 'tensor_id' in obj and 'shape' in obj and 'dtype' in obj and 'device' in obj: + # (msgspec serializes dataclasses as dicts without type info + # in nested structures) + if ( + isinstance(obj, dict) + and "tensor_id" in obj + and "shape" in obj + and "dtype" in obj + and "device" in obj + ): # Convert dict to TensorIpcHandle and decode it handle = TensorIpcHandle(**obj) return self._decode_cuda_queue_tensor(handle) From a16a093642c85e0629e8f7a6505e9fd2653343bb Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Sun, 11 Jan 2026 05:26:03 +0000 Subject: [PATCH 10/33] Handle orphaned tensors during timeout Signed-off-by: Brandon Pelfrey --- tests/v1/test_tensor_ipc_queue.py | 140 ++++++++++++++++++++++++++++-- vllm/multimodal/inputs.py | 2 +- vllm/v1/engine/core.py | 16 ++++ vllm/v1/engine/core_client.py | 16 ++++ vllm/v1/serial_utils.py | 93 +++++++++++++++++--- 5 files changed, 249 insertions(+), 18 deletions(-) diff --git a/tests/v1/test_tensor_ipc_queue.py b/tests/v1/test_tensor_ipc_queue.py index 5c50a0457b30..d08c9ef8edd4 100644 --- a/tests/v1/test_tensor_ipc_queue.py +++ b/tests/v1/test_tensor_ipc_queue.py @@ -208,7 +208,7 @@ def test_decoder_buffer_management(): } for tensor_id, tensor in tensors.items(): - ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=tensor) + ipc_data = TensorIpcData(request_id=None, tensor_id=tensor_id, tensor=tensor) tensor_queue.put(ipc_data) # Create decoder @@ -216,6 +216,7 @@ def test_decoder_buffer_management(): # Request tensor_3 (should buffer tensor_1 and tensor_2) handle = TensorIpcHandle( + request_id=None, tensor_id="tensor_3", shape=[6, 7], dtype="float32", @@ -225,12 +226,13 @@ def test_decoder_buffer_management(): result = decoder._decode_cuda_queue_tensor(handle) assert result.shape == (6, 7) - # Verify buffer has tensor_1 and tensor_2 - assert "tensor_1" in decoder._tensor_buffer - assert "tensor_2" in decoder._tensor_buffer + # Verify buffer has tensor_1 and tensor_2 using tuple keys + assert (None, "tensor_1") in decoder._tensor_buffer + assert (None, "tensor_2") in decoder._tensor_buffer # Request buffered tensor handle2 = TensorIpcHandle( + request_id=None, tensor_id="tensor_1", shape=[2, 3], dtype="float32", @@ -240,7 +242,7 @@ def test_decoder_buffer_management(): result2 = decoder._decode_cuda_queue_tensor(handle2) assert result2.shape == (2, 3) # tensor_1 should be removed from buffer - assert "tensor_1" not in decoder._tensor_buffer + assert (None, "tensor_1") not in decoder._tensor_buffer def api_server_worker( @@ -260,7 +262,7 @@ def api_server_worker( barrier.wait() # Send tensor using TensorIpcData - ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=tensor) + ipc_data = TensorIpcData(request_id=None, tensor_id=tensor_id, tensor=tensor) tensor_queue.put(ipc_data) result_queue.put({"server_id": server_id, "success": True}) @@ -363,7 +365,9 @@ def mixed_tensor_encoder_process( cuda_tensor_shared = cuda_tensor.share_memory_() from vllm.v1.serial_utils import TensorIpcData - ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=cuda_tensor_shared) + ipc_data = TensorIpcData( + request_id=None, tensor_id=tensor_id, tensor=cuda_tensor_shared + ) tensor_queues[0].put(ipc_data, timeout=10.0) ready_event.set() @@ -659,3 +663,125 @@ def test_mixed_cpu_cuda_with_ipc_enabled(): # Note: Actual IPC transfer only works across processes # (tested in test_cpu_tensor_ipc) # This test just verifies the configuration is correct + + +def test_tensor_cleanup_on_abort(): + """Test that orphaned tensors are cleaned up when requests are aborted.""" + # Create a tensor queue (not actually used in this simplified test) + tensor_queue = torch_mp.Queue() + + # Create decoder + decoder = MsgpackDecoder(dict, tensor_queue=tensor_queue) + + # Simulate tensors in the buffer for multiple requests + request_ids = ["req1", "req2", "req3"] + + for request_id in request_ids: + # Simulate 2 tensors per request using tuple keys + for i in range(2): + tensor_id = f"encoder_{i}" + tensor_key = (request_id, tensor_id) + tensor = torch.randn(10, 10) + + # Manually add to buffer and tracking (simulating decode behavior) + decoder._tensor_buffer[tensor_key] = tensor + + if request_id not in decoder._request_to_tensors: + decoder._request_to_tensors[request_id] = [] + decoder._request_to_tensors[request_id].append(tensor_key) + + # Verify tensors are in the buffer + initial_buffer_size = len(decoder._tensor_buffer) + assert initial_buffer_size == 6, "Buffer should contain 6 tensors (2 per request)" + + # Verify request tracking + assert len(decoder._request_to_tensors) == 3, "Should track 3 requests" + assert len(decoder._request_to_tensors["req1"]) == 2, "req1 should have 2 tensors" + + # Cleanup tensors for req1 + removed_count_1 = decoder.cleanup_request_tensors("req1") + assert removed_count_1 == 2, "Should have removed 2 tensors for req1" + assert len(decoder._tensor_buffer) == 4, "Buffer should have 4 tensors left" + assert "req1" not in decoder._request_to_tensors, ( + "req1 should be removed from tracking" + ) + + # Cleanup tensors for req2 + removed_count_2 = decoder.cleanup_request_tensors("req2") + assert removed_count_2 == 2, "Should have removed 2 tensors for req2" + assert len(decoder._tensor_buffer) == 2, "Buffer should have 2 tensors left" + + # Cleanup req3 + removed_count_3 = decoder.cleanup_request_tensors("req3") + assert removed_count_3 == 2, "Should have removed 2 tensors for req3" + + # Verify all tensors are cleaned up + assert len(decoder._tensor_buffer) == 0, "Buffer should be empty" + assert len(decoder._request_to_tensors) == 0, "Request tracking should be empty" + + # Cleanup for non-existent request should return 0 + removed_count_4 = decoder.cleanup_request_tensors("nonexistent") + assert removed_count_4 == 0, "Should return 0 for non-existent request" + + +def test_tensor_cleanup_after_decode(): + """Test that tensors are removed from tracking after successful decode.""" + # Create a tensor queue + tensor_queue = torch_mp.Queue() + + # Create and encode a tensor + tensor = torch.randn(5, 5) + # Move to shared memory for IPC + if not tensor.is_shared(): + tensor.share_memory_() + + # Manually create a TensorIpcData and put it in the queue + request_id = "test_req" + tensor_id = "encoder_0" + ipc_data = TensorIpcData(request_id=request_id, tensor_id=tensor_id, tensor=tensor) + tensor_queue.put(ipc_data) + + # Create decoder + decoder = MsgpackDecoder(dict, tensor_queue=tensor_queue) + + # Create a TensorIpcHandle to decode + handle = TensorIpcHandle( + request_id=request_id, + tensor_id=tensor_id, + shape=list(tensor.shape), + dtype=str(tensor.dtype).removeprefix("torch."), + device=str(tensor.device), + ) + + # Decode the tensor - this should retrieve it from the queue + decoded_tensor = decoder._decode_cuda_queue_tensor(handle) + + # Verify the tensor was decoded + assert decoded_tensor.shape == tensor.shape, "Decoded tensor should match shape" + + # Verify the tensor was removed from buffer after decode + tensor_key = (request_id, tensor_id) + assert tensor_key not in decoder._tensor_buffer, ( + "Tensor should be removed from buffer" + ) + + # Verify the request tracking was cleaned up + assert request_id not in decoder._request_to_tensors, ( + "Request tracking should be cleaned up" + ) + + +def test_request_context_in_encoder(): + """Test that encoder properly sets and clears request context.""" + encoder = MsgpackEncoder() + + # Initially no request context + assert encoder._current_request_id is None + + # Set request context + encoder.set_request_context("req123") + assert encoder._current_request_id == "req123" + + # Clear request context + encoder.set_request_context(None) + assert encoder._current_request_id is None diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 814e92dd41ff..7b12158763c3 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -575,7 +575,7 @@ def _shape_before_after(tensor: torch.Tensor): (*shape_before, shape_concat, *shape_after), dtype=batch[0].dtype, device=batch[0].device, - pin_memory=pin_memory and batch[0].device.type == "cpu", + pin_memory=pin_memory, ) return torch.concat(batch, dim=self.dim, out=out) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index dfa50e336270..8b5c20672a9d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -660,6 +660,9 @@ def __init__( identity = self.engine_index.to_bytes(length=2, byteorder="little") self.engines_running = False + # Decoder for cleanup (set by process_input_sockets thread) + self.tensor_decoder: MsgpackDecoder | None = None + with self._perform_handshakes( handshake_address, identity, @@ -1024,6 +1027,16 @@ def _process_engine_step(self) -> bool: return model_executed + def abort_requests(self, request_ids: list[str]): + """Abort requests and cleanup any orphaned tensors.""" + # First, abort the requests in the scheduler + super().abort_requests(request_ids) + + # Then cleanup any orphaned tensors for these requests + if self.tensor_decoder is not None: + for request_id in request_ids: + self.tensor_decoder.cleanup_request_tensors(request_id) + def _handle_client_request( self, request_type: EngineCoreRequestType, request: Any ) -> None: @@ -1102,6 +1115,9 @@ def process_input_sockets( ) generic_decoder = MsgpackDecoder(tensor_queue=self.tensor_queue) + # Store decoder reference for tensor cleanup on abort + self.tensor_decoder = add_request_decoder + with ExitStack() as stack, zmq.Context() as ctx: input_sockets = [ stack.enter_context( diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 4041e988fcea..afc4566bdd95 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -746,9 +746,17 @@ def get_output(self) -> EngineCoreOutputs: def _send_input(self, request_type: EngineCoreRequestType, request: Any): self.ensure_alive() self.free_pending_messages() + + # Set request context if this is an ADD request with a request_id + if request_type == EngineCoreRequestType.ADD and hasattr(request, "request_id"): + self.encoder.set_request_context(request.request_id) + # (Identity, RequestType, SerializedRequest) msg = (self.core_engine, request_type.value, *self.encoder.encode(request)) + # Clear request context after encoding + self.encoder.set_request_context(None) + if len(msg) <= 3: # No auxiliary buffers => no tensor backing buffers in request. self.input_socket.send_multipart(msg, copy=False) @@ -932,7 +940,15 @@ def _send_input( engine_index = int.from_bytes(engine, "little") self.encoder.set_target_engine(engine_index) + # Set request context if this is an ADD request with a request_id + if request_type == EngineCoreRequestType.ADD and hasattr(request, "request_id"): + self.encoder.set_request_context(request.request_id) + message = (request_type.value, *self.encoder.encode(request)) + + # Clear request context after encoding + self.encoder.set_request_context(None) + return self._send_input_message(message, engine, request) def _send_input_message( diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index ed0a424f57ef..c836336db088 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -47,10 +47,11 @@ class TensorIpcData: """ Data sent via torch.multiprocessing.Queue for zero-copy IPC. - Contains the tensor_id and the actual tensor. The tensor is shared + Contains the request_id, tensor_id and the actual tensor. The tensor is shared in memory (GPU or CPU) for efficient inter-process communication. """ + request_id: str | None tensor_id: str tensor: torch.Tensor @@ -66,6 +67,7 @@ class TensorIpcHandle: as TensorIpcData. Works for both CUDA and CPU tensors. """ + request_id: str | None tensor_id: str shape: list[int] dtype: str @@ -178,6 +180,8 @@ def __init__( self.target_engine_index: int | None = None # Counter for generating unique tensor IDs self._tensor_id_counter = 0 + # Current request ID being encoded (for associating tensors with requests) + self._current_request_id: str | None = None if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() @@ -185,6 +189,10 @@ def set_target_engine(self, engine_index: int | None) -> None: """Set the target engine index for routing multimodal tensors to IPC queues.""" self.target_engine_index = engine_index + def set_request_context(self, request_id: str | None) -> None: + """Set the current request ID being encoded (for tensor association).""" + self._current_request_id = request_id + def encode(self, obj: Any) -> Sequence[bytestr]: try: self.aux_buffers = bufs = [b""] @@ -287,7 +295,7 @@ def _encode_tensor( ): # Send tensor via torch.multiprocessing.Queue for zero-copy IPC # This works for both CUDA and CPU tensors - # Generate unique tensor ID + # Generate unique tensor ID (without request ID embedded) tensor_id = f"{id(self)}_{self._tensor_id_counter}" self._tensor_id_counter += 1 @@ -297,22 +305,27 @@ def _encode_tensor( if not obj.is_shared(): obj = obj.share_memory_() - # Put TensorIpcData (tensor_id + tensor) into the target engine's queue target_queue = self.tensor_queues[self.target_engine_index] - ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=obj) + ipc_data = TensorIpcData( + request_id=self._current_request_id, + tensor_id=tensor_id, + tensor=obj, + ) # Use a timeout to avoid blocking indefinitely target_queue.put(ipc_data, timeout=10.0) logger.debug( - "Sent tensor %s (shape=%s, device=%s) to engine %d " + "Sent tensor %s for request %s (shape=%s, device=%s) to engine %d " "via IPC queue (shared memory)", tensor_id, + self._current_request_id, obj.shape, obj.device, self.target_engine_index, ) return TensorIpcHandle( + request_id=self._current_request_id, tensor_id=tensor_id, shape=list(obj.shape), dtype=str(obj.dtype).removeprefix("torch."), @@ -417,8 +430,10 @@ def __init__( # Tensor IPC queue for receiving multimodal tensors from API servers self.tensor_queue = tensor_queue # Buffer for temporarily storing tensors retrieved from queue - # that don't match the current request - self._tensor_buffer: dict[str, torch.Tensor] = {} + # Keys are tuples: (request_id, tensor_id) or (None, tensor_id) for legacy + self._tensor_buffer: dict[tuple[str | None, str], torch.Tensor] = {} + # Mapping from request_id to list of tensor keys for cleanup + self._request_to_tensors: dict[str, list[tuple[str | None, str]]] = {} if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() @@ -516,13 +531,40 @@ def _decode_cuda_queue_tensor(self, handle: TensorIpcHandle) -> torch.Tensor: Works for CUDA and CPU. """ + # Create lookup key from handle + lookup_key = (handle.request_id, handle.tensor_id) + # Drain all available tensors. We save them regardless if this is the one # we're waiting for as they may arrive out of order from multiple producers. - while handle.tensor_id not in self._tensor_buffer: + while lookup_key not in self._tensor_buffer: ipc_data: TensorIpcData = self.tensor_queue.get(timeout=10.0) - self._tensor_buffer[ipc_data.tensor_id] = ipc_data.tensor - tensor = self._tensor_buffer.pop(handle.tensor_id) + # Store tensor with tuple key (request_id, tensor_id) + tensor_key = (ipc_data.request_id, ipc_data.tensor_id) + self._tensor_buffer[tensor_key] = ipc_data.tensor + + # Track which request this tensor belongs to for cleanup + if ipc_data.request_id is not None: + if ipc_data.request_id not in self._request_to_tensors: + self._request_to_tensors[ipc_data.request_id] = [] + self._request_to_tensors[ipc_data.request_id].append(tensor_key) + + # Retrieve and remove tensor from buffer + tensor = self._tensor_buffer.pop(lookup_key) + + # Remove from request tracking when consumed + if ( + handle.request_id is not None + and handle.request_id in self._request_to_tensors + ): + try: + self._request_to_tensors[handle.request_id].remove(lookup_key) + if not self._request_to_tensors[handle.request_id]: + del self._request_to_tensors[handle.request_id] + except ValueError: + # Tensor was already removed, ignore + pass + return tensor def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems: @@ -572,6 +614,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: and "device" in obj ): # Convert dict to TensorIpcHandle and decode it + # Handle both new format (with request_id) and old format (without) handle = TensorIpcHandle(**obj) return self._decode_cuda_queue_tensor(handle) if not isinstance(obj, list): @@ -586,6 +629,36 @@ def _decode_nested_slices(self, obj: Any) -> Any: return slice(*obj) return [self._decode_nested_slices(x) for x in obj] + def cleanup_request_tensors(self, request_id: str) -> int: + """Remove all orphaned tensors associated with a request. + + This should be called when a request is aborted, times out, or fails + to ensure tensors in the buffer don't accumulate indefinitely. + + Args: + request_id: The request ID whose tensors should be cleaned up. + + Returns: + The number of tensors that were removed from the buffer. + """ + if request_id not in self._request_to_tensors: + return 0 + + tensor_keys = self._request_to_tensors.pop(request_id) + removed_count = 0 + + for tensor_key in tensor_keys: + if tensor_key in self._tensor_buffer: + del self._tensor_buffer[tensor_key] + removed_count += 1 + logger.debug( + "Cleaned up orphaned tensor %s for request %s", + tensor_key[1], # Just log the tensor_id part + request_id, + ) + + return removed_count + def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_RAW_VIEW: return data From 2ab00746f7ed2d1eec5dba98f2346372428edf8c Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Sun, 11 Jan 2026 05:37:00 +0000 Subject: [PATCH 11/33] remove references to maximum_concurrent_videos Signed-off-by: Brandon Pelfrey --- vllm/config/model.py | 3 --- vllm/engine/arg_utils.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index 962ca3caccb9..afc35436d1d6 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -311,7 +311,6 @@ class ModelConfig: interleave_mm_strings: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None video_pruning_rate: InitVar[float | None] = None - maximum_concurrent_videos: InitVar[int | None] = None multimodal_tensor_ipc: InitVar[bool | None] = None def compute_hash(self) -> str: @@ -427,7 +426,6 @@ def __post_init__( interleave_mm_strings: bool | None, skip_mm_profiling: bool | None, video_pruning_rate: float | None, - maximum_concurrent_videos: int | None, multimodal_tensor_ipc: bool | None, ) -> None: # Keep set served_model_name before maybe_model_redirect(self.model) @@ -592,7 +590,6 @@ def __post_init__( interleave_mm_strings=interleave_mm_strings, skip_mm_profiling=skip_mm_profiling, video_pruning_rate=video_pruning_rate, - max_concurrent_videos=maximum_concurrent_videos, multimodal_tensor_ipc=multimodal_tensor_ipc, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 715cafc1d3cd..714bddd6205b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -474,7 +474,6 @@ class EngineArgs: io_processor_plugin: str | None = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling video_pruning_rate: float = MultiModalConfig.video_pruning_rate - maximum_concurrent_videos: int | None = MultiModalConfig.max_concurrent_videos multimodal_tensor_ipc: bool | None = MultiModalConfig.multimodal_tensor_ipc # LoRA fields enable_lora: bool = False @@ -1290,7 +1289,6 @@ def create_model_config(self) -> ModelConfig: override_attention_dtype=self.override_attention_dtype, logits_processors=self.logits_processors, video_pruning_rate=self.video_pruning_rate, - maximum_concurrent_videos=self.maximum_concurrent_videos, multimodal_tensor_ipc=self.multimodal_tensor_ipc, io_processor_plugin=self.io_processor_plugin, ) From 1107bcc16709a6e5132b6c34a8df15a58e2bfb33 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Mon, 12 Jan 2026 05:18:13 +0000 Subject: [PATCH 12/33] Handle race condition between tensor cleanup and decode threads Signed-off-by: Brandon Pelfrey --- vllm/v1/serial_utils.py | 100 ++++++++++++++++++++++------------------ 1 file changed, 56 insertions(+), 44 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index c836336db088..ab5bf6f214d9 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -4,6 +4,7 @@ import dataclasses import importlib import pickle +import threading from collections.abc import Callable, Sequence from functools import partial from inspect import isclass @@ -434,6 +435,8 @@ def __init__( self._tensor_buffer: dict[tuple[str | None, str], torch.Tensor] = {} # Mapping from request_id to list of tensor keys for cleanup self._request_to_tensors: dict[str, list[tuple[str | None, str]]] = {} + # Lock to synchronize access between cleanup and decode threads + self._buffer_lock = threading.Lock() if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() @@ -536,36 +539,44 @@ def _decode_cuda_queue_tensor(self, handle: TensorIpcHandle) -> torch.Tensor: # Drain all available tensors. We save them regardless if this is the one # we're waiting for as they may arrive out of order from multiple producers. - while lookup_key not in self._tensor_buffer: + while True: + # Check if tensor is already in buffer (with lock) + with self._buffer_lock: + if lookup_key in self._tensor_buffer: + # Retrieve and remove tensor from buffer + tensor = self._tensor_buffer.pop(lookup_key) + + # Remove from request tracking when consumed + if ( + handle.request_id is not None + and handle.request_id in self._request_to_tensors + ): + try: + self._request_to_tensors[handle.request_id].remove( + lookup_key + ) + if not self._request_to_tensors[handle.request_id]: + del self._request_to_tensors[handle.request_id] + except ValueError: + # Tensor was already removed, ignore + pass + + return tensor + + # Release lock while waiting on queue (important to avoid blocking cleanup) ipc_data: TensorIpcData = self.tensor_queue.get(timeout=10.0) - # Store tensor with tuple key (request_id, tensor_id) - tensor_key = (ipc_data.request_id, ipc_data.tensor_id) - self._tensor_buffer[tensor_key] = ipc_data.tensor + # Store the received tensor (with lock) + with self._buffer_lock: + # Store tensor with tuple key (request_id, tensor_id) + tensor_key = (ipc_data.request_id, ipc_data.tensor_id) + self._tensor_buffer[tensor_key] = ipc_data.tensor - # Track which request this tensor belongs to for cleanup - if ipc_data.request_id is not None: - if ipc_data.request_id not in self._request_to_tensors: - self._request_to_tensors[ipc_data.request_id] = [] - self._request_to_tensors[ipc_data.request_id].append(tensor_key) - - # Retrieve and remove tensor from buffer - tensor = self._tensor_buffer.pop(lookup_key) - - # Remove from request tracking when consumed - if ( - handle.request_id is not None - and handle.request_id in self._request_to_tensors - ): - try: - self._request_to_tensors[handle.request_id].remove(lookup_key) - if not self._request_to_tensors[handle.request_id]: - del self._request_to_tensors[handle.request_id] - except ValueError: - # Tensor was already removed, ignore - pass - - return tensor + # Track which request this tensor belongs to for cleanup + if ipc_data.request_id is not None: + if ipc_data.request_id not in self._request_to_tensors: + self._request_to_tensors[ipc_data.request_id] = [] + self._request_to_tensors[ipc_data.request_id].append(tensor_key) def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems: return MultiModalKwargsItems( @@ -641,23 +652,24 @@ def cleanup_request_tensors(self, request_id: str) -> int: Returns: The number of tensors that were removed from the buffer. """ - if request_id not in self._request_to_tensors: - return 0 - - tensor_keys = self._request_to_tensors.pop(request_id) - removed_count = 0 - - for tensor_key in tensor_keys: - if tensor_key in self._tensor_buffer: - del self._tensor_buffer[tensor_key] - removed_count += 1 - logger.debug( - "Cleaned up orphaned tensor %s for request %s", - tensor_key[1], # Just log the tensor_id part - request_id, - ) - - return removed_count + with self._buffer_lock: + if request_id not in self._request_to_tensors: + return 0 + + tensor_keys = self._request_to_tensors.pop(request_id) + removed_count = 0 + + for tensor_key in tensor_keys: + if tensor_key in self._tensor_buffer: + del self._tensor_buffer[tensor_key] + removed_count += 1 + logger.debug( + "Cleaned up orphaned tensor %s for request %s", + tensor_key[1], # Just log the tensor_id part + request_id, + ) + + return removed_count def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_RAW_VIEW: From d11a8a66e3ce0b92de345e629b1a3f06403e7493 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Mon, 12 Jan 2026 05:19:42 +0000 Subject: [PATCH 13/33] Ensure tensor queue is non-null Signed-off-by: Brandon Pelfrey --- vllm/v1/serial_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index ab5bf6f214d9..2c9e8299534a 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -534,6 +534,8 @@ def _decode_cuda_queue_tensor(self, handle: TensorIpcHandle) -> torch.Tensor: Works for CUDA and CPU. """ + assert self.tensor_queue is not None, "Tensor queue is not set" + # Create lookup key from handle lookup_key = (handle.request_id, handle.tensor_id) From 2fecb85a4ba6a95a9d76bb88bf24a86af9db0473 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Mon, 12 Jan 2026 05:21:51 +0000 Subject: [PATCH 14/33] SyncMPClient: set target engine for IPC routing Signed-off-by: Brandon Pelfrey --- vllm/v1/engine/core_client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index afc4566bdd95..a52b2e0b4bc2 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -747,6 +747,10 @@ def _send_input(self, request_type: EngineCoreRequestType, request: Any): self.ensure_alive() self.free_pending_messages() + # Set target engine index for tensor routing + engine_index = int.from_bytes(self.core_engine, "little") + self.encoder.set_target_engine(engine_index) + # Set request context if this is an ADD request with a request_id if request_type == EngineCoreRequestType.ADD and hasattr(request, "request_id"): self.encoder.set_request_context(request.request_id) From 4fbc3c083b1ca2c2ba96f8f06944e4268b3e4a58 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Mon, 12 Jan 2026 05:23:07 +0000 Subject: [PATCH 15/33] Remove video-related options leftover from other PR Signed-off-by: Brandon Pelfrey --- vllm/config/multimodal.py | 4 ---- vllm/engine/arg_utils.py | 8 -------- 2 files changed, 12 deletions(-) diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 6e3bcddda4c9..6cb81f37a49c 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -140,10 +140,6 @@ class MultiModalConfig: Value sits in range [0;1) and determines fraction of media tokens from each video to be pruned. """ - max_concurrent_videos: int | None = Field(default=None, gt=0) - """Maximum number of videos that can be preprocessed concurrently in this - process. This limits VRAM usage from video decoding libraries like - PyNvVideoCodec that allocate VRAM separately from PyTorch.""" multimodal_tensor_ipc: bool | None = None """Enable IPC (inter-process communication) for multimodal tensors. When enabled, all multimodal tensors (CUDA and CPU) are transferred diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 714bddd6205b..faef50b9ee4f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -991,14 +991,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: multimodal_group.add_argument( "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"] ) - multimodal_group.add_argument( - "--maximum-concurrent-videos", - type=int, - default=None, - help="Maximum number of videos that can be preprocessed concurrently. " - "This limits VRAM usage from video decoding. The count is spread " - "evenly over API server processes.", - ) multimodal_group.add_argument( "--enable-multimodal-tensor-ipc", "--disable-multimodal-tensor-ipc", From 1df4745084f61142c90f04186c11da18ddec7211 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Mon, 12 Jan 2026 18:16:30 +0000 Subject: [PATCH 16/33] remove --disable-multimodal-tensor-ipc Signed-off-by: Brandon Pelfrey --- vllm/engine/arg_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index faef50b9ee4f..3e592e6d9703 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -993,7 +993,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) multimodal_group.add_argument( "--enable-multimodal-tensor-ipc", - "--disable-multimodal-tensor-ipc", action=argparse.BooleanOptionalAction, default=None, help="Enable IPC (inter-process communication) for multimodal tensors. " From c3b78567706ce26d43bb4b78973c43a42fcea58d Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Mon, 12 Jan 2026 18:17:36 +0000 Subject: [PATCH 17/33] multimodal_tensor_ipc = False Signed-off-by: Brandon Pelfrey --- vllm/v1/engine/core_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index a52b2e0b4bc2..13f6f43c1909 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -496,7 +496,7 @@ def __init__( # Serialization setup with tensor queues for multimodal tensor IPC. # Get IPC config from multimodal_config, falling back to env var - multimodal_tensor_ipc = True # Default + multimodal_tensor_ipc = False # Default if vllm_config.model_config.multimodal_config is not None: mm_ipc = ( vllm_config.model_config.multimodal_config.multimodal_tensor_ipc From 500dc8c4f4054cac674c15b7ded84841750205b6 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Mon, 12 Jan 2026 18:18:05 +0000 Subject: [PATCH 18/33] rename _decode_ipc_queue_tensor Signed-off-by: Brandon Pelfrey --- tests/v1/test_tensor_ipc_queue.py | 6 +++--- vllm/v1/serial_utils.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/v1/test_tensor_ipc_queue.py b/tests/v1/test_tensor_ipc_queue.py index d08c9ef8edd4..c63e1c9dfa32 100644 --- a/tests/v1/test_tensor_ipc_queue.py +++ b/tests/v1/test_tensor_ipc_queue.py @@ -223,7 +223,7 @@ def test_decoder_buffer_management(): device="cpu", ) - result = decoder._decode_cuda_queue_tensor(handle) + result = decoder._decode_ipc_queue_tensor(handle) assert result.shape == (6, 7) # Verify buffer has tensor_1 and tensor_2 using tuple keys @@ -239,7 +239,7 @@ def test_decoder_buffer_management(): device="cpu", ) - result2 = decoder._decode_cuda_queue_tensor(handle2) + result2 = decoder._decode_ipc_queue_tensor(handle2) assert result2.shape == (2, 3) # tensor_1 should be removed from buffer assert (None, "tensor_1") not in decoder._tensor_buffer @@ -754,7 +754,7 @@ def test_tensor_cleanup_after_decode(): ) # Decode the tensor - this should retrieve it from the queue - decoded_tensor = decoder._decode_cuda_queue_tensor(handle) + decoded_tensor = decoder._decode_ipc_queue_tensor(handle) # Verify the tensor was decoded assert decoded_tensor.shape == tensor.shape, "Decoded tensor should match shape" diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 2c9e8299534a..76d6bb78e761 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -460,7 +460,7 @@ def dec_hook(self, t: type, obj: Any) -> Any: # to TensorIpcHandle if isinstance(obj, dict): obj = TensorIpcHandle(**obj) - return self._decode_cuda_queue_tensor(obj) + return self._decode_ipc_queue_tensor(obj) if issubclass(t, torch.Tensor): return self._decode_tensor(obj) if t is slice: @@ -528,7 +528,7 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor: # Convert back to proper shape & type return arr.view(torch_dtype).view(shape) - def _decode_cuda_queue_tensor(self, handle: TensorIpcHandle) -> torch.Tensor: + def _decode_ipc_queue_tensor(self, handle: TensorIpcHandle) -> torch.Tensor: """Retrieve a tensor from torch.multiprocessing.Queue. Works for CUDA and CPU. @@ -615,7 +615,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: # values are sometimes floats. return obj if isinstance(obj, TensorIpcHandle): - return self._decode_cuda_queue_tensor(obj) + return self._decode_ipc_queue_tensor(obj) # Check if this is a dict that represents a TensorIpcHandle # (msgspec serializes dataclasses as dicts without type info # in nested structures) @@ -629,7 +629,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: # Convert dict to TensorIpcHandle and decode it # Handle both new format (with request_id) and old format (without) handle = TensorIpcHandle(**obj) - return self._decode_cuda_queue_tensor(handle) + return self._decode_ipc_queue_tensor(handle) if not isinstance(obj, list): raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}") if obj and isinstance(obj[0], str): From 809fe3855b3061f389a2614848db6f89e4fa4629 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Mon, 12 Jan 2026 18:20:10 +0000 Subject: [PATCH 19/33] Use encoder_request_context across MP/Async Clients Signed-off-by: Brandon Pelfrey --- vllm/v1/engine/core_client.py | 57 ++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 13f6f43c1909..d07f7e93fdb0 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -59,6 +59,33 @@ EngineIdentity = bytes +@contextlib.contextmanager +def encoder_request_context( + encoder: MsgpackEncoder, + engine: EngineIdentity, + request_type: EngineCoreRequestType, + request: Any, +): + """Context manager for setting encoder state during request encoding. + + Sets the target engine and request context (for ADD requests) on entry, + and clears the request context on exit. + """ + # Set target engine index for tensor routing + engine_index = int.from_bytes(engine, "little") + encoder.set_target_engine(engine_index) + + # Set request context if this is an ADD request with a request_id + if request_type == EngineCoreRequestType.ADD and hasattr(request, "request_id"): + encoder.set_request_context(request.request_id) + + try: + yield encoder + finally: + # Clear request context after encoding + encoder.set_request_context(None) + + class EngineCoreClient(ABC): """ EngineCoreClient: subclasses handle different methods for pushing @@ -747,19 +774,11 @@ def _send_input(self, request_type: EngineCoreRequestType, request: Any): self.ensure_alive() self.free_pending_messages() - # Set target engine index for tensor routing - engine_index = int.from_bytes(self.core_engine, "little") - self.encoder.set_target_engine(engine_index) - - # Set request context if this is an ADD request with a request_id - if request_type == EngineCoreRequestType.ADD and hasattr(request, "request_id"): - self.encoder.set_request_context(request.request_id) - # (Identity, RequestType, SerializedRequest) - msg = (self.core_engine, request_type.value, *self.encoder.encode(request)) - - # Clear request context after encoding - self.encoder.set_request_context(None) + with encoder_request_context( + self.encoder, self.core_engine, request_type, request + ): + msg = (self.core_engine, request_type.value, *self.encoder.encode(request)) if len(msg) <= 3: # No auxiliary buffers => no tensor backing buffers in request. @@ -940,18 +959,8 @@ def _send_input( if engine is None: engine = self.core_engine - # Set target engine index for CUDA tensor routing - engine_index = int.from_bytes(engine, "little") - self.encoder.set_target_engine(engine_index) - - # Set request context if this is an ADD request with a request_id - if request_type == EngineCoreRequestType.ADD and hasattr(request, "request_id"): - self.encoder.set_request_context(request.request_id) - - message = (request_type.value, *self.encoder.encode(request)) - - # Clear request context after encoding - self.encoder.set_request_context(None) + with encoder_request_context(self.encoder, engine, request_type, request): + message = (request_type.value, *self.encoder.encode(request)) return self._send_input_message(message, engine, request) From 0140d4f710199f943af7268472bd0636e9e58383 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Mon, 12 Jan 2026 19:50:23 +0000 Subject: [PATCH 20/33] Symmetric _encode/_decode methods for tensor queues Signed-off-by: Brandon Pelfrey --- vllm/v1/serial_utils.py | 79 ++++++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 76d6bb78e761..aefcd38b3c2a 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -281,6 +281,47 @@ def _encode_ndarray( # backing buffers that we've stashed in `aux_buffers`. return obj.dtype.str, obj.shape, data + def _encode_ipc_queue_tensor(self, obj: torch.Tensor) -> TensorIpcHandle: + """Send tensor via torch.multiprocessing.Queue for zero-copy IPC. + + This works for both CUDA and CPU tensors. + """ + # Generate unique tensor ID (without request ID embedded) + tensor_id = f"{id(self)}_{self._tensor_id_counter}" + self._tensor_id_counter += 1 + + # Move tensor to shared memory for IPC + # This is required for proper inter-process communication + if not obj.is_shared(): + obj = obj.share_memory_() + + target_queue = self.tensor_queues[self.target_engine_index] + ipc_data = TensorIpcData( + request_id=self._current_request_id, + tensor_id=tensor_id, + tensor=obj, + ) + # Use a timeout to avoid blocking indefinitely + target_queue.put(ipc_data, timeout=10.0) + + logger.debug( + "Sent tensor %s for request %s (shape=%s, device=%s) to engine %d " + "via IPC queue (shared memory)", + tensor_id, + self._current_request_id, + obj.shape, + obj.device, + self.target_engine_index, + ) + + return TensorIpcHandle( + request_id=self._current_request_id, + tensor_id=tensor_id, + shape=list(obj.shape), + dtype=str(obj.dtype).removeprefix("torch."), + device=str(obj.device), + ) + def _encode_tensor( self, obj: torch.Tensor ) -> tuple[str, tuple[int, ...], int | memoryview] | dict[str, Any]: @@ -294,44 +335,8 @@ def _encode_tensor( and self.tensor_queues is not None and self.target_engine_index is not None ): - # Send tensor via torch.multiprocessing.Queue for zero-copy IPC - # This works for both CUDA and CPU tensors - # Generate unique tensor ID (without request ID embedded) - tensor_id = f"{id(self)}_{self._tensor_id_counter}" - self._tensor_id_counter += 1 - try: - # Move tensor to shared memory for IPC - # This is required for proper inter-process communication - if not obj.is_shared(): - obj = obj.share_memory_() - - target_queue = self.tensor_queues[self.target_engine_index] - ipc_data = TensorIpcData( - request_id=self._current_request_id, - tensor_id=tensor_id, - tensor=obj, - ) - # Use a timeout to avoid blocking indefinitely - target_queue.put(ipc_data, timeout=10.0) - - logger.debug( - "Sent tensor %s for request %s (shape=%s, device=%s) to engine %d " - "via IPC queue (shared memory)", - tensor_id, - self._current_request_id, - obj.shape, - obj.device, - self.target_engine_index, - ) - - return TensorIpcHandle( - request_id=self._current_request_id, - tensor_id=tensor_id, - shape=list(obj.shape), - dtype=str(obj.dtype).removeprefix("torch."), - device=str(obj.device), - ) + return self._encode_ipc_queue_tensor(obj) except Exception as e: logger.warning( "Failed to send tensor via IPC queue: %s. " From 68e5bc6d247afc87ace2b49668b0d7c222eea8ae Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Mon, 12 Jan 2026 19:52:37 +0000 Subject: [PATCH 21/33] Handle _decode_tensor calls for both TensorIpcHandle/dict cases Signed-off-by: Brandon Pelfrey --- tests/v1/test_serial_utils.py | 79 +++++++++++++++++++++++++++++++++++ vllm/v1/serial_utils.py | 17 ++++++++ 2 files changed, 96 insertions(+) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index dbbbfce97d28..78e3188d69db 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -283,3 +283,82 @@ def test_custom_class_serialization_disallowed_without_pickle(): with pytest.raises(TypeError): # Attempt to encode the custom class encoder.encode(obj) + + +@dataclass +class RequestWithTensor: + """Mock request with non-multimodal tensor field like EngineCoreRequest.""" + + prompt_embeds: torch.Tensor | None + data: str + + +def test_non_multimodal_tensor_with_ipc(): + """Test that non-multimodal tensor fields work correctly with IPC enabled. + + This reproduces the bug where fields like prompt_embeds: torch.Tensor | None + would fail to decode when IPC is enabled because _decode_tensor expected a tuple + but received a TensorIpcHandle dict. + """ + import torch.multiprocessing as torch_mp + + # Create tensor queues for IPC + tensor_queues = [torch_mp.Queue()] + + # Create encoder with IPC enabled + encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=True) + encoder.set_target_engine(0) + encoder.set_request_context("test_request_123") + + # Create decoder with IPC queue + decoder = MsgpackDecoder(RequestWithTensor, tensor_queue=tensor_queues[0]) + + # Create a request with a non-multimodal tensor + original_tensor = torch.randn(5, 10, dtype=torch.float32) + request = RequestWithTensor(prompt_embeds=original_tensor, data="test_data") + + # Encode the request - this should send the tensor via IPC + encoded = encoder.encode(request) + + # Verify encoding succeeded + assert len(encoded) > 0 + + # Decode the request - this should retrieve the tensor from IPC queue + # Previously this would fail with: TypeError: cannot unpack non-iterable dict object + decoded = decoder.decode(encoded) + + # Verify the decoded request matches the original + assert isinstance(decoded, RequestWithTensor) + assert decoded.data == "test_data" + assert decoded.prompt_embeds is not None + assert torch.allclose(decoded.prompt_embeds, original_tensor), ( + "Decoded tensor does not match the original tensor." + ) + + +def test_non_multimodal_tensor_with_ipc_none_value(): + """Test that None values for tensor fields work correctly with IPC enabled.""" + import torch.multiprocessing as torch_mp + + # Create tensor queues for IPC + tensor_queues = [torch_mp.Queue()] + + # Create encoder with IPC enabled + encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=True) + encoder.set_target_engine(0) + encoder.set_request_context("test_request_456") + + # Create decoder with IPC queue + decoder = MsgpackDecoder(RequestWithTensor, tensor_queue=tensor_queues[0]) + + # Create a request with None for the tensor field + request = RequestWithTensor(prompt_embeds=None, data="test_data_with_none") + + # Encode and decode the request + encoded = encoder.encode(request) + decoded = decoder.decode(encoded) + + # Verify the decoded request matches the original + assert isinstance(decoded, RequestWithTensor) + assert decoded.data == "test_data_with_none" + assert decoded.prompt_embeds is None diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index aefcd38b3c2a..1953fbb22a86 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -511,6 +511,23 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray: return arr.reshape(shape) def _decode_tensor(self, arr: Any) -> torch.Tensor: + # Check if this is a TensorIpcHandle (sent via IPC queue) + # This can happen when IPC is enabled for non-multimodal tensor fields + if isinstance(arr, TensorIpcHandle): + return self._decode_ipc_queue_tensor(arr) + # Check if this is a dict that represents a TensorIpcHandle + # (msgspec serializes dataclasses as dicts without type info) + if ( + isinstance(arr, dict) + and "tensor_id" in arr + and "shape" in arr + and "dtype" in arr + and "device" in arr + ): + # Convert dict to TensorIpcHandle and decode it + handle = TensorIpcHandle(**arr) + return self._decode_ipc_queue_tensor(handle) + # Standard tensor decoding dtype, shape, data = arr is_aux = isinstance(data, int) From 8bf94c408018d49948a90d041817f6bb94ae10df Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Sun, 25 Jan 2026 22:19:39 +0000 Subject: [PATCH 22/33] remove VLLM_MULTIMODAL_TENSOR_IPC env variable Signed-off-by: Brandon Pelfrey --- vllm/config/multimodal.py | 5 ++--- vllm/engine/arg_utils.py | 5 ++--- vllm/envs.py | 9 --------- vllm/v1/engine/core_client.py | 5 +---- 4 files changed, 5 insertions(+), 19 deletions(-) diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 6cb81f37a49c..b0f624ad0154 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -140,13 +140,12 @@ class MultiModalConfig: Value sits in range [0;1) and determines fraction of media tokens from each video to be pruned. """ - multimodal_tensor_ipc: bool | None = None + multimodal_tensor_ipc: bool = False """Enable IPC (inter-process communication) for multimodal tensors. When enabled, all multimodal tensors (CUDA and CPU) are transferred via torch.multiprocessing shared memory for zero-copy IPC. When disabled, all tensors use standard serialization. - If None, defaults to the value of VLLM_MULTIMODAL_TENSOR_IPC environment - variable (default: True).""" + Defaults to False. """ @field_validator("limit_per_prompt", mode="before") @classmethod diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3e592e6d9703..8b6e0654e5dc 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -994,13 +994,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: multimodal_group.add_argument( "--enable-multimodal-tensor-ipc", action=argparse.BooleanOptionalAction, - default=None, + default=False, help="Enable IPC (inter-process communication) for multimodal tensors. " "When enabled, all multimodal tensors (CUDA and CPU) are " "transferred via torch.multiprocessing shared memory for " "zero-copy IPC. When disabled, all tensors use standard " - "serialization. If not specified, defaults to " - "VLLM_MULTIMODAL_TENSOR_IPC env var (default: False).", + "serialization. If not specified, defaults to False.", dest="multimodal_tensor_ipc", ) diff --git a/vllm/envs.py b/vllm/envs.py index d22718ef67f5..4d12ea327833 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -74,8 +74,6 @@ VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_MEDIA_CONNECTOR: str = "http" VLLM_MM_HASHER_ALGORITHM: str = "blake3" - VLLM_MULTIMODAL_TENSOR_IPC: bool = True - VLLM_MULTIMODAL_TENSOR_IPC: bool = False VLLM_TARGET_DEVICE: str = "cuda" VLLM_MAIN_CUDA_VERSION: str = "12.9" VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest" @@ -820,13 +818,6 @@ def get_vllm_port() -> int | None: ["blake3", "sha256", "sha512"], case_sensitive=False, ), - # Enable IPC (inter-process communication) for multimodal tensors. - # When enabled, all multimodal tensors (CUDA and CPU) are transferred - # via torch.multiprocessing shared memory for zero-copy IPC. - # When disabled, all tensors use standard serialization. - "VLLM_MULTIMODAL_TENSOR_IPC": lambda: bool( - int(os.getenv("VLLM_MULTIMODAL_TENSOR_IPC", "1")) - ), # Path to the XLA persistent cache directory. # Only used for XLA devices such as TPUs. "VLLM_XLA_CACHE_PATH": lambda: os.path.expanduser( diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index d07f7e93fdb0..1c0611fb3d55 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -531,10 +531,7 @@ def __init__( if mm_ipc is not None: multimodal_tensor_ipc = mm_ipc else: - # Fall back to environment variable - from vllm import envs - - multimodal_tensor_ipc = envs.VLLM_MULTIMODAL_TENSOR_IPC + multimodal_tensor_ipc = False self.encoder = MsgpackEncoder( tensor_queues=tensor_queues, From d03c791b5b8e3b9fb03a6f781fbd17ec6dc9e6eb Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Sun, 25 Jan 2026 22:45:53 +0000 Subject: [PATCH 23/33] CR comments on request->tensor cleanup Signed-off-by: Brandon Pelfrey --- vllm/v1/serial_utils.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 1953fbb22a86..ea50d5788ad0 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -575,15 +575,12 @@ def _decode_ipc_queue_tensor(self, handle: TensorIpcHandle) -> torch.Tensor: handle.request_id is not None and handle.request_id in self._request_to_tensors ): - try: - self._request_to_tensors[handle.request_id].remove( - lookup_key - ) - if not self._request_to_tensors[handle.request_id]: + tensors = self._request_to_tensors.get(handle.request_id) + if tensors: + tensors.remove(lookup_key) + # Clean up if this is the last tensor for the request + if not tensors: del self._request_to_tensors[handle.request_id] - except ValueError: - # Tensor was already removed, ignore - pass return tensor From bd4b5ee9e611a2464d19d9ec795195fd7a409542 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Sun, 25 Jan 2026 22:58:04 +0000 Subject: [PATCH 24/33] Address precommit Signed-off-by: Brandon Pelfrey --- vllm/v1/engine/core_client.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 1c0611fb3d55..ba741882b078 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -528,10 +528,7 @@ def __init__( mm_ipc = ( vllm_config.model_config.multimodal_config.multimodal_tensor_ipc ) - if mm_ipc is not None: - multimodal_tensor_ipc = mm_ipc - else: - multimodal_tensor_ipc = False + multimodal_tensor_ipc = mm_ipc if mm_ipc is not None else False self.encoder = MsgpackEncoder( tensor_queues=tensor_queues, From e37a2c8df5b9776cc6528756cc562b73e9d3b1f5 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Wed, 28 Jan 2026 01:26:05 +0000 Subject: [PATCH 25/33] Change config to msgspec|torch instead of boolean Signed-off-by: Brandon Pelfrey --- tests/v1/test_serial_utils.py | 4 ++-- tests/v1/test_tensor_ipc_queue.py | 24 +++++++++++------------- vllm/config/model.py | 4 ++-- vllm/config/multimodal.py | 11 +++++------ vllm/engine/arg_utils.py | 15 +++++---------- vllm/v1/engine/core_client.py | 6 +++--- vllm/v1/serial_utils.py | 13 +++++++------ 7 files changed, 35 insertions(+), 42 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 78e3188d69db..743a779e533b 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -306,7 +306,7 @@ def test_non_multimodal_tensor_with_ipc(): tensor_queues = [torch_mp.Queue()] # Create encoder with IPC enabled - encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=True) + encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc="torch") encoder.set_target_engine(0) encoder.set_request_context("test_request_123") @@ -344,7 +344,7 @@ def test_non_multimodal_tensor_with_ipc_none_value(): tensor_queues = [torch_mp.Queue()] # Create encoder with IPC enabled - encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=True) + encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc="torch") encoder.set_target_engine(0) encoder.set_request_context("test_request_456") diff --git a/tests/v1/test_tensor_ipc_queue.py b/tests/v1/test_tensor_ipc_queue.py index c63e1c9dfa32..e4405edff48f 100644 --- a/tests/v1/test_tensor_ipc_queue.py +++ b/tests/v1/test_tensor_ipc_queue.py @@ -343,16 +343,10 @@ def mixed_tensor_encoder_process( ready_event: EventType, retrieval_done: EventType, ): - """Process that encodes mixed CPU/CUDA tensors. - - Old behavior: only CUDA via IPC. - """ + """Process that encodes mixed CPU/CUDA tensors.""" try: - # Use old behavior: multimodal_tensor_ipc defaults to True but only CUDA went - # through. For this test, we want to test the old behavior where only CUDA - # uses IPC. encoder = MsgpackEncoder( - tensor_queues=tensor_queues, multimodal_tensor_ipc=False + tensor_queues=tensor_queues, multimodal_tensor_ipc="torch" ) encoder.set_target_engine(0) @@ -482,7 +476,7 @@ def cpu_tensor_ipc_encoder_process( try: # Create encoder with IPC enabled for all tensors encoder = MsgpackEncoder( - tensor_queues=tensor_queues, multimodal_tensor_ipc=True + tensor_queues=tensor_queues, multimodal_tensor_ipc="torch" ) encoder.set_target_engine(target_engine) @@ -614,11 +608,13 @@ def test_cpu_tensor_ipc(): def test_ipc_disabled_mode(): - """Test that IPC is disabled when multimodal_tensor_ipc=False.""" + """Test that IPC is disabled when multimodal_tensor_ipc="msgspec".""" tensor_queues = [torch_mp.Queue()] # Create encoder with IPC disabled - encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=False) + encoder = MsgpackEncoder( + tensor_queues=tensor_queues, multimodal_tensor_ipc="msgspec" + ) encoder.set_target_engine(0) # Create a CPU tensor @@ -652,11 +648,13 @@ def test_mixed_cpu_cuda_with_ipc_enabled(): tensor_queues = [torch_mp.Queue()] # Create encoder with IPC enabled for all tensors - encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc=True) + encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc="torch") encoder.set_target_engine(0) # Verify encoder configuration - assert encoder.multimodal_tensor_ipc is True, "IPC should be enabled" + assert encoder.multimodal_tensor_ipc == "torch", ( + "Torch queue-based IPC should be enabled" + ) assert encoder.tensor_queues is not None, "Tensor queues should be set" assert encoder.target_engine_index == 0, "Target engine should be set" diff --git a/vllm/config/model.py b/vllm/config/model.py index afc35436d1d6..5c244661e2a3 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -311,7 +311,7 @@ class ModelConfig: interleave_mm_strings: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None video_pruning_rate: InitVar[float | None] = None - multimodal_tensor_ipc: InitVar[bool | None] = None + multimodal_tensor_ipc: InitVar[Literal["msgspec", "torch"] | None] = None def compute_hash(self) -> str: """ @@ -426,7 +426,7 @@ def __post_init__( interleave_mm_strings: bool | None, skip_mm_profiling: bool | None, video_pruning_rate: float | None, - multimodal_tensor_ipc: bool | None, + multimodal_tensor_ipc: Literal["msgspec", "torch"] | None, ) -> None: # Keep set served_model_name before maybe_model_redirect(self.model) self.served_model_name = get_served_model_name( diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index b0f624ad0154..42ac1e4719f8 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -140,12 +140,11 @@ class MultiModalConfig: Value sits in range [0;1) and determines fraction of media tokens from each video to be pruned. """ - multimodal_tensor_ipc: bool = False - """Enable IPC (inter-process communication) for multimodal tensors. - When enabled, all multimodal tensors (CUDA and CPU) are transferred - via torch.multiprocessing shared memory for zero-copy IPC. - When disabled, all tensors use standard serialization. - Defaults to False. """ + multimodal_tensor_ipc: Literal["msgspec", "torch"] = "msgspec" + """IPC (inter-process communication) method for multimodal tensors. + - "msgspec": Use msgspec serialization + - "torch": Use torch.multiprocessing shared memory for zero-copy IPC + Defaults to "msgspec". """ @field_validator("limit_per_prompt", mode="before") @classmethod diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8b6e0654e5dc..499ef60e2df2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -474,7 +474,9 @@ class EngineArgs: io_processor_plugin: str | None = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling video_pruning_rate: float = MultiModalConfig.video_pruning_rate - multimodal_tensor_ipc: bool | None = MultiModalConfig.multimodal_tensor_ipc + multimodal_tensor_ipc: Literal["msgspec", "torch"] | None = ( + MultiModalConfig.multimodal_tensor_ipc + ) # LoRA fields enable_lora: bool = False max_loras: int = LoRAConfig.max_loras @@ -992,15 +994,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"] ) multimodal_group.add_argument( - "--enable-multimodal-tensor-ipc", - action=argparse.BooleanOptionalAction, - default=False, - help="Enable IPC (inter-process communication) for multimodal tensors. " - "When enabled, all multimodal tensors (CUDA and CPU) are " - "transferred via torch.multiprocessing shared memory for " - "zero-copy IPC. When disabled, all tensors use standard " - "serialization. If not specified, defaults to False.", - dest="multimodal_tensor_ipc", + "--multimodal-tensor-ipc", + **multimodal_kwargs["multimodal_tensor_ipc"], ) # LoRA related configs diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index ba741882b078..d189001e4a4b 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -522,13 +522,13 @@ def __init__( ) # Serialization setup with tensor queues for multimodal tensor IPC. - # Get IPC config from multimodal_config, falling back to env var - multimodal_tensor_ipc = False # Default + # Get IPC config from multimodal_config, falling back to default + multimodal_tensor_ipc = "msgspec" # Default if vllm_config.model_config.multimodal_config is not None: mm_ipc = ( vllm_config.model_config.multimodal_config.multimodal_tensor_ipc ) - multimodal_tensor_ipc = mm_ipc if mm_ipc is not None else False + multimodal_tensor_ipc = mm_ipc if mm_ipc is not None else "msgspec" self.encoder = MsgpackEncoder( tensor_queues=tensor_queues, diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index ea50d5788ad0..29a3b9d16fc0 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -154,16 +154,17 @@ class MsgpackEncoder: By default, arrays below 256B are serialized inline Larger will get sent via dedicated messages. Note that this is a per-tensor limit. - When multimodal_tensor_ipc is enabled and tensor_queues is provided, + When multimodal_tensor_ipc is "torch" and tensor_queues is provided, all multimodal tensors (CUDA and CPU) will be sent via torch.multiprocessing.Queue for zero-copy IPC instead of serialization. + When "msgspec", tensors use standard msgspec serialization. """ def __init__( self, size_threshold: int | None = None, tensor_queues: list[Any] | None = None, - multimodal_tensor_ipc: bool = True, + multimodal_tensor_ipc: str = "msgspec", ): if size_threshold is None: size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD @@ -175,7 +176,7 @@ def __init__( self.size_threshold = size_threshold # Tensor IPC queues for sharing multimodal tensors (one per engine core) self.tensor_queues = tensor_queues - # Enable IPC for all multimodal tensors (CUDA and CPU) + # IPC method for multimodal tensors self.multimodal_tensor_ipc = multimodal_tensor_ipc # Target engine index for routing tensors to the correct queue self.target_engine_index: int | None = None @@ -328,10 +329,10 @@ def _encode_tensor( assert self.aux_buffers is not None # Check if we should use IPC for this tensor - # IPC is used when: multimodal_tensor_ipc is enabled, queues are available, + # IPC is used when: multimodal_tensor_ipc is "torch", queues are available, # and we have a target engine if ( - self.multimodal_tensor_ipc + self.multimodal_tensor_ipc == "torch" and self.tensor_queues is not None and self.target_engine_index is not None ): @@ -348,7 +349,7 @@ def _encode_tensor( # Standard serialization fallback # For CUDA tensors without IPC support, we need to move to CPU first if obj.is_cuda: - if self.multimodal_tensor_ipc and self.tensor_queues is not None: + if self.multimodal_tensor_ipc == "torch" and self.tensor_queues is not None: # Only warn if IPC was expected but unavailable logger.warning( "CUDA tensor without IPC support encountered " From b1f6aa5efed93b9f3038b9571c353f964c2b41e2 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Wed, 28 Jan 2026 16:04:03 +0000 Subject: [PATCH 26/33] remove None typing for multimodal_tensor_ipc Signed-off-by: Brandon Pelfrey --- vllm/config/model.py | 2 +- vllm/engine/arg_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index 5c244661e2a3..4e6174f305c0 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -426,7 +426,7 @@ def __post_init__( interleave_mm_strings: bool | None, skip_mm_profiling: bool | None, video_pruning_rate: float | None, - multimodal_tensor_ipc: Literal["msgspec", "torch"] | None, + multimodal_tensor_ipc: Literal["msgspec", "torch"], ) -> None: # Keep set served_model_name before maybe_model_redirect(self.model) self.served_model_name = get_served_model_name( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 499ef60e2df2..d5c15d2329db 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -474,7 +474,7 @@ class EngineArgs: io_processor_plugin: str | None = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling video_pruning_rate: float = MultiModalConfig.video_pruning_rate - multimodal_tensor_ipc: Literal["msgspec", "torch"] | None = ( + multimodal_tensor_ipc: Literal["msgspec", "torch"] = ( MultiModalConfig.multimodal_tensor_ipc ) # LoRA fields From 6aa1e3d433cfaed1c2bedc8a6879337779d12700 Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Fri, 30 Jan 2026 21:27:02 +0000 Subject: [PATCH 27/33] Change to direct_rpc and torch_shm, dataclass -> NamedTuple+tuple datatype, cleanup on scheduler finished_req_ids Signed-off-by: Brandon Pelfrey --- tests/v1/test_serial_utils.py | 8 ++++++-- tests/v1/test_tensor_ipc_queue.py | 14 ++++++++------ vllm/config/model.py | 4 ++-- vllm/config/multimodal.py | 8 ++++---- vllm/engine/arg_utils.py | 2 +- vllm/v1/engine/core.py | 20 ++++++++++++++++++++ vllm/v1/engine/core_client.py | 5 ++--- vllm/v1/serial_utils.py | 22 ++++++++++++---------- 8 files changed, 55 insertions(+), 28 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 743a779e533b..bfa14193e5ab 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -306,7 +306,9 @@ def test_non_multimodal_tensor_with_ipc(): tensor_queues = [torch_mp.Queue()] # Create encoder with IPC enabled - encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc="torch") + encoder = MsgpackEncoder( + tensor_queues=tensor_queues, multimodal_tensor_ipc="torch_shm" + ) encoder.set_target_engine(0) encoder.set_request_context("test_request_123") @@ -344,7 +346,9 @@ def test_non_multimodal_tensor_with_ipc_none_value(): tensor_queues = [torch_mp.Queue()] # Create encoder with IPC enabled - encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc="torch") + encoder = MsgpackEncoder( + tensor_queues=tensor_queues, multimodal_tensor_ipc="torch_shm" + ) encoder.set_target_engine(0) encoder.set_request_context("test_request_456") diff --git a/tests/v1/test_tensor_ipc_queue.py b/tests/v1/test_tensor_ipc_queue.py index e4405edff48f..f6ce283937d0 100644 --- a/tests/v1/test_tensor_ipc_queue.py +++ b/tests/v1/test_tensor_ipc_queue.py @@ -346,7 +346,7 @@ def mixed_tensor_encoder_process( """Process that encodes mixed CPU/CUDA tensors.""" try: encoder = MsgpackEncoder( - tensor_queues=tensor_queues, multimodal_tensor_ipc="torch" + tensor_queues=tensor_queues, multimodal_tensor_ipc="torch_shm" ) encoder.set_target_engine(0) @@ -476,7 +476,7 @@ def cpu_tensor_ipc_encoder_process( try: # Create encoder with IPC enabled for all tensors encoder = MsgpackEncoder( - tensor_queues=tensor_queues, multimodal_tensor_ipc="torch" + tensor_queues=tensor_queues, multimodal_tensor_ipc="torch_shm" ) encoder.set_target_engine(target_engine) @@ -608,12 +608,12 @@ def test_cpu_tensor_ipc(): def test_ipc_disabled_mode(): - """Test that IPC is disabled when multimodal_tensor_ipc="msgspec".""" + """Test that IPC is disabled when multimodal_tensor_ipc="direct_rpc".""" tensor_queues = [torch_mp.Queue()] # Create encoder with IPC disabled encoder = MsgpackEncoder( - tensor_queues=tensor_queues, multimodal_tensor_ipc="msgspec" + tensor_queues=tensor_queues, multimodal_tensor_ipc="direct_rpc" ) encoder.set_target_engine(0) @@ -648,11 +648,13 @@ def test_mixed_cpu_cuda_with_ipc_enabled(): tensor_queues = [torch_mp.Queue()] # Create encoder with IPC enabled for all tensors - encoder = MsgpackEncoder(tensor_queues=tensor_queues, multimodal_tensor_ipc="torch") + encoder = MsgpackEncoder( + tensor_queues=tensor_queues, multimodal_tensor_ipc="torch_shm" + ) encoder.set_target_engine(0) # Verify encoder configuration - assert encoder.multimodal_tensor_ipc == "torch", ( + assert encoder.multimodal_tensor_ipc == "torch_shm", ( "Torch queue-based IPC should be enabled" ) assert encoder.tensor_queues is not None, "Tensor queues should be set" diff --git a/vllm/config/model.py b/vllm/config/model.py index 4e6174f305c0..f76fb571e237 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -311,7 +311,7 @@ class ModelConfig: interleave_mm_strings: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None video_pruning_rate: InitVar[float | None] = None - multimodal_tensor_ipc: InitVar[Literal["msgspec", "torch"] | None] = None + multimodal_tensor_ipc: InitVar[Literal["direct_rpc", "torch_shm"]] = "direct_rpc" def compute_hash(self) -> str: """ @@ -426,7 +426,7 @@ def __post_init__( interleave_mm_strings: bool | None, skip_mm_profiling: bool | None, video_pruning_rate: float | None, - multimodal_tensor_ipc: Literal["msgspec", "torch"], + multimodal_tensor_ipc: Literal["direct_rpc", "torch_shm"], ) -> None: # Keep set served_model_name before maybe_model_redirect(self.model) self.served_model_name = get_served_model_name( diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 42ac1e4719f8..4bc6d36789d9 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -140,11 +140,11 @@ class MultiModalConfig: Value sits in range [0;1) and determines fraction of media tokens from each video to be pruned. """ - multimodal_tensor_ipc: Literal["msgspec", "torch"] = "msgspec" + multimodal_tensor_ipc: Literal["direct_rpc", "torch_shm"] = "direct_rpc" """IPC (inter-process communication) method for multimodal tensors. - - "msgspec": Use msgspec serialization - - "torch": Use torch.multiprocessing shared memory for zero-copy IPC - Defaults to "msgspec". """ + - "direct_rpc": Use msgspec serialization via RPC + - "torch_shm": Use torch.multiprocessing shared memory for zero-copy IPC + Defaults to "direct_rpc". """ @field_validator("limit_per_prompt", mode="before") @classmethod diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d5c15d2329db..6921ddd9c1b4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -474,7 +474,7 @@ class EngineArgs: io_processor_plugin: str | None = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling video_pruning_rate: float = MultiModalConfig.video_pruning_rate - multimodal_tensor_ipc: Literal["msgspec", "torch"] = ( + multimodal_tensor_ipc: Literal["direct_rpc", "torch_shm"] = ( MultiModalConfig.multimodal_tensor_ipc ) # LoRA fields diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 8b5c20672a9d..ac8b637b6a9a 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -395,8 +395,19 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: scheduler_output, model_output ) + # Cleanup tensors for finished requests + self._cleanup_finished_request_tensors(scheduler_output.finished_req_ids) + return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 + def _cleanup_finished_request_tensors(self, finished_req_ids: set[str]) -> None: + """Cleanup any orphaned tensors for finished requests. + + This is a no-op in the base class but can be overridden in subclasses + to perform actual cleanup (e.g., for IPC tensor queues). + """ + pass + def post_step(self, model_executed: bool) -> None: # When using async scheduling we can't get draft token ids in advance, # so we update draft token ids in the worker process and don't @@ -497,6 +508,9 @@ def step_with_batch_queue( scheduler_output, model_output ) + # Cleanup tensors for finished requests + self._cleanup_finished_request_tensors(scheduler_output.finished_req_ids) + # NOTE(nick): We can either handle the deferred tasks here or save # in a field and do it immediately once step_with_batch_queue is # re-called. The latter slightly favors TTFT over TPOT/throughput. @@ -1037,6 +1051,12 @@ def abort_requests(self, request_ids: list[str]): for request_id in request_ids: self.tensor_decoder.cleanup_request_tensors(request_id) + def _cleanup_finished_request_tensors(self, finished_req_ids: set[str]) -> None: + """Cleanup any orphaned tensors for finished requests.""" + if self.tensor_decoder is not None: + for request_id in finished_req_ids: + self.tensor_decoder.cleanup_request_tensors(request_id) + def _handle_client_request( self, request_type: EngineCoreRequestType, request: Any ) -> None: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index d189001e4a4b..08ecec12a1e8 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -523,12 +523,11 @@ def __init__( # Serialization setup with tensor queues for multimodal tensor IPC. # Get IPC config from multimodal_config, falling back to default - multimodal_tensor_ipc = "msgspec" # Default + multimodal_tensor_ipc = "direct_rpc" # Default if vllm_config.model_config.multimodal_config is not None: - mm_ipc = ( + multimodal_tensor_ipc = ( vllm_config.model_config.multimodal_config.multimodal_tensor_ipc ) - multimodal_tensor_ipc = mm_ipc if mm_ipc is not None else "msgspec" self.encoder = MsgpackEncoder( tensor_queues=tensor_queues, diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 29a3b9d16fc0..2fffc3db26c7 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -9,7 +9,7 @@ from functools import partial from inspect import isclass from types import FunctionType -from typing import Any, TypeAlias, get_type_hints +from typing import Any, NamedTuple, TypeAlias, get_type_hints import cloudpickle import msgspec @@ -57,8 +57,7 @@ class TensorIpcData: tensor: torch.Tensor -@dataclasses.dataclass -class TensorIpcHandle: +class TensorIpcHandle(NamedTuple): """ Handle for a tensor sent via IPC queue (zero-copy transfer). @@ -70,7 +69,7 @@ class TensorIpcHandle: request_id: str | None tensor_id: str - shape: list[int] + shape: tuple[int, ...] dtype: str device: str @@ -154,17 +153,17 @@ class MsgpackEncoder: By default, arrays below 256B are serialized inline Larger will get sent via dedicated messages. Note that this is a per-tensor limit. - When multimodal_tensor_ipc is "torch" and tensor_queues is provided, + When multimodal_tensor_ipc is "torch_shm" and tensor_queues is provided, all multimodal tensors (CUDA and CPU) will be sent via torch.multiprocessing.Queue for zero-copy IPC instead of serialization. - When "msgspec", tensors use standard msgspec serialization. + When "direct_rpc", tensors use standard msgspec serialization. """ def __init__( self, size_threshold: int | None = None, tensor_queues: list[Any] | None = None, - multimodal_tensor_ipc: str = "msgspec", + multimodal_tensor_ipc: str = "direct_rpc", ): if size_threshold is None: size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD @@ -329,10 +328,10 @@ def _encode_tensor( assert self.aux_buffers is not None # Check if we should use IPC for this tensor - # IPC is used when: multimodal_tensor_ipc is "torch", queues are available, + # IPC is used when: multimodal_tensor_ipc is "torch_shm", queues are available, # and we have a target engine if ( - self.multimodal_tensor_ipc == "torch" + self.multimodal_tensor_ipc == "torch_shm" and self.tensor_queues is not None and self.target_engine_index is not None ): @@ -349,7 +348,10 @@ def _encode_tensor( # Standard serialization fallback # For CUDA tensors without IPC support, we need to move to CPU first if obj.is_cuda: - if self.multimodal_tensor_ipc == "torch" and self.tensor_queues is not None: + if ( + self.multimodal_tensor_ipc == "torch_shm" + and self.tensor_queues is not None + ): # Only warn if IPC was expected but unavailable logger.warning( "CUDA tensor without IPC support encountered " From 8a8a7b8409cbc85fe99b5136df91bef0bde5e08f Mon Sep 17 00:00:00 2001 From: Brandon Pelfrey Date: Tue, 3 Feb 2026 19:03:45 +0000 Subject: [PATCH 28/33] precommit issues resolved Signed-off-by: Brandon Pelfrey --- vllm/v1/serial_utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 2fffc3db26c7..9e09d2145686 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -286,6 +286,9 @@ def _encode_ipc_queue_tensor(self, obj: torch.Tensor) -> TensorIpcHandle: This works for both CUDA and CPU tensors. """ + assert self.target_engine_index is not None, "Target engine index is not set" + assert self.tensor_queues is not None, "Tensor queues are not set" + # Generate unique tensor ID (without request ID embedded) tensor_id = f"{id(self)}_{self._tensor_id_counter}" self._tensor_id_counter += 1 @@ -317,14 +320,16 @@ def _encode_ipc_queue_tensor(self, obj: torch.Tensor) -> TensorIpcHandle: return TensorIpcHandle( request_id=self._current_request_id, tensor_id=tensor_id, - shape=list(obj.shape), + shape=tuple(obj.shape), dtype=str(obj.dtype).removeprefix("torch."), device=str(obj.device), ) def _encode_tensor( self, obj: torch.Tensor - ) -> tuple[str, tuple[int, ...], int | memoryview] | dict[str, Any]: + ) -> ( + tuple[str, tuple[int, ...], int | memoryview] | dict[str, Any] | TensorIpcHandle + ): assert self.aux_buffers is not None # Check if we should use IPC for this tensor @@ -530,6 +535,12 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor: # Convert dict to TensorIpcHandle and decode it handle = TensorIpcHandle(**arr) return self._decode_ipc_queue_tensor(handle) + # Check if this is a list/tuple with 5 elements (TensorIpcHandle) + # msgspec serializes NamedTuples as lists + if isinstance(arr, (list, tuple)) and len(arr) == 5: + # Convert list to TensorIpcHandle and decode it + handle = TensorIpcHandle(*arr) + return self._decode_ipc_queue_tensor(handle) # Standard tensor decoding dtype, shape, data = arr From 852e13fbb8c6ecb772b011cc70a033c9e066529d Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Wed, 4 Feb 2026 16:46:38 -0500 Subject: [PATCH 29/33] upd --- vllm/multimodal/media/image.py | 146 ++++++++------------------------- 1 file changed, 36 insertions(+), 110 deletions(-) diff --git a/vllm/multimodal/media/image.py b/vllm/multimodal/media/image.py index 309f8ea15447..9b427fc1f77f 100644 --- a/vllm/multimodal/media/image.py +++ b/vllm/multimodal/media/image.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading from io import BytesIO from pathlib import Path -from typing import TypeAlias import pybase64 import torch +from nvidia import nvimgcodec from PIL import Image from vllm.logger import init_logger @@ -16,16 +17,18 @@ logger = init_logger(__file__) -# Image output can be either PIL Image or Tensor (from nvJPEG) -ImageOutput: TypeAlias = Image.Image | torch.Tensor +# Thread-local storage for nvimgcodec decoder +_thread_local = threading.local() -class ImageMediaIO(MediaIO[ImageOutput]): - # Class-level counters for nvJPEG statistics - _nvjpeg_success_count: int = 0 - _nvjpeg_fallback_count: int = 0 - _nvjpeg_available: bool | None = None # Lazy initialization +def _get_decoder() -> nvimgcodec.Decoder: + """Get a per-thread nvimgcodec decoder instance.""" + if not hasattr(_thread_local, "decoder"): + _thread_local.decoder = nvimgcodec.Decoder() + return _thread_local.decoder + +class ImageMediaIO(MediaIO[Image.Image]): def __init__(self, image_mode: str = "RGB", **kwargs) -> None: super().__init__() @@ -56,87 +59,6 @@ def __init__(self, image_mode: str = "RGB", **kwargs) -> None: ) self.rgba_background_color = rgba_bg - # Check nvJPEG availability on first instantiation - if ImageMediaIO._nvjpeg_available is None: - ImageMediaIO._nvjpeg_available = self._check_nvjpeg_available() - - @staticmethod - def _check_nvjpeg_available() -> bool: - """Check if nvJPEG is available (CUDA + torchvision decode_jpeg).""" - try: - # torch.cuda.is_available() can raise RuntimeError if CUDA driver fails - if not torch.cuda.is_available(): - logger.debug("nvJPEG not available: CUDA not available") - return False - # Check if torchvision decode_jpeg is available - from torchvision.io import decode_jpeg # noqa: F401 - logger.info("nvJPEG available: using GPU-accelerated JPEG decoding") - return True - except ImportError: - logger.debug("nvJPEG not available: torchvision.io.decode_jpeg not found") - return False - except RuntimeError as e: - # CUDA driver initialization can fail with RuntimeError - logger.debug(f"nvJPEG not available: CUDA driver error - {e}") - return False - except Exception as e: - logger.debug(f"nvJPEG not available: {e}") - return False - - @staticmethod - def _is_jpeg(data: bytes) -> bool: - """Detect JPEG format from magic bytes.""" - return len(data) >= 3 and data[:3] == b'\xff\xd8\xff' - - def _decode_with_nvjpeg(self, data: bytes) -> torch.Tensor | None: - """ - Try to decode JPEG using nvJPEG (GPU-accelerated). - - Returns: - torch.Tensor in CHW format on CPU, or None on failure. - Note: Decoding happens on GPU for speed, then moved to CPU - for compatibility with vLLM's memory pinning. - """ - try: - from torchvision.io import decode_jpeg, ImageReadMode - - # Convert bytes to tensor - data_tensor = torch.frombuffer(bytearray(data), dtype=torch.uint8) - - # Select mode based on image_mode - if self.image_mode == "RGB": - mode = ImageReadMode.RGB - elif self.image_mode == "L": - mode = ImageReadMode.GRAY - else: - mode = ImageReadMode.UNCHANGED - - # Decode on GPU using nvJPEG - tensor = decode_jpeg(data_tensor, mode=mode, device='cuda') - - # Move to CPU for compatibility with vLLM's memory pinning - tensor = tensor.cpu() - - # Update success counter and log periodically - ImageMediaIO._nvjpeg_success_count += 1 - self._log_stats_if_needed() - - return tensor # CHW tensor on CPU - - except Exception as e: - logger.debug(f"nvJPEG decode failed, falling back to PIL: {e}") - ImageMediaIO._nvjpeg_fallback_count += 1 - return None - - def _log_stats_if_needed(self) -> None: - """Log nvJPEG statistics periodically.""" - total = ImageMediaIO._nvjpeg_success_count + ImageMediaIO._nvjpeg_fallback_count - if total > 0 and total % 100 == 0: - logger.info( - f"nvJPEG decode stats: {ImageMediaIO._nvjpeg_success_count} successful, " - f"{ImageMediaIO._nvjpeg_fallback_count} fallback to PIL" - ) - def _convert_image_mode( self, image: Image.Image | MediaWithBytes[Image.Image] ) -> Image.Image: @@ -150,33 +72,37 @@ def _convert_image_mode( else: return convert_image_mode(image, self.image_mode) - def load_bytes(self, data: bytes) -> MediaWithBytes[ImageOutput]: - # Try nvJPEG for JPEG images when available - if ImageMediaIO._nvjpeg_available and self._is_jpeg(data): - tensor = self._decode_with_nvjpeg(data) - if tensor is not None: - return MediaWithBytes(tensor, data) - - # Fallback to PIL for non-JPEG or when nvJPEG fails + def load_bytes( + self, data: bytes + ) -> MediaWithBytes[Image.Image] | MediaWithBytes[torch.Tensor]: + # return self.load_pil_image(data) + return self.load_nvimgcodec_image(data) + + def load_base64( + self, media_type: str, data: str + ) -> MediaWithBytes[Image.Image] | MediaWithBytes[torch.Tensor]: + return self.load_bytes(pybase64.b64decode(data, validate=True)) + + def load_pil_image(self, data: bytes) -> MediaWithBytes[Image.Image]: image = Image.open(BytesIO(data)) return MediaWithBytes(self._convert_image_mode(image), data) - def load_base64(self, media_type: str, data: str) -> MediaWithBytes[ImageOutput]: - return self.load_bytes(pybase64.b64decode(data, validate=True)) + def load_nvimgcodec_image(self, data: bytes) -> MediaWithBytes[torch.Tensor]: + decoded = _get_decoder().decode(data) + + device = "cuda:0" + tensor = torch.as_tensor(decoded, device=device) + # HWC -> CHW + tensor = tensor.permute(2, 0, 1) + + return MediaWithBytes(tensor, data) - def load_file(self, filepath: Path) -> MediaWithBytes[ImageOutput]: + def load_file( + self, filepath: Path + ) -> MediaWithBytes[Image.Image] | MediaWithBytes[torch.Tensor]: with open(filepath, "rb") as f: data = f.read() - - # Try nvJPEG for JPEG images when available - if ImageMediaIO._nvjpeg_available and self._is_jpeg(data): - tensor = self._decode_with_nvjpeg(data) - if tensor is not None: - return MediaWithBytes(tensor, data) - - # Fallback to PIL for non-JPEG or when nvJPEG fails - image = Image.open(BytesIO(data)) - return MediaWithBytes(self._convert_image_mode(image), data) + return self.load_bytes(data) def encode_base64( self, From 7713eafffff5815669806838eb822d9e7f6a2e4c Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Wed, 4 Feb 2026 18:40:28 -0500 Subject: [PATCH 30/33] upd --- vllm/v1/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 666e5c5d313d..5c7b6bbd4d41 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -417,7 +417,7 @@ def tensor_data(tensor: torch.Tensor) -> memoryview: Returns: A memoryview of the tensor data as uint8. """ - return tensor.cpu().flatten().contiguous().view(torch.uint8).numpy().data + return tensor.flatten().contiguous().view(torch.uint8).numpy().data @dataclass From d5c1780b2d79f7ee65e1bf5bbd74dbad563768d7 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Wed, 4 Feb 2026 15:20:47 -0800 Subject: [PATCH 31/33] upd --- vllm/multimodal/media/image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/multimodal/media/image.py b/vllm/multimodal/media/image.py index 9b427fc1f77f..00c741ee663a 100644 --- a/vllm/multimodal/media/image.py +++ b/vllm/multimodal/media/image.py @@ -88,7 +88,8 @@ def load_pil_image(self, data: bytes) -> MediaWithBytes[Image.Image]: return MediaWithBytes(self._convert_image_mode(image), data) def load_nvimgcodec_image(self, data: bytes) -> MediaWithBytes[torch.Tensor]: - decoded = _get_decoder().decode(data) + code_stream = nvimgcodec.CodeStream(data) + decoded = _get_decoder().decode(code_stream) device = "cuda:0" tensor = torch.as_tensor(decoded, device=device) From 517e7e4b048f68975d4f9103f6f0516061bb3faa Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Wed, 4 Feb 2026 17:07:55 -0800 Subject: [PATCH 32/33] upd --- vllm/multimodal/inputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 7b12158763c3..08d17c7df5b5 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -477,7 +477,7 @@ def reduce_data( pin_memory = False batch = [elem.data for elem in elems] - out = self._reduce_data(batch, pin_memory=pin_memory) + out = self._reduce_data(batch, pin_memory=pin_memory and device.type != 'cuda') return _nested_tensors_h2d(out, device=device) From c5e2df16211ce81ac76156637b9e30b1dc6c9e95 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Sun, 8 Feb 2026 21:48:54 -0500 Subject: [PATCH 33/33] fix --- vllm/multimodal/inputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 08d17c7df5b5..a76caa1044f3 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -477,7 +477,7 @@ def reduce_data( pin_memory = False batch = [elem.data for elem in elems] - out = self._reduce_data(batch, pin_memory=pin_memory and device.type != 'cuda') + out = self._reduce_data(batch, pin_memory=pin_memory and device.type != "cuda") return _nested_tensors_h2d(out, device=device)