diff --git a/vllm/envs.py b/vllm/envs.py index f82dae108f6a..0b1215483396 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -214,6 +214,8 @@ VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False + VLLM_USE_TRITON_POS_EMBED: bool = False + VLLM_POS_EMBED_CACHE_SIZE: int = 100 VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True @@ -1442,6 +1444,17 @@ def get_vllm_port() -> int | None: "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL": lambda: bool( int(os.getenv("VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "0")) ), + # If set, use a fused Triton kernel for bilinear position-embedding + # interpolation in Qwen3-VL (replaces ~25 small eager kernels with one). + "VLLM_USE_TRITON_POS_EMBED": lambda: bool( + int(os.getenv("VLLM_USE_TRITON_POS_EMBED", "0")) + ), + # Number of grid configurations to pre-warm in the Qwen3-VL position + # embedding cache (0 = disabled, max 100). Uses ~9 MB per entry on + # average at BF16; 100 entries ≈ 0.9 GB. + "VLLM_POS_EMBED_CACHE_SIZE": lambda: int( + os.getenv("VLLM_POS_EMBED_CACHE_SIZE", "100") + ), # If set to 1/True, use the TRTLLM attention backend in flashinfer. # If set to 0/False, use the default attention backend in flashinfer. # If not set, auto-detect the attention backend in flashinfer. diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1f1ee2f56219..e3ab33bfa5b7 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -48,6 +48,7 @@ ) from transformers.video_utils import VideoMetadata +import vllm.envs as envs from vllm.compilation.decorators import support_torch_compile from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions @@ -90,6 +91,7 @@ PromptUpdateDetails, ) from vllm.sequence import IntermediateTensors +from vllm.triton_utils import tl, triton from vllm.utils.collection_utils import is_list_of from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -135,6 +137,243 @@ # This avoids creating a new graph for each unique batch size at runtime BATCH_BUCKETS = [8, 16, 32, 64] +# --------------------------------------------------------------------------- +# Triton kernel: fused bilinear position-embedding interpolation +# --------------------------------------------------------------------------- +# Replaces ~25 small eager-mode CUDA kernels (linspace, meshgrid, clamp, +# stack, embedding lookup, weighted sum, reshape, permute …) with a single +# kernel launch. The spatial-merge reorder is baked into the index math so +# the output is ready to be added to the patch embeddings directly. +# --------------------------------------------------------------------------- + + +@triton.jit +def _bilinear_pos_embed_kernel( + # Pointers + embed_ptr, # [num_embeddings, HIDDEN_DIM] - embedding table + output_ptr, # [total_out, HIDDEN_DIM] - result + # Grid geometry (runtime values) + H, # int - grid height (in patches) + W, # int - grid width (in patches) + h_scale, # float - (num_grid_per_side - 1) / (H - 1), or 0 + w_scale, # float - (num_grid_per_side - 1) / (W - 1), or 0 + # Model constants (constexpr -> compiled-in) + NUM_GRID: tl.constexpr, # num_grid_per_side + M_SIZE: tl.constexpr, # spatial_merge_size + HIDDEN_DIM: tl.constexpr, # embedding dim + BLOCK_D: tl.constexpr, # tile width over hidden dim +): + """Fused bilinear pos-embed interpolation with spatial-merge reorder.""" + pid = tl.program_id(0) # one program per output row + total_spatial = H * W + spatial_idx = pid % total_spatial # same interp for all t + + # --- undo spatial-merge reorder to recover original (row, col) --------- + num_blocks_w = W // M_SIZE + block_idx = spatial_idx // (M_SIZE * M_SIZE) + local_idx = spatial_idx % (M_SIZE * M_SIZE) + br = block_idx // num_blocks_w + bc = block_idx % num_blocks_w + lr = local_idx // M_SIZE + lc = local_idx % M_SIZE + row = br * M_SIZE + lr + col = bc * M_SIZE + lc + + # --- fractional grid coordinates (equivalent to torch.linspace) -------- + h_frac = row.to(tl.float32) * h_scale + w_frac = col.to(tl.float32) * w_scale + + # floor / ceil with clamp + hf = tl.math.floor(h_frac).to(tl.int32) + wf = tl.math.floor(w_frac).to(tl.int32) + hc = tl.minimum(hf + 1, NUM_GRID - 1) + wc = tl.minimum(wf + 1, NUM_GRID - 1) + + # --- bilinear weights (fp32) ------------------------------------------ + dh = h_frac - hf.to(tl.float32) + dw = w_frac - wf.to(tl.float32) + w11 = dh * dw + w10 = dh - w11 + w01 = dw - w11 + w00 = 1.0 - dh - w01 + + # --- embedding-table row offsets (flat index into [N, D] table) -------- + off00 = (hf * NUM_GRID + wf) * HIDDEN_DIM + off01 = (hf * NUM_GRID + wc) * HIDDEN_DIM + off10 = (hc * NUM_GRID + wf) * HIDDEN_DIM + off11 = (hc * NUM_GRID + wc) * HIDDEN_DIM + out_off = pid * HIDDEN_DIM + + # --- iterate over hidden dim in tiles of BLOCK_D ----------------------- + for d in tl.range(0, HIDDEN_DIM, BLOCK_D): + cols = d + tl.arange(0, BLOCK_D) + mask = cols < HIDDEN_DIM + + e00 = tl.load(embed_ptr + off00 + cols, mask=mask).to(tl.float32) + e01 = tl.load(embed_ptr + off01 + cols, mask=mask).to(tl.float32) + e10 = tl.load(embed_ptr + off10 + cols, mask=mask).to(tl.float32) + e11 = tl.load(embed_ptr + off11 + cols, mask=mask).to(tl.float32) + + val = w00 * e00 + w01 * e01 + w10 * e10 + w11 * e11 + + tl.store(output_ptr + out_off + cols, val, mask=mask) + + +def triton_pos_embed_interpolate( + embed_weight: torch.Tensor, # [num_embeddings, hidden_dim] + t: int, + h: int, + w: int, + num_grid_per_side: int, + m_size: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Python wrapper - launches the fused Triton kernel for one (t,h,w) grid. + + Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the + bilinearly-interpolated position embeddings already in spatial-merge order. + """ + hidden_dim = embed_weight.shape[1] + total_out = t * h * w + output = torch.empty(total_out, hidden_dim, device=embed_weight.device, dtype=dtype) + + h_scale = float(num_grid_per_side - 1) / float(h - 1) if h > 1 else 0.0 + w_scale = float(num_grid_per_side - 1) / float(w - 1) if w > 1 else 0.0 + + BLOCK_D = triton.next_power_of_2(hidden_dim) + + _bilinear_pos_embed_kernel[(total_out,)]( + embed_weight, + output, + h, + w, + h_scale, + w_scale, + num_grid_per_side, # NUM_GRID (constexpr) + m_size, # M_SIZE (constexpr) + hidden_dim, # HIDDEN_DIM (constexpr) + BLOCK_D, # BLOCK_D (constexpr) + ) + return output + + +# --------------------------------------------------------------------------- +# Grid configurations for CUDA graph capture (T, H, W in patch units) +# +# Top 100 most common grids for embedding cache pre-warming. +# Pre-warming these grids at startup avoids cold-start embedding computation +# at runtime, eliminating ~20 small kernel launches per grid on first encounter. +# Based on MLPerf VLM dataset analysis (~71% coverage with top 100 grids). +_EMBEDDING_WARMUP_GRIDS: list[tuple[int, int, int]] = [ + # Top 50 grids (sorted by frequency) + (1, 62, 62), + (1, 32, 32), + (1, 50, 50), + (1, 38, 38), + (1, 76, 76), + (1, 94, 94), + (1, 64, 64), + (1, 124, 124), + (1, 68, 68), + (1, 100, 100), + (1, 16, 16), + (1, 24, 24), + (1, 46, 46), + (1, 44, 44), + (1, 42, 42), + (1, 40, 40), + (1, 56, 56), + (1, 128, 128), + (1, 18, 18), + (1, 28, 28), + (1, 34, 34), + (1, 80, 80), + (1, 30, 30), + (1, 38, 50), + (1, 22, 22), + (1, 112, 112), + (1, 36, 36), + (1, 34, 50), + (1, 188, 188), + (1, 14, 20), + (1, 90, 90), + (1, 44, 42), + (1, 16, 18), + (1, 54, 54), + (1, 48, 48), + (1, 40, 42), + (1, 60, 60), + (1, 88, 88), + (1, 26, 26), + (1, 156, 156), + (1, 94, 62), + (1, 30, 38), + (1, 24, 38), + (1, 20, 20), + (1, 24, 16), + (1, 18, 16), + (1, 120, 120), + (1, 60, 80), + (1, 52, 52), + (1, 66, 66), + # Next 50 grids + (1, 20, 14), + (1, 24, 32), + (1, 160, 160), + (1, 28, 38), + (1, 30, 40), + (1, 38, 42), + (1, 58, 58), + (1, 20, 32), + (1, 50, 38), + (1, 48, 64), + (1, 78, 78), + (1, 24, 20), + (1, 42, 62), + (1, 62, 94), + (1, 36, 42), + (1, 32, 20), + (1, 150, 150), + (1, 50, 42), + (1, 50, 76), + (1, 72, 72), + (1, 32, 24), + (1, 46, 42), + (1, 92, 94), + (1, 82, 82), + (1, 32, 38), + (1, 90, 94), + (1, 14, 22), + (1, 76, 100), + (1, 94, 92), + (1, 24, 18), + (1, 54, 42), + (1, 38, 32), + (1, 18, 24), + (1, 28, 32), + (1, 30, 42), + (1, 56, 76), + (1, 62, 42), + (1, 28, 50), + (1, 32, 42), + (1, 36, 50), + (1, 38, 24), + (1, 108, 82), + (1, 16, 20), + (1, 26, 38), + (1, 38, 36), + (1, 34, 42), + (1, 76, 50), + (1, 38, 56), + (1, 48, 42), + (1, 30, 32), +] +# Slice the list based on VLLM_POS_EMBED_CACHE_SIZE (0 = disabled, max 100). +_EMBEDDING_WARMUP_GRIDS = _EMBEDDING_WARMUP_GRIDS[ + : max(0, envs.VLLM_POS_EMBED_CACHE_SIZE) +] +_EMBEDDING_WARMUP_GRIDS_SET: set[tuple[int, int, int]] = set(_EMBEDDING_WARMUP_GRIDS) + class Qwen3_VisionPatchEmbed(nn.Module): def __init__( @@ -447,6 +686,29 @@ def __init__( ] ) + self._embedding_cache: dict[tuple[int, int, int], torch.Tensor] = {} + + def warmup_embedding_cache(self) -> None: + """Pre-compute and cache position embeddings for common grid sizes. + + Call this after model weights are loaded to pre-populate the cache, + eliminating ~25 small kernel launches per grid on first encounter. + """ + logger.info( + "Pre-warming position embedding cache for %d grid configurations", + len(_EMBEDDING_WARMUP_GRIDS), + ) + for grid in _EMBEDDING_WARMUP_GRIDS: + t, h, w = grid + grid_key = (t, h, w) + if grid_key not in self._embedding_cache: + pos_embed = self.fast_pos_embed_interpolate([[t, h, w]]) + self._embedding_cache[grid_key] = pos_embed + logger.info( + "Position embedding cache warmed up: %d entries", + len(self._embedding_cache), + ) + @property def dtype(self) -> torch.dtype: return self.patch_embed.proj.weight.dtype @@ -500,6 +762,47 @@ def rot_pos_emb(self, grid_thw: list[list[int]]): return cos_combined, sin_combined + def _get_cached_pos_embeds(self, grid_thw_list: list[list[int]]) -> torch.Tensor: + """ + Get position embeddings with per-grid caching. + + This method caches pos embeddings only for grids in + _EMBEDDING_WARMUP_GRIDS_SET to avoid unbounded memory growth. + Grids not in the warmup set are computed on-the-fly without caching. + + Args: + grid_thw_list: List of [T, H, W] for each image + + Returns: + Concatenated position embeddings for all grids. + """ + pos_embeds_list: list[torch.Tensor] = [] + + for grid in grid_thw_list: + t, h, w = grid + grid_key = (t, h, w) + + if grid_key in self._embedding_cache: + # Cache hit - use cached embeddings + cached = self._embedding_cache[grid_key] + pos_embeds_list.append(cached) + else: + # Cache miss - compute embeddings + single_grid = [[t, h, w]] + pos_embed = self.fast_pos_embed_interpolate(single_grid) + + # Only cache if grid is in pre-warmed set to prevent OOM. + # Caching at runtime causes unbounded memory growth. + if grid_key in _EMBEDDING_WARMUP_GRIDS_SET: + self._embedding_cache[grid_key] = pos_embed + + pos_embeds_list.append(pos_embed) + + # Concatenate all embeddings + pos_embeds = torch.cat(pos_embeds_list, dim=0) + + return pos_embeds + def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: num_grid_per_side = self.num_grid_per_side m_size = self.spatial_merge_size @@ -507,55 +810,81 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: outputs = [] for t, h, w in grid_thw: - h_idxs = torch.linspace( - 0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device - ) - w_idxs = torch.linspace( - 0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device - ) + if envs.VLLM_USE_TRITON_POS_EMBED: + repeated = triton_pos_embed_interpolate( + self.pos_embed.weight, + t, + h, + w, + num_grid_per_side, + m_size, + self.dtype, + ) + else: + h_idxs = torch.linspace( + 0, + num_grid_per_side - 1, + h, + dtype=torch.float32, + device=self.device, + ) + w_idxs = torch.linspace( + 0, + num_grid_per_side - 1, + w, + dtype=torch.float32, + device=self.device, + ) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + # Create meshgrid view for all h, w vars + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") + h_floor_grid, w_floor_grid = torch.meshgrid( + h_floor, w_floor, indexing="ij" + ) + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") + + # original computation of weights + # w00 = (1 - dh_grid) * (1 - dw_grid) + # w01 = (1 - dh_grid) * dw_grid + # w10 = dh_grid * (1 - dw_grid) + # w11 = dh_grid * dw_grid + # we reuse w11 here to avoid duplicate + # dh_grid * dw_grid computation + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - w01 + + h_grid = torch.stack( + [h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid] + ) + w_grid = torch.stack( + [w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid] + ) + h_grid_idx = h_grid * num_grid_per_side + + indices = (h_grid_idx + w_grid).reshape(4, -1) + weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) + weights = weights.to(dtype=self.dtype) + + embeds = self.pos_embed(indices) + embeds *= weights + combined = embeds.sum(dim=0) + + combined = combined.reshape( + h // m_size, m_size, w // m_size, m_size, hidden_dim + ) + combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) + repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) - h_floor = h_idxs.to(torch.long) - w_floor = w_idxs.to(torch.long) - h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) - w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) - - dh = h_idxs - h_floor - dw = w_idxs - w_floor - - # Create meshgrid view for all h, w vars - dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") - h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") - h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") - - # original computation of weights - # w00 = (1 - dh_grid) * (1 - dw_grid) - # w01 = (1 - dh_grid) * dw_grid - # w10 = dh_grid * (1 - dw_grid) - # w11 = dh_grid * dw_grid - # we reuse w11 here to avoid duplicate - # dh_grid * dw_grid computation - w11 = dh_grid * dw_grid - w10 = dh_grid - w11 - w01 = dw_grid - w11 - w00 = 1 - dh_grid - w01 - - h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) - w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) - h_grid_idx = h_grid * num_grid_per_side - - indices = (h_grid_idx + w_grid).reshape(4, -1) - weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) - weights = weights.to(dtype=self.dtype) - - embeds = self.pos_embed(indices) - embeds *= weights - combined = embeds.sum(dim=0) - - combined = combined.reshape( - h // m_size, m_size, w // m_size, m_size, hidden_dim - ) - combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) - repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) outputs.append(repeated) return torch.cat(outputs, dim=0) @@ -632,7 +961,7 @@ def forward( grid_thw_list = grid_thw.tolist() grid_thw = grid_thw.numpy() - pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) + pos_embeds = self._get_cached_pos_embeds(grid_thw_list) hidden_states = hidden_states + pos_embeds rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) @@ -2170,7 +2499,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if self.visual is None: skip_prefixes.extend(["visual."]) loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + # Pre-warm the position embedding cache now that weights are loaded. + if self.visual is not None: + self.visual.warmup_embedding_cache() + + return loaded def get_mm_mapping(self) -> MultiModelKeys: """