diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 380bbc30e3d1..fda957258600 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -12,4 +12,6 @@ torchvision==0.24.1 # Required for phi3v processor. See https://github.com/pytor # FlashInfer should be updated together with the Dockerfile flashinfer-python==0.5.3 # FA4 -flash-attn-cute @ git+https://github.com/Dao-AILab/flash-attention.git@2580b5a4882562640f3cfbffd2bb8d2de9268f9f#subdirectory=flash_attn/cute \ No newline at end of file +flash-attn-cute @ git+https://github.com/Dao-AILab/flash-attention.git@2580b5a4882562640f3cfbffd2bb8d2de9268f9f#subdirectory=flash_attn/cute +# nvimgcodec +nvidia-nvimgcodec-cu13==0.7.0.11 \ No newline at end of file diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index dbbbfce97d28..bfa14193e5ab 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -283,3 +283,86 @@ def test_custom_class_serialization_disallowed_without_pickle(): with pytest.raises(TypeError): # Attempt to encode the custom class encoder.encode(obj) + + +@dataclass +class RequestWithTensor: + """Mock request with non-multimodal tensor field like EngineCoreRequest.""" + + prompt_embeds: torch.Tensor | None + data: str + + +def test_non_multimodal_tensor_with_ipc(): + """Test that non-multimodal tensor fields work correctly with IPC enabled. + + This reproduces the bug where fields like prompt_embeds: torch.Tensor | None + would fail to decode when IPC is enabled because _decode_tensor expected a tuple + but received a TensorIpcHandle dict. + """ + import torch.multiprocessing as torch_mp + + # Create tensor queues for IPC + tensor_queues = [torch_mp.Queue()] + + # Create encoder with IPC enabled + encoder = MsgpackEncoder( + tensor_queues=tensor_queues, multimodal_tensor_ipc="torch_shm" + ) + encoder.set_target_engine(0) + encoder.set_request_context("test_request_123") + + # Create decoder with IPC queue + decoder = MsgpackDecoder(RequestWithTensor, tensor_queue=tensor_queues[0]) + + # Create a request with a non-multimodal tensor + original_tensor = torch.randn(5, 10, dtype=torch.float32) + request = RequestWithTensor(prompt_embeds=original_tensor, data="test_data") + + # Encode the request - this should send the tensor via IPC + encoded = encoder.encode(request) + + # Verify encoding succeeded + assert len(encoded) > 0 + + # Decode the request - this should retrieve the tensor from IPC queue + # Previously this would fail with: TypeError: cannot unpack non-iterable dict object + decoded = decoder.decode(encoded) + + # Verify the decoded request matches the original + assert isinstance(decoded, RequestWithTensor) + assert decoded.data == "test_data" + assert decoded.prompt_embeds is not None + assert torch.allclose(decoded.prompt_embeds, original_tensor), ( + "Decoded tensor does not match the original tensor." + ) + + +def test_non_multimodal_tensor_with_ipc_none_value(): + """Test that None values for tensor fields work correctly with IPC enabled.""" + import torch.multiprocessing as torch_mp + + # Create tensor queues for IPC + tensor_queues = [torch_mp.Queue()] + + # Create encoder with IPC enabled + encoder = MsgpackEncoder( + tensor_queues=tensor_queues, multimodal_tensor_ipc="torch_shm" + ) + encoder.set_target_engine(0) + encoder.set_request_context("test_request_456") + + # Create decoder with IPC queue + decoder = MsgpackDecoder(RequestWithTensor, tensor_queue=tensor_queues[0]) + + # Create a request with None for the tensor field + request = RequestWithTensor(prompt_embeds=None, data="test_data_with_none") + + # Encode and decode the request + encoded = encoder.encode(request) + decoded = decoder.decode(encoded) + + # Verify the decoded request matches the original + assert isinstance(decoded, RequestWithTensor) + assert decoded.data == "test_data_with_none" + assert decoded.prompt_embeds is None diff --git a/tests/v1/test_tensor_ipc_queue.py b/tests/v1/test_tensor_ipc_queue.py new file mode 100644 index 000000000000..f6ce283937d0 --- /dev/null +++ b/tests/v1/test_tensor_ipc_queue.py @@ -0,0 +1,787 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Tests for tensor IPC queue functionality.""" + +import contextlib +import multiprocessing as mp +from multiprocessing.synchronize import Barrier as BarrierType +from multiprocessing.synchronize import Event as EventType +from typing import Any + +import pytest +import torch +import torch.multiprocessing as torch_mp + +from vllm.v1.serial_utils import ( + MsgpackDecoder, + MsgpackEncoder, + TensorIpcData, + TensorIpcHandle, +) + + +@pytest.fixture(scope="module", autouse=True) +def setup_multiprocessing(): + """Set multiprocessing start method to 'spawn' for compatibility.""" + with contextlib.suppress(RuntimeError): + # Already set, which is fine + torch_mp.set_start_method("spawn", force=True) + yield + + +def encoder_process( + tensor_queues: list[torch_mp.Queue], + result_queue: mp.Queue, + target_engine: int, + tensor_data: dict[str, Any], + ready_event: EventType, +): + """Process that encodes and sends CUDA tensors via queue.""" + try: + # Create encoder with tensor queues + encoder = MsgpackEncoder(tensor_queues=tensor_queues) + encoder.set_target_engine(target_engine) + + # Create a CUDA tensor if available + if torch.cuda.is_available(): + device = "cuda:0" + tensor = torch.randn( + *tensor_data["shape"], dtype=tensor_data["dtype"], device=device + ) + else: + # Fall back to CPU for testing + device = "cpu" + tensor = torch.randn(*tensor_data["shape"], dtype=tensor_data["dtype"]) + + # Encode the tensor + encoded = encoder.encode({"test_tensor": tensor}) + + # Signal that encoding is complete before sending result + ready_event.set() + + result_queue.put( + { + "success": True, + "encoded_length": len(encoded), + "device": str(device), + "tensor_shape": tuple(tensor.shape), + } + ) + except Exception as e: + import traceback + + ready_event.set() # Signal even on failure + result_queue.put( + {"success": False, "error": str(e), "traceback": traceback.format_exc()} + ) + + +def decoder_process( + tensor_queue: torch_mp.Queue, + result_queue: mp.Queue, + expected_shape: tuple, + encoder_ready: EventType, +): + """Process that decodes and receives CUDA tensors from queue.""" + try: + # Wait for encoder to finish sending + if not encoder_ready.wait(timeout=10.0): + raise TimeoutError("Encoder did not signal ready") + + # Try to get tensor from queue directly for testing + ipc_data = tensor_queue.get(timeout=5.0) + + result_queue.put( + { + "success": True, + "tensor_id": ipc_data.tensor_id, + "tensor_shape": tuple(ipc_data.tensor.shape), + "device": str(ipc_data.tensor.device), + "matches_expected": tuple(ipc_data.tensor.shape) == expected_shape, + } + ) + except Exception as e: + import traceback + + result_queue.put( + {"success": False, "error": str(e), "traceback": traceback.format_exc()} + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_tensor_queue_basic(): + """Test basic CUDA tensor sharing via queue.""" + # Set up queues and synchronization + num_engines = 2 + tensor_queues = [torch_mp.Queue() for _ in range(num_engines)] + result_queue: mp.Queue = mp.Queue() + encoder_ready = mp.Event() + + target_engine = 0 + tensor_shape = (4, 8, 16) + tensor_dtype = torch.float32 + + # Start encoder process + encoder_proc = mp.Process( + target=encoder_process, + args=( + tensor_queues, + result_queue, + target_engine, + {"shape": tensor_shape, "dtype": tensor_dtype}, + encoder_ready, + ), + ) + encoder_proc.start() + + # Start decoder process + decoder_proc = mp.Process( + target=decoder_process, + args=(tensor_queues[target_engine], result_queue, tensor_shape, encoder_ready), + ) + decoder_proc.start() + + # Wait for processes and collect results + encoder_result = result_queue.get(timeout=10.0) + decoder_result = result_queue.get(timeout=10.0) + + encoder_proc.join(timeout=5.0) + decoder_proc.join(timeout=5.0) + + # Verify results + assert encoder_result["success"], ( + f"Encoder failed: {encoder_result.get('error')}\n" + f"{encoder_result.get('traceback', '')}" + ) + assert decoder_result["success"], ( + f"Decoder failed: {decoder_result.get('error')}\n" + f"{decoder_result.get('traceback', '')}" + ) + assert decoder_result["matches_expected"], "Tensor shape mismatch" + assert "cuda" in decoder_result["device"], "Tensor not on CUDA device" + + +def test_cpu_tensor_fallback(): + """Test that CPU tensors use standard serialization path.""" + encoder = MsgpackEncoder(tensor_queues=None) + + # Create a CPU tensor + tensor = torch.randn(3, 4, dtype=torch.float32) + + # Encode the tensor (should use standard path, not queue) + encoded = encoder.encode({"test_tensor": tensor}) + + # Verify encoding succeeded + assert len(encoded) > 0 + assert isinstance(encoded, (list, tuple)) + + # Basic check: no queue should be used, so tensor goes through standard path + # This is mainly to ensure no exceptions are raised + + +def test_encoder_without_target_engine(): + """Test that encoder handles missing target engine gracefully.""" + tensor_queues = [torch_mp.Queue()] + encoder = MsgpackEncoder(tensor_queues=tensor_queues) + + # Don't set target engine + if torch.cuda.is_available(): + tensor = torch.randn(2, 3, device="cuda:0") + else: + tensor = torch.randn(2, 3) + + # Should fall back to standard serialization + encoded = encoder.encode({"test_tensor": tensor}) + assert len(encoded) > 0 + + +def test_decoder_buffer_management(): + """Test decoder's tensor buffer management when draining queue.""" + tensor_queue = torch_mp.Queue() + + # Put multiple tensors in queue using TensorIpcData + tensors = { + "tensor_1": torch.randn(2, 3), + "tensor_2": torch.randn(4, 5), + "tensor_3": torch.randn(6, 7), + } + + for tensor_id, tensor in tensors.items(): + ipc_data = TensorIpcData(request_id=None, tensor_id=tensor_id, tensor=tensor) + tensor_queue.put(ipc_data) + + # Create decoder + decoder = MsgpackDecoder(tensor_queue=tensor_queue) + + # Request tensor_3 (should buffer tensor_1 and tensor_2) + handle = TensorIpcHandle( + request_id=None, + tensor_id="tensor_3", + shape=[6, 7], + dtype="float32", + device="cpu", + ) + + result = decoder._decode_ipc_queue_tensor(handle) + assert result.shape == (6, 7) + + # Verify buffer has tensor_1 and tensor_2 using tuple keys + assert (None, "tensor_1") in decoder._tensor_buffer + assert (None, "tensor_2") in decoder._tensor_buffer + + # Request buffered tensor + handle2 = TensorIpcHandle( + request_id=None, + tensor_id="tensor_1", + shape=[2, 3], + dtype="float32", + device="cpu", + ) + + result2 = decoder._decode_ipc_queue_tensor(handle2) + assert result2.shape == (2, 3) + # tensor_1 should be removed from buffer + assert (None, "tensor_1") not in decoder._tensor_buffer + + +def api_server_worker( + server_id: int, + tensor_queue: torch_mp.Queue, + result_queue: mp.Queue, + barrier: BarrierType, + retrieval_done: EventType, +): + """Worker simulating an API server sending tensors.""" + try: + # Each server sends a unique tensor + tensor = torch.ones(server_id + 1, server_id + 2) * server_id + tensor_id = f"server_{server_id}_tensor" + + # Wait for all servers to be ready + barrier.wait() + + # Send tensor using TensorIpcData + ipc_data = TensorIpcData(request_id=None, tensor_id=tensor_id, tensor=tensor) + tensor_queue.put(ipc_data) + + result_queue.put({"server_id": server_id, "success": True}) + + # Keep process alive until main process has retrieved all tensors + # This prevents shared memory handles from being invalidated + retrieval_done.wait(timeout=30.0) + except Exception as e: + import traceback + + result_queue.put( + { + "server_id": server_id, + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + } + ) + + +def test_multiple_api_servers_to_engine(): + """Test multiple API servers sending to one engine core via multiprocessing.""" + num_api_servers = 3 + tensor_queue = torch_mp.Queue() + result_queue: mp.Queue = mp.Queue() + barrier = mp.Barrier(num_api_servers) + retrieval_done = mp.Event() + + # Start multiple API server processes + processes = [] + for server_id in range(num_api_servers): + proc = mp.Process( + target=api_server_worker, + args=(server_id, tensor_queue, result_queue, barrier, retrieval_done), + ) + proc.start() + processes.append(proc) + + # Collect results from all servers + results = [] + for _ in range(num_api_servers): + result = result_queue.get(timeout=10.0) + results.append(result) + + # Verify all servers succeeded + for result in results: + assert result["success"], ( + f"Server {result['server_id']} failed: {result.get('error')}" + ) + + # Verify all tensors are in queue + received_tensors = [] + for _ in range(num_api_servers): + ipc_data = tensor_queue.get(timeout=1.0) + received_tensors.append((ipc_data.tensor_id, ipc_data.tensor)) + + assert len(received_tensors) == num_api_servers + + # Verify tensor content (order may vary with multiprocessing) + tensor_by_id = {tid: t for tid, t in received_tensors} + for server_id in range(num_api_servers): + expected_id = f"server_{server_id}_tensor" + assert expected_id in tensor_by_id, f"Missing tensor from server {server_id}" + expected_tensor = torch.ones(server_id + 1, server_id + 2) * server_id + assert torch.allclose(tensor_by_id[expected_id], expected_tensor) + + # Signal workers that retrieval is complete + retrieval_done.set() + + # Wait for all processes to complete + for proc in processes: + proc.join(timeout=5.0) + + +def mixed_tensor_encoder_process( + tensor_queues: list[torch_mp.Queue], + result_queue: mp.Queue, + ready_event: EventType, + retrieval_done: EventType, +): + """Process that encodes mixed CPU/CUDA tensors.""" + try: + encoder = MsgpackEncoder( + tensor_queues=tensor_queues, multimodal_tensor_ipc="torch_shm" + ) + encoder.set_target_engine(0) + + # Create only CUDA tensor for IPC (CPU will be serialized) + # But actually, let's just send CUDA tensor directly + cuda_tensor = torch.randn(4, 5, device="cuda:0") + + # Manually send via IPC to test the mechanism + tensor_id = "test_cuda_tensor" + cuda_tensor_shared = cuda_tensor.share_memory_() + from vllm.v1.serial_utils import TensorIpcData + + ipc_data = TensorIpcData( + request_id=None, tensor_id=tensor_id, tensor=cuda_tensor_shared + ) + tensor_queues[0].put(ipc_data, timeout=10.0) + + ready_event.set() + + result_queue.put( + { + "success": True, + "sent_cuda": True, + } + ) + + # Keep process alive until decoder has retrieved the tensor + retrieval_done.wait(timeout=30.0) + except Exception as e: + import traceback + + ready_event.set() + result_queue.put( + {"success": False, "error": str(e), "traceback": traceback.format_exc()} + ) + + +def mixed_tensor_decoder_process( + tensor_queue: torch_mp.Queue, + result_queue: mp.Queue, + encoder_ready: EventType, + retrieval_done: EventType, +): + """Process that retrieves mixed tensors from queue.""" + try: + # Wait for encoder to finish + if not encoder_ready.wait(timeout=10.0): + raise TimeoutError("Encoder did not signal ready") + + # Try to get CUDA tensor from queue + ipc_data = tensor_queue.get(timeout=5.0) + + result_queue.put( + { + "success": True, + "is_cuda": ipc_data.tensor.is_cuda, + "shape": tuple(ipc_data.tensor.shape), + } + ) + + # Signal that retrieval is complete + retrieval_done.set() + except Exception as e: + import traceback + + retrieval_done.set() # Signal even on failure + result_queue.put( + {"success": False, "error": str(e), "traceback": traceback.format_exc()} + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_mixed_cpu_cuda_tensors(): + """Test encoding with mixed CPU and CUDA tensors using multiprocessing.""" + tensor_queues = [torch_mp.Queue()] + result_queue: mp.Queue = mp.Queue() + encoder_ready = mp.Event() + retrieval_done = mp.Event() + + # Start encoder process + encoder_proc = mp.Process( + target=mixed_tensor_encoder_process, + args=(tensor_queues, result_queue, encoder_ready, retrieval_done), + ) + encoder_proc.start() + + # Start decoder process + decoder_proc = mp.Process( + target=mixed_tensor_decoder_process, + args=(tensor_queues[0], result_queue, encoder_ready, retrieval_done), + ) + decoder_proc.start() + + # Get results + encoder_result = result_queue.get(timeout=10.0) + decoder_result = result_queue.get(timeout=10.0) + + encoder_proc.join(timeout=5.0) + decoder_proc.join(timeout=5.0) + + # Verify encoder succeeded + assert encoder_result["success"], ( + f"Encoder failed: {encoder_result.get('error')}\n" + f"{encoder_result.get('traceback', '')}" + ) + + # Verify decoder succeeded and got CUDA tensor + assert decoder_result["success"], ( + f"Decoder failed: {decoder_result.get('error')}\n" + f"{decoder_result.get('traceback', '')}" + ) + assert decoder_result["is_cuda"], "Retrieved tensor is not on CUDA" + assert decoder_result["shape"] == (4, 5), ( + f"Unexpected shape: {decoder_result['shape']}" + ) + + +def cpu_tensor_ipc_encoder_process( + tensor_queues: list[torch_mp.Queue], + result_queue: mp.Queue, + target_engine: int, + tensor_shape: tuple, + ready_event: EventType, + retrieval_done: EventType, +): + """Process that encodes and sends CPU tensors via IPC queue.""" + try: + # Create encoder with IPC enabled for all tensors + encoder = MsgpackEncoder( + tensor_queues=tensor_queues, multimodal_tensor_ipc="torch_shm" + ) + encoder.set_target_engine(target_engine) + + # Create a CPU tensor + tensor = torch.randn(*tensor_shape, dtype=torch.float32) + + # Encode the tensor (should use IPC queue, not standard serialization) + encoded = encoder.encode({"test_tensor": tensor}) + + # Signal that encoding is complete + ready_event.set() + + result_queue.put( + { + "success": True, + "encoded_length": len(encoded), + "device": str(tensor.device), + "tensor_shape": tuple(tensor.shape), + } + ) + + # Keep process alive until decoder has retrieved the tensor + # This is necessary for CPU tensor shared memory to remain valid + retrieval_done.wait(timeout=30.0) + except Exception as e: + import traceback + + ready_event.set() + result_queue.put( + {"success": False, "error": str(e), "traceback": traceback.format_exc()} + ) + + +def cpu_tensor_ipc_decoder_process( + tensor_queue: torch_mp.Queue, + result_queue: mp.Queue, + expected_shape: tuple, + encoder_ready: EventType, + retrieval_done: EventType, +): + """Process that decodes and receives CPU tensors from IPC queue.""" + try: + # Wait for encoder to finish sending + if not encoder_ready.wait(timeout=10.0): + raise TimeoutError("Encoder did not signal ready") + + # Get tensor from queue + ipc_data = tensor_queue.get(timeout=5.0) + + result_queue.put( + { + "success": True, + "tensor_id": ipc_data.tensor_id, + "tensor_shape": tuple(ipc_data.tensor.shape), + "device": str(ipc_data.tensor.device), + "matches_expected": tuple(ipc_data.tensor.shape) == expected_shape, + "is_cpu": ipc_data.tensor.device.type == "cpu", + } + ) + + # Signal that retrieval is complete + retrieval_done.set() + except Exception as e: + import traceback + + retrieval_done.set() # Signal even on failure + result_queue.put( + {"success": False, "error": str(e), "traceback": traceback.format_exc()} + ) + + +def test_cpu_tensor_ipc(): + """Test CPU tensor sharing via IPC queue when multimodal_tensor_ipc is enabled.""" + # Set up queues and synchronization + num_engines = 2 + tensor_queues = [torch_mp.Queue() for _ in range(num_engines)] + result_queue: mp.Queue = mp.Queue() + encoder_ready = mp.Event() + retrieval_done = mp.Event() + + target_engine = 0 + tensor_shape = (3, 5, 7) + + # Start encoder process + encoder_proc = mp.Process( + target=cpu_tensor_ipc_encoder_process, + args=( + tensor_queues, + result_queue, + target_engine, + tensor_shape, + encoder_ready, + retrieval_done, + ), + ) + encoder_proc.start() + + # Start decoder process + decoder_proc = mp.Process( + target=cpu_tensor_ipc_decoder_process, + args=( + tensor_queues[target_engine], + result_queue, + tensor_shape, + encoder_ready, + retrieval_done, + ), + ) + decoder_proc.start() + + # Wait for processes and collect results + encoder_result = result_queue.get(timeout=10.0) + decoder_result = result_queue.get(timeout=10.0) + + encoder_proc.join(timeout=5.0) + decoder_proc.join(timeout=5.0) + + # Verify results + assert encoder_result["success"], ( + f"Encoder failed: {encoder_result.get('error')}\n" + f"{encoder_result.get('traceback', '')}" + ) + assert decoder_result["success"], ( + f"Decoder failed: {decoder_result.get('error')}\n" + f"{decoder_result.get('traceback', '')}" + ) + assert decoder_result["matches_expected"], "Tensor shape mismatch" + assert decoder_result["is_cpu"], "Tensor not on CPU device" + + +def test_ipc_disabled_mode(): + """Test that IPC is disabled when multimodal_tensor_ipc="direct_rpc".""" + tensor_queues = [torch_mp.Queue()] + + # Create encoder with IPC disabled + encoder = MsgpackEncoder( + tensor_queues=tensor_queues, multimodal_tensor_ipc="direct_rpc" + ) + encoder.set_target_engine(0) + + # Create a CPU tensor + cpu_tensor = torch.randn(2, 3, dtype=torch.float32) + + # Encode the tensor (should use standard serialization, not IPC) + encoded = encoder.encode({"test_tensor": cpu_tensor}) + + # Verify encoding succeeded + assert len(encoded) > 0 + assert isinstance(encoded, (list, tuple)) + + # Verify queue is empty (no IPC was used) + assert tensor_queues[0].empty(), "Tensor queue should be empty when IPC is disabled" + + # If CUDA is available, test with CUDA tensor too + if torch.cuda.is_available(): + cuda_tensor = torch.randn(4, 5, device="cuda:0") + encoded_cuda = encoder.encode({"cuda_tensor": cuda_tensor}) + assert len(encoded_cuda) > 0 + assert tensor_queues[0].empty(), ( + "Tensor queue should be empty for CUDA tensor when IPC is disabled" + ) + + +def test_mixed_cpu_cuda_with_ipc_enabled(): + """Test that encoder is configured correctly for IPC with all tensor types.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + tensor_queues = [torch_mp.Queue()] + + # Create encoder with IPC enabled for all tensors + encoder = MsgpackEncoder( + tensor_queues=tensor_queues, multimodal_tensor_ipc="torch_shm" + ) + encoder.set_target_engine(0) + + # Verify encoder configuration + assert encoder.multimodal_tensor_ipc == "torch_shm", ( + "Torch queue-based IPC should be enabled" + ) + assert encoder.tensor_queues is not None, "Tensor queues should be set" + assert encoder.target_engine_index == 0, "Target engine should be set" + + # Note: Actual IPC transfer only works across processes + # (tested in test_cpu_tensor_ipc) + # This test just verifies the configuration is correct + + +def test_tensor_cleanup_on_abort(): + """Test that orphaned tensors are cleaned up when requests are aborted.""" + # Create a tensor queue (not actually used in this simplified test) + tensor_queue = torch_mp.Queue() + + # Create decoder + decoder = MsgpackDecoder(dict, tensor_queue=tensor_queue) + + # Simulate tensors in the buffer for multiple requests + request_ids = ["req1", "req2", "req3"] + + for request_id in request_ids: + # Simulate 2 tensors per request using tuple keys + for i in range(2): + tensor_id = f"encoder_{i}" + tensor_key = (request_id, tensor_id) + tensor = torch.randn(10, 10) + + # Manually add to buffer and tracking (simulating decode behavior) + decoder._tensor_buffer[tensor_key] = tensor + + if request_id not in decoder._request_to_tensors: + decoder._request_to_tensors[request_id] = [] + decoder._request_to_tensors[request_id].append(tensor_key) + + # Verify tensors are in the buffer + initial_buffer_size = len(decoder._tensor_buffer) + assert initial_buffer_size == 6, "Buffer should contain 6 tensors (2 per request)" + + # Verify request tracking + assert len(decoder._request_to_tensors) == 3, "Should track 3 requests" + assert len(decoder._request_to_tensors["req1"]) == 2, "req1 should have 2 tensors" + + # Cleanup tensors for req1 + removed_count_1 = decoder.cleanup_request_tensors("req1") + assert removed_count_1 == 2, "Should have removed 2 tensors for req1" + assert len(decoder._tensor_buffer) == 4, "Buffer should have 4 tensors left" + assert "req1" not in decoder._request_to_tensors, ( + "req1 should be removed from tracking" + ) + + # Cleanup tensors for req2 + removed_count_2 = decoder.cleanup_request_tensors("req2") + assert removed_count_2 == 2, "Should have removed 2 tensors for req2" + assert len(decoder._tensor_buffer) == 2, "Buffer should have 2 tensors left" + + # Cleanup req3 + removed_count_3 = decoder.cleanup_request_tensors("req3") + assert removed_count_3 == 2, "Should have removed 2 tensors for req3" + + # Verify all tensors are cleaned up + assert len(decoder._tensor_buffer) == 0, "Buffer should be empty" + assert len(decoder._request_to_tensors) == 0, "Request tracking should be empty" + + # Cleanup for non-existent request should return 0 + removed_count_4 = decoder.cleanup_request_tensors("nonexistent") + assert removed_count_4 == 0, "Should return 0 for non-existent request" + + +def test_tensor_cleanup_after_decode(): + """Test that tensors are removed from tracking after successful decode.""" + # Create a tensor queue + tensor_queue = torch_mp.Queue() + + # Create and encode a tensor + tensor = torch.randn(5, 5) + # Move to shared memory for IPC + if not tensor.is_shared(): + tensor.share_memory_() + + # Manually create a TensorIpcData and put it in the queue + request_id = "test_req" + tensor_id = "encoder_0" + ipc_data = TensorIpcData(request_id=request_id, tensor_id=tensor_id, tensor=tensor) + tensor_queue.put(ipc_data) + + # Create decoder + decoder = MsgpackDecoder(dict, tensor_queue=tensor_queue) + + # Create a TensorIpcHandle to decode + handle = TensorIpcHandle( + request_id=request_id, + tensor_id=tensor_id, + shape=list(tensor.shape), + dtype=str(tensor.dtype).removeprefix("torch."), + device=str(tensor.device), + ) + + # Decode the tensor - this should retrieve it from the queue + decoded_tensor = decoder._decode_ipc_queue_tensor(handle) + + # Verify the tensor was decoded + assert decoded_tensor.shape == tensor.shape, "Decoded tensor should match shape" + + # Verify the tensor was removed from buffer after decode + tensor_key = (request_id, tensor_id) + assert tensor_key not in decoder._tensor_buffer, ( + "Tensor should be removed from buffer" + ) + + # Verify the request tracking was cleaned up + assert request_id not in decoder._request_to_tensors, ( + "Request tracking should be cleaned up" + ) + + +def test_request_context_in_encoder(): + """Test that encoder properly sets and clears request context.""" + encoder = MsgpackEncoder() + + # Initially no request context + assert encoder._current_request_id is None + + # Set request context + encoder.set_request_context("req123") + assert encoder._current_request_id == "req123" + + # Clear request context + encoder.set_request_context(None) + assert encoder._current_request_id is None diff --git a/vllm/config/model.py b/vllm/config/model.py index df25e900c354..f76fb571e237 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -311,6 +311,7 @@ class ModelConfig: interleave_mm_strings: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None video_pruning_rate: InitVar[float | None] = None + multimodal_tensor_ipc: InitVar[Literal["direct_rpc", "torch_shm"]] = "direct_rpc" def compute_hash(self) -> str: """ @@ -425,6 +426,7 @@ def __post_init__( interleave_mm_strings: bool | None, skip_mm_profiling: bool | None, video_pruning_rate: float | None, + multimodal_tensor_ipc: Literal["direct_rpc", "torch_shm"], ) -> None: # Keep set served_model_name before maybe_model_redirect(self.model) self.served_model_name = get_served_model_name( @@ -588,6 +590,7 @@ def __post_init__( interleave_mm_strings=interleave_mm_strings, skip_mm_profiling=skip_mm_profiling, video_pruning_rate=video_pruning_rate, + multimodal_tensor_ipc=multimodal_tensor_ipc, ) mm_config_kwargs = { diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index ecb346af8f3c..4bc6d36789d9 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -140,6 +140,11 @@ class MultiModalConfig: Value sits in range [0;1) and determines fraction of media tokens from each video to be pruned. """ + multimodal_tensor_ipc: Literal["direct_rpc", "torch_shm"] = "direct_rpc" + """IPC (inter-process communication) method for multimodal tensors. + - "direct_rpc": Use msgspec serialization via RPC + - "torch_shm": Use torch.multiprocessing shared memory for zero-copy IPC + Defaults to "direct_rpc". """ @field_validator("limit_per_prompt", mode="before") @classmethod diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cb82be6b6b6f..6921ddd9c1b4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -474,6 +474,9 @@ class EngineArgs: io_processor_plugin: str | None = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling video_pruning_rate: float = MultiModalConfig.video_pruning_rate + multimodal_tensor_ipc: Literal["direct_rpc", "torch_shm"] = ( + MultiModalConfig.multimodal_tensor_ipc + ) # LoRA fields enable_lora: bool = False max_loras: int = LoRAConfig.max_loras @@ -990,6 +993,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: multimodal_group.add_argument( "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"] ) + multimodal_group.add_argument( + "--multimodal-tensor-ipc", + **multimodal_kwargs["multimodal_tensor_ipc"], + ) # LoRA related configs lora_kwargs = get_kwargs(LoRAConfig) @@ -1267,6 +1274,7 @@ def create_model_config(self) -> ModelConfig: override_attention_dtype=self.override_attention_dtype, logits_processors=self.logits_processors, video_pruning_rate=self.video_pruning_rate, + multimodal_tensor_ipc=self.multimodal_tensor_ipc, io_processor_plugin=self.io_processor_plugin, ) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index f06a9391321f..96e59b08eb95 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -209,6 +209,7 @@ def run_multi_api_server(args: argparse.Namespace): stats_update_address=coordinator.get_stats_publish_address() if coordinator else None, + tensor_queues=addresses.tensor_queues, ) # For dp ranks > 0 in external/hybrid DP LB modes, we must delay the diff --git a/vllm/envs.py b/vllm/envs.py index f82dae108f6a..4d12ea327833 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1758,6 +1758,7 @@ def compile_factors() -> dict[str, object]: "VLLM_MAX_AUDIO_CLIP_FILESIZE_MB", "VLLM_VIDEO_LOADER_BACKEND", "VLLM_MEDIA_CONNECTOR", + "VLLM_MULTIMODAL_TENSOR_IPC", "VLLM_ASSETS_CACHE", "VLLM_ASSETS_CACHE_MODEL_CLEAN", "VLLM_WORKER_MULTIPROC_METHOD", diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 7b12158763c3..a76caa1044f3 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -477,7 +477,7 @@ def reduce_data( pin_memory = False batch = [elem.data for elem in elems] - out = self._reduce_data(batch, pin_memory=pin_memory) + out = self._reduce_data(batch, pin_memory=pin_memory and device.type != "cuda") return _nested_tensors_h2d(out, device=device) diff --git a/vllm/multimodal/media/image.py b/vllm/multimodal/media/image.py index 977a67007363..00c741ee663a 100644 --- a/vllm/multimodal/media/image.py +++ b/vllm/multimodal/media/image.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading from io import BytesIO from pathlib import Path import pybase64 import torch +from nvidia import nvimgcodec from PIL import Image from vllm.logger import init_logger @@ -15,6 +17,16 @@ logger = init_logger(__file__) +# Thread-local storage for nvimgcodec decoder +_thread_local = threading.local() + + +def _get_decoder() -> nvimgcodec.Decoder: + """Get a per-thread nvimgcodec decoder instance.""" + if not hasattr(_thread_local, "decoder"): + _thread_local.decoder = nvimgcodec.Decoder() + return _thread_local.decoder + class ImageMediaIO(MediaIO[Image.Image]): def __init__(self, image_mode: str = "RGB", **kwargs) -> None: @@ -60,18 +72,38 @@ def _convert_image_mode( else: return convert_image_mode(image, self.image_mode) - def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]: + def load_bytes( + self, data: bytes + ) -> MediaWithBytes[Image.Image] | MediaWithBytes[torch.Tensor]: + # return self.load_pil_image(data) + return self.load_nvimgcodec_image(data) + + def load_base64( + self, media_type: str, data: str + ) -> MediaWithBytes[Image.Image] | MediaWithBytes[torch.Tensor]: + return self.load_bytes(pybase64.b64decode(data, validate=True)) + + def load_pil_image(self, data: bytes) -> MediaWithBytes[Image.Image]: image = Image.open(BytesIO(data)) return MediaWithBytes(self._convert_image_mode(image), data) - def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]: - return self.load_bytes(pybase64.b64decode(data, validate=True)) + def load_nvimgcodec_image(self, data: bytes) -> MediaWithBytes[torch.Tensor]: + code_stream = nvimgcodec.CodeStream(data) + decoded = _get_decoder().decode(code_stream) + + device = "cuda:0" + tensor = torch.as_tensor(decoded, device=device) + # HWC -> CHW + tensor = tensor.permute(2, 0, 1) + + return MediaWithBytes(tensor, data) - def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]: + def load_file( + self, filepath: Path + ) -> MediaWithBytes[Image.Image] | MediaWithBytes[torch.Tensor]: with open(filepath, "rb") as f: data = f.read() - image = Image.open(BytesIO(data)) - return MediaWithBytes(self._convert_image_mode(image), data) + return self.load_bytes(data) def encode_base64( self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 141e5a459c5b..ac8b637b6a9a 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -395,8 +395,19 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: scheduler_output, model_output ) + # Cleanup tensors for finished requests + self._cleanup_finished_request_tensors(scheduler_output.finished_req_ids) + return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 + def _cleanup_finished_request_tensors(self, finished_req_ids: set[str]) -> None: + """Cleanup any orphaned tensors for finished requests. + + This is a no-op in the base class but can be overridden in subclasses + to perform actual cleanup (e.g., for IPC tensor queues). + """ + pass + def post_step(self, model_executed: bool) -> None: # When using async scheduling we can't get draft token ids in advance, # so we update draft token ids in the worker process and don't @@ -497,6 +508,9 @@ def step_with_batch_queue( scheduler_output, model_output ) + # Cleanup tensors for finished requests + self._cleanup_finished_request_tensors(scheduler_output.finished_req_ids) + # NOTE(nick): We can either handle the deferred tasks here or save # in a field and do it immediately once step_with_batch_queue is # re-called. The latter slightly favors TTFT over TPOT/throughput. @@ -648,6 +662,7 @@ def __init__( client_handshake_address: str | None = None, *, engine_index: int = 0, + tensor_queues: list[Any] | None = None, ): self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]() @@ -659,6 +674,9 @@ def __init__( identity = self.engine_index.to_bytes(length=2, byteorder="little") self.engines_running = False + # Decoder for cleanup (set by process_input_sockets thread) + self.tensor_decoder: MsgpackDecoder | None = None + with self._perform_handshakes( handshake_address, identity, @@ -668,6 +686,16 @@ def __init__( ) as addresses: self.client_count = len(addresses.outputs) + # Get this engine's tensor IPC queue for receiving multimodal tensors + # Queues are passed directly via constructor since they can't be serialized + self.tensor_queue = None + if tensor_queues and addresses.tensor_queue_index is not None: + self.tensor_queue = tensor_queues[addresses.tensor_queue_index] + logger.info( + "Engine %d using tensor IPC queue for multimodal tensor sharing", + self.engine_index, + ) + # Set up data parallel environment. self.has_coordinator = addresses.coordinator_output is not None self.frontend_stats_publish_address = ( @@ -875,10 +903,20 @@ def startup_handshake( for key, value in init_message.parallel_config.items(): setattr(parallel_config, key, value) - return init_message.addresses + # Store tensor_queue_index for engine to access + addresses = init_message.addresses + addresses.tensor_queue_index = init_message.tensor_queue_index + + return addresses @staticmethod - def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): + def run_engine_core( + *args, + dp_rank: int = 0, + local_dp_rank: int = 0, + tensor_queues: list[Any] | None = None, + **kwargs, + ): """Launch EngineCore busy loop in background process.""" # Signal handler used for graceful termination. @@ -915,7 +953,7 @@ def signal_handler(signum, frame): if data_parallel and vllm_config.model_config.is_moe: # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank - engine_core = DPEngineCoreProc(*args, **kwargs) + engine_core = DPEngineCoreProc(**kwargs, tensor_queues=tensor_queues) else: # Non-MoE DP ranks are completely independent, so treat like DP=1. # Note that parallel_config.data_parallel_index will still reflect @@ -923,7 +961,9 @@ def signal_handler(signum, frame): parallel_config.data_parallel_size = 1 parallel_config.data_parallel_size_local = 1 parallel_config.data_parallel_rank = 0 - engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs) + engine_core = EngineCoreProc( + **kwargs, engine_index=dp_rank, tensor_queues=tensor_queues + ) engine_core.run_busy_loop() @@ -1001,6 +1041,22 @@ def _process_engine_step(self) -> bool: return model_executed + def abort_requests(self, request_ids: list[str]): + """Abort requests and cleanup any orphaned tensors.""" + # First, abort the requests in the scheduler + super().abort_requests(request_ids) + + # Then cleanup any orphaned tensors for these requests + if self.tensor_decoder is not None: + for request_id in request_ids: + self.tensor_decoder.cleanup_request_tensors(request_id) + + def _cleanup_finished_request_tensors(self, finished_req_ids: set[str]) -> None: + """Cleanup any orphaned tensors for finished requests.""" + if self.tensor_decoder is not None: + for request_id in finished_req_ids: + self.tensor_decoder.cleanup_request_tensors(request_id) + def _handle_client_request( self, request_type: EngineCoreRequestType, request: Any ) -> None: @@ -1073,9 +1129,14 @@ def process_input_sockets( ): """Input socket IO thread.""" - # Msgpack serialization decoding. - add_request_decoder = MsgpackDecoder(EngineCoreRequest) - generic_decoder = MsgpackDecoder() + # Msgpack serialization decoding with tensor queue for CUDA tensors. + add_request_decoder = MsgpackDecoder( + EngineCoreRequest, tensor_queue=self.tensor_queue + ) + generic_decoder = MsgpackDecoder(tensor_queue=self.tensor_queue) + + # Store decoder reference for tensor cleanup on abort + self.tensor_decoder = add_request_decoder with ExitStack() as stack, zmq.Context() as ctx: input_sockets = [ @@ -1252,6 +1313,7 @@ def __init__( executor_class: type[Executor], log_stats: bool, client_handshake_address: str | None = None, + tensor_queues: list[Any] | None = None, ): assert vllm_config.model_config.is_moe, ( "DPEngineCoreProc should only be used for MoE models" @@ -1273,6 +1335,7 @@ def __init__( log_stats, client_handshake_address, engine_index=dp_rank, + tensor_queues=tensor_queues, ) def _init_data_parallel(self, vllm_config: VllmConfig): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index c9a1d53c8fb7..08ecec12a1e8 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -59,6 +59,33 @@ EngineIdentity = bytes +@contextlib.contextmanager +def encoder_request_context( + encoder: MsgpackEncoder, + engine: EngineIdentity, + request_type: EngineCoreRequestType, + request: Any, +): + """Context manager for setting encoder state during request encoding. + + Sets the target engine and request context (for ADD requests) on entry, + and clears the request context on exit. + """ + # Set target engine index for tensor routing + engine_index = int.from_bytes(engine, "little") + encoder.set_target_engine(engine_index) + + # Set request context if this is an ADD request with a request_id + if request_type == EngineCoreRequestType.ADD and hasattr(request, "request_id"): + encoder.set_request_context(request.request_id) + + try: + yield encoder + finally: + # Clear request context after encoding + encoder.set_request_context(None) + + class EngineCoreClient(ABC): """ EngineCoreClient: subclasses handle different methods for pushing @@ -360,6 +387,7 @@ class BackgroundResources: output_queue_task: asyncio.Task | None = None stats_update_task: asyncio.Task | None = None shutdown_path: str | None = None + tensor_queues: list[Any] | None = None # Set if any of the engines are dead. Here so that the output # processing threads can access it without holding a ref to the client. @@ -450,9 +478,6 @@ def __init__( client_addresses: dict[str, str] | None = None, ): self.vllm_config = vllm_config - # Serialization setup. - self.encoder = MsgpackEncoder() - self.decoder = MsgpackDecoder(EngineCoreOutputs) # ZMQ setup. sync_ctx = zmq.Context(io_threads=2) @@ -469,11 +494,14 @@ def __init__( self.engines_running = False self.stats_update_address: str | None = None + tensor_queues: list[Any] | None = None if client_addresses: # Engines are managed externally to this client. input_address = client_addresses["input_address"] output_address = client_addresses["output_address"] self.stats_update_address = client_addresses.get("stats_update_address") + # Tensor queues passed via client_addresses for multi-API-server case + tensor_queues = client_addresses.get("tensor_queues") # type: ignore[assignment] else: # Engines are managed by this client. with launch_core_engines(vllm_config, executor_class, log_stats) as ( @@ -487,11 +515,28 @@ def __init__( (input_address,) = addresses.inputs (output_address,) = addresses.outputs self.stats_update_address = addresses.frontend_stats_publish_address + tensor_queues = addresses.tensor_queues if coordinator is not None: assert self.stats_update_address == ( coordinator.get_stats_publish_address() ) + # Serialization setup with tensor queues for multimodal tensor IPC. + # Get IPC config from multimodal_config, falling back to default + multimodal_tensor_ipc = "direct_rpc" # Default + if vllm_config.model_config.multimodal_config is not None: + multimodal_tensor_ipc = ( + vllm_config.model_config.multimodal_config.multimodal_tensor_ipc + ) + + self.encoder = MsgpackEncoder( + tensor_queues=tensor_queues, + multimodal_tensor_ipc=multimodal_tensor_ipc, + ) + self.decoder = MsgpackDecoder(EngineCoreOutputs) + # Store tensor queues for routing + self.resources.tensor_queues = tensor_queues + # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( self.ctx, input_address, zmq.ROUTER, bind=True @@ -721,8 +766,12 @@ def get_output(self) -> EngineCoreOutputs: def _send_input(self, request_type: EngineCoreRequestType, request: Any): self.ensure_alive() self.free_pending_messages() + # (Identity, RequestType, SerializedRequest) - msg = (self.core_engine, request_type.value, *self.encoder.encode(request)) + with encoder_request_context( + self.encoder, self.core_engine, request_type, request + ): + msg = (self.core_engine, request_type.value, *self.encoder.encode(request)) if len(msg) <= 3: # No auxiliary buffers => no tensor backing buffers in request. @@ -903,7 +952,9 @@ def _send_input( if engine is None: engine = self.core_engine - message = (request_type.value, *self.encoder.encode(request)) + with encoder_request_context(self.encoder, engine, request_type, request): + message = (request_type.value, *self.encoder.encode(request)) + return self._send_input_message(message, engine, request) def _send_input_message( diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 5db3a53266f0..eb24fe368499 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -9,10 +9,11 @@ from enum import Enum, auto from multiprocessing import Process, connection from multiprocessing.process import BaseProcess -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from unittest.mock import patch import msgspec +import torch.multiprocessing as torch_mp import zmq from vllm import envs @@ -64,6 +65,11 @@ class EngineZmqAddresses: # Not used by engine, just relayed to front-end in handshake response. # Only required for external DP LB case. frontend_stats_publish_address: str | None = None + # Tensor IPC queues for sharing CUDA tensors between API servers and engines + # One queue per engine core for direct GPU tensor transfer + tensor_queues: list[Any] | None = None + # Index of this engine's tensor queue (set during handshake) + tensor_queue_index: int | None = None @dataclass @@ -75,6 +81,8 @@ class EngineHandshakeMetadata: addresses: EngineZmqAddresses parallel_config: dict[str, int | str | list[int]] + # Index of this engine's tensor queue in addresses.tensor_queues + tensor_queue_index: int | None = None class CoreEngineProcManager: @@ -95,6 +103,7 @@ def __init__( executor_class: type[Executor], log_stats: bool, client_handshake_address: str | None = None, + tensor_queues: list[Any] | None = None, ): context = get_mp_context() common_kwargs = { @@ -108,6 +117,9 @@ def __init__( if client_handshake_address: common_kwargs["client_handshake_address"] = client_handshake_address + # Store tensor_queues for passing to engine processes + self.tensor_queues = tensor_queues + self.processes: list[BaseProcess] = [] local_dp_ranks = [] for index in range(local_engine_count): @@ -124,6 +136,7 @@ def __init__( | { "dp_rank": global_index, "local_dp_rank": local_index, + "tensor_queues": tensor_queues, }, ) ) @@ -800,6 +813,13 @@ def launch_core_engines( offline_mode or local_engines_only or (local_engine_count == dp_size) ) + # Create tensor IPC queues for sharing multimodal tensors between API servers + # and engine cores. One queue per engine core. + # Use torch.multiprocessing for tensor sharing via IPC/shared memory. + # Set start method to 'spawn' for compatibility with multiprocessing. + torch_mp.set_start_method("spawn", force=True) + tensor_queues: list[torch_mp.Queue] = [torch_mp.Queue() for _ in range(dp_size)] + # Set up input and output addresses. addresses = EngineZmqAddresses( inputs=[ @@ -810,6 +830,7 @@ def launch_core_engines( get_engine_client_zmq_addr(client_local_only, host) for _ in range(num_api_servers) ], + tensor_queues=tensor_queues, ) # Run the DP Coordinator process with rank 0 when in online DP mode. @@ -908,6 +929,7 @@ def launch_core_engines( local_engine_count=local_engine_count, start_index=dp_rank, local_start_index=local_start_index or 0, + tensor_queues=tensor_queues, ) else: local_engine_manager = None @@ -1015,9 +1037,21 @@ def wait_for_engine_startup( if status == "HELLO" and engine.state == CoreEngineState.NEW: # Send init message with DP config info. + # Note: tensor_queues are excluded from serialization as they can't be + # serialized by msgspec. They are passed directly to engine processes + # when spawning them. + addresses_for_handshake = EngineZmqAddresses( + inputs=addresses.inputs, + outputs=addresses.outputs, + coordinator_input=addresses.coordinator_input, + coordinator_output=addresses.coordinator_output, + frontend_stats_publish_address=addresses.frontend_stats_publish_address, + tensor_queues=None, # Don't serialize queues + tensor_queue_index=None, # Will be set separately + ) init_message = msgspec.msgpack.encode( EngineHandshakeMetadata( - addresses=addresses, + addresses=addresses_for_handshake, parallel_config={ k: getattr(parallel_config, k) for k in ( @@ -1029,6 +1063,7 @@ def wait_for_engine_startup( } if coordinated_dp else {}, + tensor_queue_index=eng_index, ) ) handshake_socket.send_multipart((eng_identity, init_message), copy=False) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index a3c30e368b82..9e09d2145686 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -4,11 +4,12 @@ import dataclasses import importlib import pickle +import threading from collections.abc import Callable, Sequence from functools import partial from inspect import isclass from types import FunctionType -from typing import Any, TypeAlias, get_type_hints +from typing import Any, NamedTuple, TypeAlias, get_type_hints import cloudpickle import msgspec @@ -41,6 +42,38 @@ CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_RAW_VIEW = 3 + +@dataclasses.dataclass +class TensorIpcData: + """ + Data sent via torch.multiprocessing.Queue for zero-copy IPC. + + Contains the request_id, tensor_id and the actual tensor. The tensor is shared + in memory (GPU or CPU) for efficient inter-process communication. + """ + + request_id: str | None + tensor_id: str + tensor: torch.Tensor + + +class TensorIpcHandle(NamedTuple): + """ + Handle for a tensor sent via IPC queue (zero-copy transfer). + + Contains only metadata about the tensor. This is serialized via msgpack + and used by the decoder to retrieve the actual tensor from the queue. + The actual tensor is sent separately via torch.multiprocessing.Queue + as TensorIpcData. Works for both CUDA and CPU tensors. + """ + + request_id: str | None + tensor_id: str + shape: tuple[int, ...] + dtype: str + device: str + + # MultiModalField class serialization type map. # These need to list all possible field types and match them # to factory methods in `MultiModalFieldConfig`. @@ -119,9 +152,19 @@ class MsgpackEncoder: By default, arrays below 256B are serialized inline Larger will get sent via dedicated messages. Note that this is a per-tensor limit. + + When multimodal_tensor_ipc is "torch_shm" and tensor_queues is provided, + all multimodal tensors (CUDA and CPU) will be sent via + torch.multiprocessing.Queue for zero-copy IPC instead of serialization. + When "direct_rpc", tensors use standard msgspec serialization. """ - def __init__(self, size_threshold: int | None = None): + def __init__( + self, + size_threshold: int | None = None, + tensor_queues: list[Any] | None = None, + multimodal_tensor_ipc: str = "direct_rpc", + ): if size_threshold is None: size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) @@ -130,9 +173,27 @@ def __init__(self, size_threshold: int | None = None): # pass custom data to the hook otherwise. self.aux_buffers: list[bytestr] | None = None self.size_threshold = size_threshold + # Tensor IPC queues for sharing multimodal tensors (one per engine core) + self.tensor_queues = tensor_queues + # IPC method for multimodal tensors + self.multimodal_tensor_ipc = multimodal_tensor_ipc + # Target engine index for routing tensors to the correct queue + self.target_engine_index: int | None = None + # Counter for generating unique tensor IDs + self._tensor_id_counter = 0 + # Current request ID being encoded (for associating tensors with requests) + self._current_request_id: str | None = None if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() + def set_target_engine(self, engine_index: int | None) -> None: + """Set the target engine index for routing multimodal tensors to IPC queues.""" + self.target_engine_index = engine_index + + def set_request_context(self, request_id: str | None) -> None: + """Set the current request ID being encoded (for tensor association).""" + self._current_request_id = request_id + def encode(self, obj: Any) -> Sequence[bytestr]: try: self.aux_buffers = bufs = [b""] @@ -220,10 +281,90 @@ def _encode_ndarray( # backing buffers that we've stashed in `aux_buffers`. return obj.dtype.str, obj.shape, data + def _encode_ipc_queue_tensor(self, obj: torch.Tensor) -> TensorIpcHandle: + """Send tensor via torch.multiprocessing.Queue for zero-copy IPC. + + This works for both CUDA and CPU tensors. + """ + assert self.target_engine_index is not None, "Target engine index is not set" + assert self.tensor_queues is not None, "Tensor queues are not set" + + # Generate unique tensor ID (without request ID embedded) + tensor_id = f"{id(self)}_{self._tensor_id_counter}" + self._tensor_id_counter += 1 + + # Move tensor to shared memory for IPC + # This is required for proper inter-process communication + if not obj.is_shared(): + obj = obj.share_memory_() + + target_queue = self.tensor_queues[self.target_engine_index] + ipc_data = TensorIpcData( + request_id=self._current_request_id, + tensor_id=tensor_id, + tensor=obj, + ) + # Use a timeout to avoid blocking indefinitely + target_queue.put(ipc_data, timeout=10.0) + + logger.debug( + "Sent tensor %s for request %s (shape=%s, device=%s) to engine %d " + "via IPC queue (shared memory)", + tensor_id, + self._current_request_id, + obj.shape, + obj.device, + self.target_engine_index, + ) + + return TensorIpcHandle( + request_id=self._current_request_id, + tensor_id=tensor_id, + shape=tuple(obj.shape), + dtype=str(obj.dtype).removeprefix("torch."), + device=str(obj.device), + ) + def _encode_tensor( self, obj: torch.Tensor - ) -> tuple[str, tuple[int, ...], int | memoryview]: + ) -> ( + tuple[str, tuple[int, ...], int | memoryview] | dict[str, Any] | TensorIpcHandle + ): assert self.aux_buffers is not None + + # Check if we should use IPC for this tensor + # IPC is used when: multimodal_tensor_ipc is "torch_shm", queues are available, + # and we have a target engine + if ( + self.multimodal_tensor_ipc == "torch_shm" + and self.tensor_queues is not None + and self.target_engine_index is not None + ): + try: + return self._encode_ipc_queue_tensor(obj) + except Exception as e: + logger.warning( + "Failed to send tensor via IPC queue: %s. " + "Falling back to standard serialization.", + e, + ) + # Fall through to standard serialization + + # Standard serialization fallback + # For CUDA tensors without IPC support, we need to move to CPU first + if obj.is_cuda: + if ( + self.multimodal_tensor_ipc == "torch_shm" + and self.tensor_queues is not None + ): + # Only warn if IPC was expected but unavailable + logger.warning( + "CUDA tensor without IPC support encountered " + "(no target engine set). Moving to CPU for " + "serialization. This will be slow." + ) + obj = obj.cpu() + # view the tensor as a contiguous 1D array of bytes arr_data = tensor_data(obj) if obj.nbytes < self.size_threshold: @@ -281,9 +422,18 @@ class MsgpackDecoder: Note that unlike vanilla `msgspec` Decoders, this interface is generally not thread-safe when encoding tensors / numpy arrays. + + For multimodal tensors sent via torch.multiprocessing.Queue (when IPC + is enabled), they will be retrieved from the queue during decoding. + Works for both CUDA and CPU tensors. """ - def __init__(self, t: Any | None = None, share_mem: bool = True): + def __init__( + self, + t: Any | None = None, + share_mem: bool = True, + tensor_queue: Any | None = None, + ): self.share_mem = share_mem self.pin_tensors = is_pin_memory_available() args = () if t is None else (t,) @@ -291,6 +441,15 @@ def __init__(self, t: Any | None = None, share_mem: bool = True): *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook ) self.aux_buffers: Sequence[bytestr] = () + # Tensor IPC queue for receiving multimodal tensors from API servers + self.tensor_queue = tensor_queue + # Buffer for temporarily storing tensors retrieved from queue + # Keys are tuples: (request_id, tensor_id) or (None, tensor_id) for legacy + self._tensor_buffer: dict[tuple[str | None, str], torch.Tensor] = {} + # Mapping from request_id to list of tensor keys for cleanup + self._request_to_tensors: dict[str, list[tuple[str | None, str]]] = {} + # Lock to synchronize access between cleanup and decode threads + self._buffer_lock = threading.Lock() if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() @@ -309,6 +468,12 @@ def dec_hook(self, t: type, obj: Any) -> Any: if isclass(t): if issubclass(t, np.ndarray): return self._decode_ndarray(obj) + if issubclass(t, TensorIpcHandle): + # msgspec deserializes dataclasses to dicts, so convert + # to TensorIpcHandle + if isinstance(obj, dict): + obj = TensorIpcHandle(**obj) + return self._decode_ipc_queue_tensor(obj) if issubclass(t, torch.Tensor): return self._decode_tensor(obj) if t is slice: @@ -354,6 +519,30 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray: return arr.reshape(shape) def _decode_tensor(self, arr: Any) -> torch.Tensor: + # Check if this is a TensorIpcHandle (sent via IPC queue) + # This can happen when IPC is enabled for non-multimodal tensor fields + if isinstance(arr, TensorIpcHandle): + return self._decode_ipc_queue_tensor(arr) + # Check if this is a dict that represents a TensorIpcHandle + # (msgspec serializes dataclasses as dicts without type info) + if ( + isinstance(arr, dict) + and "tensor_id" in arr + and "shape" in arr + and "dtype" in arr + and "device" in arr + ): + # Convert dict to TensorIpcHandle and decode it + handle = TensorIpcHandle(**arr) + return self._decode_ipc_queue_tensor(handle) + # Check if this is a list/tuple with 5 elements (TensorIpcHandle) + # msgspec serializes NamedTuples as lists + if isinstance(arr, (list, tuple)) and len(arr) == 5: + # Convert list to TensorIpcHandle and decode it + handle = TensorIpcHandle(*arr) + return self._decode_ipc_queue_tensor(handle) + + # Standard tensor decoding dtype, shape, data = arr is_aux = isinstance(data, int) buffer = self.aux_buffers[data] if is_aux else data @@ -375,6 +564,55 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor: # Convert back to proper shape & type return arr.view(torch_dtype).view(shape) + def _decode_ipc_queue_tensor(self, handle: TensorIpcHandle) -> torch.Tensor: + """Retrieve a tensor from torch.multiprocessing.Queue. + + Works for CUDA and CPU. + """ + + assert self.tensor_queue is not None, "Tensor queue is not set" + + # Create lookup key from handle + lookup_key = (handle.request_id, handle.tensor_id) + + # Drain all available tensors. We save them regardless if this is the one + # we're waiting for as they may arrive out of order from multiple producers. + while True: + # Check if tensor is already in buffer (with lock) + with self._buffer_lock: + if lookup_key in self._tensor_buffer: + # Retrieve and remove tensor from buffer + tensor = self._tensor_buffer.pop(lookup_key) + + # Remove from request tracking when consumed + if ( + handle.request_id is not None + and handle.request_id in self._request_to_tensors + ): + tensors = self._request_to_tensors.get(handle.request_id) + if tensors: + tensors.remove(lookup_key) + # Clean up if this is the last tensor for the request + if not tensors: + del self._request_to_tensors[handle.request_id] + + return tensor + + # Release lock while waiting on queue (important to avoid blocking cleanup) + ipc_data: TensorIpcData = self.tensor_queue.get(timeout=10.0) + + # Store the received tensor (with lock) + with self._buffer_lock: + # Store tensor with tuple key (request_id, tensor_id) + tensor_key = (ipc_data.request_id, ipc_data.tensor_id) + self._tensor_buffer[tensor_key] = ipc_data.tensor + + # Track which request this tensor belongs to for cleanup + if ipc_data.request_id is not None: + if ipc_data.request_id not in self._request_to_tensors: + self._request_to_tensors[ipc_data.request_id] = [] + self._request_to_tensors[ipc_data.request_id].append(tensor_key) + def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems: return MultiModalKwargsItems( { @@ -409,6 +647,22 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: # Although it violates NestedTensors type, MultiModalKwargs # values are sometimes floats. return obj + if isinstance(obj, TensorIpcHandle): + return self._decode_ipc_queue_tensor(obj) + # Check if this is a dict that represents a TensorIpcHandle + # (msgspec serializes dataclasses as dicts without type info + # in nested structures) + if ( + isinstance(obj, dict) + and "tensor_id" in obj + and "shape" in obj + and "dtype" in obj + and "device" in obj + ): + # Convert dict to TensorIpcHandle and decode it + # Handle both new format (with request_id) and old format (without) + handle = TensorIpcHandle(**obj) + return self._decode_ipc_queue_tensor(handle) if not isinstance(obj, list): raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}") if obj and isinstance(obj[0], str): @@ -421,6 +675,37 @@ def _decode_nested_slices(self, obj: Any) -> Any: return slice(*obj) return [self._decode_nested_slices(x) for x in obj] + def cleanup_request_tensors(self, request_id: str) -> int: + """Remove all orphaned tensors associated with a request. + + This should be called when a request is aborted, times out, or fails + to ensure tensors in the buffer don't accumulate indefinitely. + + Args: + request_id: The request ID whose tensors should be cleaned up. + + Returns: + The number of tensors that were removed from the buffer. + """ + with self._buffer_lock: + if request_id not in self._request_to_tensors: + return 0 + + tensor_keys = self._request_to_tensors.pop(request_id) + removed_count = 0 + + for tensor_key in tensor_keys: + if tensor_key in self._tensor_buffer: + del self._tensor_buffer[tensor_key] + removed_count += 1 + logger.debug( + "Cleaned up orphaned tensor %s for request %s", + tensor_key[1], # Just log the tensor_id part + request_id, + ) + + return removed_count + def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_RAW_VIEW: return data diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 75ad304ddf1a..5c7b6bbd4d41 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -174,6 +174,7 @@ def __init__( input_addresses: list[str], output_addresses: list[str], stats_update_address: str | None = None, + tensor_queues: list[Any] | None = None, ): """Initialize and start API server worker processes. @@ -186,6 +187,7 @@ def __init__( input_addresses: Input addresses for each API server output_addresses: Output addresses for each API server stats_update_address: Optional stats update address + tensor_queues: Optional tensor IPC queues for CUDA tensor sharing """ self.listen_address = listen_address self.sock = sock @@ -206,6 +208,8 @@ def __init__( } if stats_update_address is not None: client_config["stats_update_address"] = stats_update_address + if tensor_queues is not None: + client_config["tensor_queues"] = tensor_queues proc = spawn_context.Process( target=target_server_fn,