From 82e5ff436c6304b00a2a4dcd637c050e28266789 Mon Sep 17 00:00:00 2001 From: Zhanda Date: Sat, 7 Feb 2026 19:47:33 -0800 Subject: [PATCH 1/9] feat: Add bilinear_pos_embed triton kernel and cache Signed-off-by: Zhanda --- vllm/envs.py | 6 + vllm/model_executor/models/qwen3_vl.py | 431 ++++++++++++++++++++++--- 2 files changed, 387 insertions(+), 50 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index f82dae108f6a..05897c52b052 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -214,6 +214,7 @@ VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False + VLLM_USE_TRITON_POS_EMBED: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True @@ -1442,6 +1443,11 @@ def get_vllm_port() -> int | None: "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL": lambda: bool( int(os.getenv("VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "0")) ), + # If set, use a fused Triton kernel for bilinear position-embedding + # interpolation in Qwen3-VL (replaces ~25 small eager kernels with one). + "VLLM_USE_TRITON_POS_EMBED": lambda: bool( + int(os.getenv("VLLM_USE_TRITON_POS_EMBED", "0")) + ), # If set to 1/True, use the TRTLLM attention backend in flashinfer. # If set to 0/False, use the default attention backend in flashinfer. # If not set, auto-detect the attention backend in flashinfer. diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1f1ee2f56219..effa67a9bf33 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -48,6 +48,7 @@ ) from transformers.video_utils import VideoMetadata +import vllm.envs as envs from vllm.compilation.decorators import support_torch_compile from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions @@ -90,6 +91,7 @@ PromptUpdateDetails, ) from vllm.sequence import IntermediateTensors +from vllm.triton_utils import tl, triton from vllm.utils.collection_utils import is_list_of from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -135,6 +137,239 @@ # This avoids creating a new graph for each unique batch size at runtime BATCH_BUCKETS = [8, 16, 32, 64] +# --------------------------------------------------------------------------- +# Triton kernel: fused bilinear position-embedding interpolation +# --------------------------------------------------------------------------- +# Replaces ~25 small eager-mode CUDA kernels (linspace, meshgrid, clamp, +# stack, embedding lookup, weighted sum, reshape, permute …) with a single +# kernel launch. The spatial-merge reorder is baked into the index math so +# the output is ready to be added to the patch embeddings directly. +# --------------------------------------------------------------------------- + + +@triton.jit +def _bilinear_pos_embed_kernel( + # Pointers + embed_ptr, # [num_embeddings, HIDDEN_DIM] - embedding table + output_ptr, # [total_out, HIDDEN_DIM] - result + # Grid geometry (runtime values) + H, # int - grid height (in patches) + W, # int - grid width (in patches) + h_scale, # float - (num_grid_per_side - 1) / (H - 1), or 0 + w_scale, # float - (num_grid_per_side - 1) / (W - 1), or 0 + # Model constants (constexpr -> compiled-in) + NUM_GRID: tl.constexpr, # num_grid_per_side + M_SIZE: tl.constexpr, # spatial_merge_size + HIDDEN_DIM: tl.constexpr, # embedding dim + BLOCK_D: tl.constexpr, # tile width over hidden dim +): + """Fused bilinear pos-embed interpolation with spatial-merge reorder.""" + pid = tl.program_id(0) # one program per output row + total_spatial = H * W + spatial_idx = pid % total_spatial # same interp for all t + + # --- undo spatial-merge reorder to recover original (row, col) --------- + num_blocks_w = W // M_SIZE + block_idx = spatial_idx // (M_SIZE * M_SIZE) + local_idx = spatial_idx % (M_SIZE * M_SIZE) + br = block_idx // num_blocks_w + bc = block_idx % num_blocks_w + lr = local_idx // M_SIZE + lc = local_idx % M_SIZE + row = br * M_SIZE + lr + col = bc * M_SIZE + lc + + # --- fractional grid coordinates (equivalent to torch.linspace) -------- + h_frac = row.to(tl.float32) * h_scale + w_frac = col.to(tl.float32) * w_scale + + # floor / ceil with clamp + hf = tl.math.floor(h_frac).to(tl.int32) + wf = tl.math.floor(w_frac).to(tl.int32) + hc = tl.minimum(hf + 1, NUM_GRID - 1) + wc = tl.minimum(wf + 1, NUM_GRID - 1) + + # --- bilinear weights (fp32) ------------------------------------------ + dh = h_frac - hf.to(tl.float32) + dw = w_frac - wf.to(tl.float32) + w11 = dh * dw + w10 = dh - w11 + w01 = dw - w11 + w00 = 1.0 - dh - w01 + + # --- embedding-table row offsets (flat index into [N, D] table) -------- + off00 = (hf * NUM_GRID + wf) * HIDDEN_DIM + off01 = (hf * NUM_GRID + wc) * HIDDEN_DIM + off10 = (hc * NUM_GRID + wf) * HIDDEN_DIM + off11 = (hc * NUM_GRID + wc) * HIDDEN_DIM + out_off = pid * HIDDEN_DIM + + # --- iterate over hidden dim in tiles of BLOCK_D ----------------------- + for d in tl.range(0, HIDDEN_DIM, BLOCK_D): + cols = d + tl.arange(0, BLOCK_D) + mask = cols < HIDDEN_DIM + + e00 = tl.load(embed_ptr + off00 + cols, mask=mask).to(tl.float32) + e01 = tl.load(embed_ptr + off01 + cols, mask=mask).to(tl.float32) + e10 = tl.load(embed_ptr + off10 + cols, mask=mask).to(tl.float32) + e11 = tl.load(embed_ptr + off11 + cols, mask=mask).to(tl.float32) + + val = w00 * e00 + w01 * e01 + w10 * e10 + w11 * e11 + + tl.store(output_ptr + out_off + cols, val, mask=mask) + + +def triton_pos_embed_interpolate( + embed_weight: torch.Tensor, # [num_embeddings, hidden_dim] + t: int, + h: int, + w: int, + num_grid_per_side: int, + m_size: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Python wrapper - launches the fused Triton kernel for one (t,h,w) grid. + + Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the + bilinearly-interpolated position embeddings already in spatial-merge order. + """ + hidden_dim = embed_weight.shape[1] + total_out = t * h * w + output = torch.empty(total_out, hidden_dim, device=embed_weight.device, dtype=dtype) + + h_scale = float(num_grid_per_side - 1) / float(h - 1) if h > 1 else 0.0 + w_scale = float(num_grid_per_side - 1) / float(w - 1) if w > 1 else 0.0 + + BLOCK_D = triton.next_power_of_2(hidden_dim) + + _bilinear_pos_embed_kernel[(total_out,)]( + embed_weight, + output, + h, + w, + h_scale, + w_scale, + num_grid_per_side, # NUM_GRID (constexpr) + m_size, # M_SIZE (constexpr) + hidden_dim, # HIDDEN_DIM (constexpr) + BLOCK_D, # BLOCK_D (constexpr) + ) + return output + + +# --------------------------------------------------------------------------- +# Grid configurations for CUDA graph capture (T, H, W in patch units) +# +# Top 100 most common grids for embedding cache pre-warming. +# Pre-warming these grids at startup avoids cold-start embedding computation +# at runtime, eliminating ~20 small kernel launches per grid on first encounter. +# Based on MLPerf VLM dataset analysis (~71% coverage with top 100 grids). +_EMBEDDING_WARMUP_GRIDS: list[tuple[int, int, int]] = [ + # Top 50 grids (sorted by frequency) + (1, 62, 62), + (1, 32, 32), + (1, 50, 50), + (1, 38, 38), + (1, 76, 76), + (1, 94, 94), + (1, 64, 64), + (1, 124, 124), + (1, 68, 68), + (1, 100, 100), + (1, 16, 16), + (1, 24, 24), + (1, 46, 46), + (1, 44, 44), + (1, 42, 42), + (1, 40, 40), + (1, 56, 56), + (1, 128, 128), + (1, 18, 18), + (1, 28, 28), + (1, 34, 34), + (1, 80, 80), + (1, 30, 30), + (1, 38, 50), + (1, 22, 22), + (1, 112, 112), + (1, 36, 36), + (1, 34, 50), + (1, 188, 188), + (1, 14, 20), + (1, 90, 90), + (1, 44, 42), + (1, 16, 18), + (1, 54, 54), + (1, 48, 48), + (1, 40, 42), + (1, 60, 60), + (1, 88, 88), + (1, 26, 26), + (1, 156, 156), + (1, 94, 62), + (1, 30, 38), + (1, 24, 38), + (1, 20, 20), + (1, 24, 16), + (1, 18, 16), + (1, 120, 120), + (1, 60, 80), + (1, 52, 52), + (1, 66, 66), + # Next 50 grids + (1, 20, 14), + (1, 24, 32), + (1, 160, 160), + (1, 28, 38), + (1, 30, 40), + (1, 38, 42), + (1, 58, 58), + (1, 20, 32), + (1, 50, 38), + (1, 48, 64), + (1, 78, 78), + (1, 24, 20), + (1, 42, 62), + (1, 62, 94), + (1, 36, 42), + (1, 32, 20), + (1, 150, 150), + (1, 50, 42), + (1, 50, 76), + (1, 72, 72), + (1, 32, 24), + (1, 46, 42), + (1, 92, 94), + (1, 82, 82), + (1, 32, 38), + (1, 90, 94), + (1, 14, 22), + (1, 76, 100), + (1, 94, 92), + (1, 24, 18), + (1, 54, 42), + (1, 38, 32), + (1, 18, 24), + (1, 28, 32), + (1, 30, 42), + (1, 56, 76), + (1, 62, 42), + (1, 28, 50), + (1, 32, 42), + (1, 36, 50), + (1, 38, 24), + (1, 108, 82), + (1, 16, 20), + (1, 26, 38), + (1, 38, 36), + (1, 34, 42), + (1, 76, 50), + (1, 38, 56), + (1, 48, 42), + (1, 30, 32), +] +_EMBEDDING_WARMUP_GRIDS_SET: set[tuple[int, int, int]] = set(_EMBEDDING_WARMUP_GRIDS) + class Qwen3_VisionPatchEmbed(nn.Module): def __init__( @@ -447,6 +682,29 @@ def __init__( ] ) + self._embedding_cache: dict[tuple[int, int, int], torch.Tensor] = {} + + def warmup_embedding_cache(self) -> None: + """Pre-compute and cache position embeddings for common grid sizes. + + Call this after model weights are loaded to pre-populate the cache, + eliminating ~25 small kernel launches per grid on first encounter. + """ + logger.info( + "Pre-warming position embedding cache for %d grid configurations", + len(_EMBEDDING_WARMUP_GRIDS), + ) + for grid in _EMBEDDING_WARMUP_GRIDS: + t, h, w = grid + grid_key = (t, h, w) + if grid_key not in self._embedding_cache: + pos_embed = self.fast_pos_embed_interpolate([[t, h, w]]) + self._embedding_cache[grid_key] = pos_embed + logger.info( + "Position embedding cache warmed up: %d entries", + len(self._embedding_cache), + ) + @property def dtype(self) -> torch.dtype: return self.patch_embed.proj.weight.dtype @@ -500,6 +758,47 @@ def rot_pos_emb(self, grid_thw: list[list[int]]): return cos_combined, sin_combined + def _get_cached_pos_embeds(self, grid_thw_list: list[list[int]]) -> torch.Tensor: + """ + Get position embeddings with per-grid caching. + + This method caches pos embeddings only for grids in + _EMBEDDING_WARMUP_GRIDS_SET to avoid unbounded memory growth. + Grids not in the warmup set are computed on-the-fly without caching. + + Args: + grid_thw_list: List of [T, H, W] for each image + + Returns: + Concatenated position embeddings for all grids. + """ + pos_embeds_list: list[torch.Tensor] = [] + + for grid in grid_thw_list: + t, h, w = grid + grid_key = (t, h, w) + + if grid_key in self._embedding_cache: + # Cache hit - use cached embeddings + cached = self._embedding_cache[grid_key] + pos_embeds_list.append(cached) + else: + # Cache miss - compute embeddings + single_grid = [[t, h, w]] + pos_embed = self.fast_pos_embed_interpolate(single_grid) + + # Only cache if grid is in pre-warmed set to prevent OOM. + # Caching at runtime causes unbounded memory growth. + if grid_key in _EMBEDDING_WARMUP_GRIDS_SET: + self._embedding_cache[grid_key] = pos_embed + + pos_embeds_list.append(pos_embed) + + # Concatenate all embeddings + pos_embeds = torch.cat(pos_embeds_list, dim=0) + + return pos_embeds + def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: num_grid_per_side = self.num_grid_per_side m_size = self.spatial_merge_size @@ -507,55 +806,81 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: outputs = [] for t, h, w in grid_thw: - h_idxs = torch.linspace( - 0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device - ) - w_idxs = torch.linspace( - 0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device - ) + if envs.VLLM_USE_TRITON_POS_EMBED: + repeated = triton_pos_embed_interpolate( + self.pos_embed.weight, + t, + h, + w, + num_grid_per_side, + m_size, + self.dtype, + ) + else: + h_idxs = torch.linspace( + 0, + num_grid_per_side - 1, + h, + dtype=torch.float32, + device=self.device, + ) + w_idxs = torch.linspace( + 0, + num_grid_per_side - 1, + w, + dtype=torch.float32, + device=self.device, + ) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + # Create meshgrid view for all h, w vars + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") + h_floor_grid, w_floor_grid = torch.meshgrid( + h_floor, w_floor, indexing="ij" + ) + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") + + # original computation of weights + # w00 = (1 - dh_grid) * (1 - dw_grid) + # w01 = (1 - dh_grid) * dw_grid + # w10 = dh_grid * (1 - dw_grid) + # w11 = dh_grid * dw_grid + # we reuse w11 here to avoid duplicate + # dh_grid * dw_grid computation + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - w01 + + h_grid = torch.stack( + [h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid] + ) + w_grid = torch.stack( + [w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid] + ) + h_grid_idx = h_grid * num_grid_per_side + + indices = (h_grid_idx + w_grid).reshape(4, -1) + weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) + weights = weights.to(dtype=self.dtype) + + embeds = self.pos_embed(indices) + embeds *= weights + combined = embeds.sum(dim=0) + + combined = combined.reshape( + h // m_size, m_size, w // m_size, m_size, hidden_dim + ) + combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) + repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) - h_floor = h_idxs.to(torch.long) - w_floor = w_idxs.to(torch.long) - h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) - w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) - - dh = h_idxs - h_floor - dw = w_idxs - w_floor - - # Create meshgrid view for all h, w vars - dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") - h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") - h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") - - # original computation of weights - # w00 = (1 - dh_grid) * (1 - dw_grid) - # w01 = (1 - dh_grid) * dw_grid - # w10 = dh_grid * (1 - dw_grid) - # w11 = dh_grid * dw_grid - # we reuse w11 here to avoid duplicate - # dh_grid * dw_grid computation - w11 = dh_grid * dw_grid - w10 = dh_grid - w11 - w01 = dw_grid - w11 - w00 = 1 - dh_grid - w01 - - h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) - w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) - h_grid_idx = h_grid * num_grid_per_side - - indices = (h_grid_idx + w_grid).reshape(4, -1) - weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) - weights = weights.to(dtype=self.dtype) - - embeds = self.pos_embed(indices) - embeds *= weights - combined = embeds.sum(dim=0) - - combined = combined.reshape( - h // m_size, m_size, w // m_size, m_size, hidden_dim - ) - combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) - repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) outputs.append(repeated) return torch.cat(outputs, dim=0) @@ -632,7 +957,7 @@ def forward( grid_thw_list = grid_thw.tolist() grid_thw = grid_thw.numpy() - pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) + pos_embeds = self._get_cached_pos_embeds(grid_thw_list) hidden_states = hidden_states + pos_embeds rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) @@ -2170,7 +2495,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if self.visual is None: skip_prefixes.extend(["visual."]) loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + # Pre-warm the position embedding cache now that weights are loaded. + if self.visual is not None: + self.visual.warmup_embedding_cache() + + return loaded def get_mm_mapping(self) -> MultiModelKeys: """ From 1ad7f89603d31c618c8c1075d1ae477379c015c5 Mon Sep 17 00:00:00 2001 From: Zhanda Date: Sat, 7 Feb 2026 20:01:34 -0800 Subject: [PATCH 2/9] Set env to control the cache size Signed-off-by: Zhanda --- vllm/envs.py | 7 +++++++ vllm/model_executor/models/qwen3_vl.py | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index 05897c52b052..0b1215483396 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -215,6 +215,7 @@ VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_USE_TRITON_POS_EMBED: bool = False + VLLM_POS_EMBED_CACHE_SIZE: int = 100 VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True @@ -1448,6 +1449,12 @@ def get_vllm_port() -> int | None: "VLLM_USE_TRITON_POS_EMBED": lambda: bool( int(os.getenv("VLLM_USE_TRITON_POS_EMBED", "0")) ), + # Number of grid configurations to pre-warm in the Qwen3-VL position + # embedding cache (0 = disabled, max 100). Uses ~9 MB per entry on + # average at BF16; 100 entries ≈ 0.9 GB. + "VLLM_POS_EMBED_CACHE_SIZE": lambda: int( + os.getenv("VLLM_POS_EMBED_CACHE_SIZE", "100") + ), # If set to 1/True, use the TRTLLM attention backend in flashinfer. # If set to 0/False, use the default attention backend in flashinfer. # If not set, auto-detect the attention backend in flashinfer. diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index effa67a9bf33..e3ab33bfa5b7 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -368,6 +368,10 @@ def triton_pos_embed_interpolate( (1, 48, 42), (1, 30, 32), ] +# Slice the list based on VLLM_POS_EMBED_CACHE_SIZE (0 = disabled, max 100). +_EMBEDDING_WARMUP_GRIDS = _EMBEDDING_WARMUP_GRIDS[ + : max(0, envs.VLLM_POS_EMBED_CACHE_SIZE) +] _EMBEDDING_WARMUP_GRIDS_SET: set[tuple[int, int, int]] = set(_EMBEDDING_WARMUP_GRIDS) From c4f829d279aebf5d54a4f14f1d750897742ed68f Mon Sep 17 00:00:00 2001 From: Zhanda Date: Wed, 4 Feb 2026 21:29:03 -0800 Subject: [PATCH 3/9] Initial implementation for fp8 vit flashinfer attn --- vllm/envs.py | 9 ++ .../layers/attention/mm_encoder_attention.py | 120 +++++++++++++++++- vllm/v1/attention/ops/vit_attn_wrappers.py | 24 +++- 3 files changed, 150 insertions(+), 3 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 0b1215483396..1c4e9511d406 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -216,6 +216,8 @@ VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_USE_TRITON_POS_EMBED: bool = False VLLM_POS_EMBED_CACHE_SIZE: int = 100 + VLLM_MM_ENCODER_FP8_ATTN: bool = False + VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH: str | None = None VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True @@ -1454,6 +1456,13 @@ def get_vllm_port() -> int | None: # average at BF16; 100 entries ≈ 0.9 GB. "VLLM_POS_EMBED_CACHE_SIZE": lambda: int( os.getenv("VLLM_POS_EMBED_CACHE_SIZE", "100") + # Controls whether to use FP8 attention for multimodal encoder (e.g., ViT) + "VLLM_MM_ENCODER_FP8_ATTN": lambda: bool( + int(os.getenv("VLLM_MM_ENCODER_FP8_ATTN", "0")) + ), + # Path to JSON file containing FP8 attention scales for multimodal encoder + "VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH": lambda: os.getenv( + "VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH", None ), # If set to 1/True, use the TRTLLM attention backend in flashinfer. # If set to 0/False, use the default attention backend in flashinfer. diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 28d83776ebe5..35e59ffca698 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +import json +from pathlib import Path import torch +import vllm.envs as envs from vllm.config import MultiModalConfig from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -20,6 +26,34 @@ logger = init_logger(__name__) +@functools.cache +def _load_fp8_scales_file(path: str | None) -> dict[str, dict[str, float]]: + """Load FP8 scales from file. Results are cached.""" + if path is None: + return {} + + path = str(Path(path).expanduser()) + with open(path, encoding="utf-8") as f: + data = json.load(f) + + # Handle nested "layers" format + if "layers" in data and isinstance(data["layers"], dict): + data = data["layers"] + + scales: dict[str, dict[str, float]] = {} + for layer_name, layer_scales in data.items(): + if not isinstance(layer_scales, dict): + continue + q = layer_scales.get("q", layer_scales.get("q_scale")) + k = layer_scales.get("k", layer_scales.get("k_scale")) + v = layer_scales.get("v", layer_scales.get("v_scale")) + if q is not None and k is not None and v is not None: + scales[layer_name] = {"q": float(q), "k": float(k), "v": float(v)} + + logger.info(f"Loaded FP8 attention scales from {path} ({len(scales)} layers)") + return scales + + # --8<-- [start:mm_encoder_attn] @CustomOp.register("mm_encoder_attn") class MMEncoderAttention(CustomOp): @@ -43,8 +77,7 @@ def __init__( head_size: hidden_size per attention head. scale: scale factor. num_kv_heads: number of kv heads. - prefix: This has no effect, it is only here to make it easier to - swap between Attention and MultiHeadAttention + prefix: layer name prefix, used to look up FP8 scales. multimodal_config: configs for multi-modal. """ super().__init__() @@ -90,6 +123,64 @@ def __init__( logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") + # FP8 attention support + self.fp8_enabled = False + self.fp8_scales: dict[str, float] | None = None + self.fp8_quant: QuantFP8 | None = None + + if envs.VLLM_MM_ENCODER_FP8_ATTN: + self._init_fp8_attention(prefix) + + def _init_fp8_attention(self, layer_name: str) -> None: + """Initialize FP8 attention for this layer.""" + scale_path = envs.VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH + all_scales = _load_fp8_scales_file(scale_path) + + if scale_path is None: + # No scale path provided - use scale=1.0 and warn + logger.warning_once( + "VLLM_MM_ENCODER_FP8_ATTN enabled but " + "VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH not set. " + "Using scale=1.0 for all Q/K/V (may cause accuracy issues)." + ) + self.fp8_scales = {"q": 1.0, "k": 1.0, "v": 1.0} + else: + # Scale path provided - layer must exist + layer_scales = all_scales.get(layer_name) + if layer_scales is None: + raise ValueError( + "FP8 attention enabled but scales not found for layer " + f"'{layer_name}' in {scale_path}. " + f"Available layers: {list(all_scales.keys())}" + ) + self.fp8_scales = layer_scales + + # Register scale tensors as buffers (auto-move to device with module) + # Shape (1, 1, 1, 1) as required by cuDNN + self.register_buffer( + "_fp8_q_scale", + torch.tensor([self.fp8_scales["q"]], dtype=torch.float32).view(1, 1, 1, 1), + ) + self.register_buffer( + "_fp8_k_scale", + torch.tensor([self.fp8_scales["k"]], dtype=torch.float32).view(1, 1, 1, 1), + ) + self.register_buffer( + "_fp8_v_scale", + torch.tensor([self.fp8_scales["v"]], dtype=torch.float32).view(1, 1, 1, 1), + ) + + # Create QuantFP8 for efficient quantization + self.fp8_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) + self.fp8_enabled = True + + logger.debug( + f"FP8 attention enabled for {layer_name}: " + f"q={self.fp8_scales['q']:.4f}, " + f"k={self.fp8_scales['k']:.4f}, " + f"v={self.fp8_scales['v']:.4f}" + ) + @classmethod def enabled(cls) -> bool: return True @@ -187,6 +278,22 @@ def _forward_fa( output = output.reshape(bsz, q_len, -1) return output + def _quantize_to_fp8( + self, + tensor: torch.Tensor, + scale: torch.Tensor, + ) -> torch.Tensor: + """Quantize a 3D tensor (total_tokens, num_heads, head_dim) to FP8. + + Uses QuantFP8 CustomOp for backend-aware quantization. + """ + assert self.fp8_quant is not None + orig_shape = tensor.shape + # QuantFP8 expects 2D input: (total_tokens, num_heads * head_dim) + tensor_2d = tensor.view(orig_shape[0], -1) + fp8_tensor, _ = self.fp8_quant(tensor_2d, scale=scale) + return fp8_tensor.view(orig_shape) + def _forward_flashinfer( self, query: torch.Tensor, @@ -197,6 +304,12 @@ def _forward_flashinfer( sequence_lengths: torch.Tensor | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: + if self.fp8_enabled: + assert self.fp8_quant is not None and self.fp8_scales is not None + query = self._quantize_to_fp8(query, self._fp8_q_scale) + key = self._quantize_to_fp8(key, self._fp8_k_scale) + value = self._quantize_to_fp8(value, self._fp8_v_scale) + return vit_flashinfer_wrapper( q=query, k=key, @@ -206,6 +319,9 @@ def _forward_flashinfer( cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, sequence_lengths=sequence_lengths, + q_scale=self._fp8_q_scale if self.fp8_enabled else None, + k_scale=self._fp8_k_scale if self.fp8_enabled else None, + v_scale=self._fp8_v_scale if self.fp8_enabled else None, ) def _forward_fa4( diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index 84b1438fb1b0..081b245364bc 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -273,6 +273,9 @@ def flashinfer_wrapper( cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, sequence_lengths: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + k_scale: torch.Tensor | None = None, + v_scale: torch.Tensor | None = None, ) -> torch.Tensor: from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache @@ -306,6 +309,9 @@ def flashinfer_wrapper( batch_offsets_k=batch_offsets_qk, batch_offsets_v=batch_offsets_v, batch_offsets_o=batch_offsets_o, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, ) if is_reshaped: @@ -323,6 +329,9 @@ def vit_flashinfer_wrapper_fake( cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, sequence_lengths: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + k_scale: torch.Tensor | None = None, + v_scale: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(q) @@ -343,7 +352,20 @@ def vit_flashinfer_wrapper( cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, sequence_lengths: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + k_scale: torch.Tensor | None = None, + v_scale: torch.Tensor | None = None, ) -> torch.Tensor: return torch.ops.vllm.flashinfer_wrapper( - q, k, v, scale, workspace_buffer, cu_seqlens, max_seqlen, sequence_lengths + q, + k, + v, + scale, + workspace_buffer, + cu_seqlens, + max_seqlen, + sequence_lengths, + q_scale, + k_scale, + v_scale, ) From 3641873798b9572aa2ac52cea875630e74bf84eb Mon Sep 17 00:00:00 2001 From: Zhanda Date: Fri, 6 Feb 2026 17:20:22 -0800 Subject: [PATCH 4/9] Support fp8 flashinfer vit attn --- .../layers/attention/mm_encoder_attention.py | 8 +++++--- vllm/v1/attention/ops/vit_attn_wrappers.py | 7 +++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 35e59ffca698..e4ee54ffda37 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -97,6 +97,7 @@ def __init__( # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() + self.dtype = dtype # Try to get vision attention backend from multimodal_config. attn_backend_override = None @@ -290,9 +291,9 @@ def _quantize_to_fp8( assert self.fp8_quant is not None orig_shape = tensor.shape # QuantFP8 expects 2D input: (total_tokens, num_heads * head_dim) - tensor_2d = tensor.view(orig_shape[0], -1) - fp8_tensor, _ = self.fp8_quant(tensor_2d, scale=scale) - return fp8_tensor.view(orig_shape) + tensor_2d = tensor.reshape(orig_shape[0], -1) + fp8_tensor, _ = self.fp8_quant.forward_cuda(tensor_2d, scale=scale) + return fp8_tensor.reshape(orig_shape) def _forward_flashinfer( self, @@ -322,6 +323,7 @@ def _forward_flashinfer( q_scale=self._fp8_q_scale if self.fp8_enabled else None, k_scale=self._fp8_k_scale if self.fp8_enabled else None, v_scale=self._fp8_v_scale if self.fp8_enabled else None, + o_data_type=self.dtype if self.fp8_enabled else None, ) def _forward_fa4( diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index 081b245364bc..fb9bab7d0fff 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -276,6 +276,7 @@ def flashinfer_wrapper( q_scale: torch.Tensor | None = None, k_scale: torch.Tensor | None = None, v_scale: torch.Tensor | None = None, + o_data_type: torch.dtype | None = None, ) -> torch.Tensor: from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache @@ -312,6 +313,7 @@ def flashinfer_wrapper( q_scale=q_scale, k_scale=k_scale, v_scale=v_scale, + o_data_type=o_data_type, ) if is_reshaped: @@ -332,7 +334,10 @@ def vit_flashinfer_wrapper_fake( q_scale: torch.Tensor | None = None, k_scale: torch.Tensor | None = None, v_scale: torch.Tensor | None = None, + o_data_type: torch.dtype | None = None, ) -> torch.Tensor: + if o_data_type is not None and o_data_type != q.dtype: + return torch.empty(q.shape, device=q.device, dtype=o_data_type) return torch.empty_like(q) @@ -355,6 +360,7 @@ def vit_flashinfer_wrapper( q_scale: torch.Tensor | None = None, k_scale: torch.Tensor | None = None, v_scale: torch.Tensor | None = None, + o_data_type: torch.dtype | None = None, ) -> torch.Tensor: return torch.ops.vllm.flashinfer_wrapper( q, @@ -368,4 +374,5 @@ def vit_flashinfer_wrapper( q_scale, k_scale, v_scale, + o_data_type, ) From c52319571997f90ec482e8bf5f5274a821a2a3d4 Mon Sep 17 00:00:00 2001 From: Zhanda Date: Sat, 7 Feb 2026 20:18:21 -0800 Subject: [PATCH 5/9] Fix fp8 quant FI interface issues (padding + strides) Signed-off-by: Zhanda --- .../layers/attention/mm_encoder_attention.py | 53 +++++-- .../layers/quantization/input_quant_fp8.py | 146 ++++++++++++++++++ vllm/model_executor/models/qwen3_vl.py | 33 ++++ 3 files changed, 219 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index e4ee54ffda37..1796cc894739 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -11,7 +11,10 @@ from vllm.config import MultiModalConfig from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.input_quant_fp8 import ( + QuantFP8, + quantize_fp8_pad_head_dim_triton, +) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.v1.attention.backends.fa_utils import get_flash_attn_version @@ -174,7 +177,10 @@ def _init_fp8_attention(self, layer_name: str) -> None: # Create QuantFP8 for efficient quantization self.fp8_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) self.fp8_enabled = True - + self.skip_scale_q = self.fp8_scales["q"] == 1.0 + self.skip_scale_k = self.fp8_scales["k"] == 1.0 + self.skip_scale_v = self.fp8_scales["v"] == 1.0 + logger.debug( f"FP8 attention enabled for {layer_name}: " f"q={self.fp8_scales['q']:.4f}, " @@ -283,17 +289,31 @@ def _quantize_to_fp8( self, tensor: torch.Tensor, scale: torch.Tensor, + skip_scale: bool = False, ) -> torch.Tensor: - """Quantize a 3D tensor (total_tokens, num_heads, head_dim) to FP8. - - Uses QuantFP8 CustomOp for backend-aware quantization. + """Quantize a 3D (S, H, D) tensor to FP8. + + Uses QuantFP8 CustomOp when head_dim is aligned to 16; otherwise + falls back to a stride-aware Triton kernel that pads head_dim to + a multiple of 16 — no extra copy even for non-contiguous inputs. """ assert self.fp8_quant is not None orig_shape = tensor.shape - # QuantFP8 expects 2D input: (total_tokens, num_heads * head_dim) - tensor_2d = tensor.reshape(orig_shape[0], -1) - fp8_tensor, _ = self.fp8_quant.forward_cuda(tensor_2d, scale=scale) - return fp8_tensor.reshape(orig_shape) + head_dim = orig_shape[-1] + + if head_dim % 16 == 0: + if skip_scale: + return tensor.to(torch.float8_e4m3fn) + + # QuantFP8 expects 2D input: (total_tokens, num_heads * head_dim) + tensor_2d = tensor.reshape(orig_shape[0], -1) + fp8_tensor, _ = self.fp8_quant.forward_cuda( + tensor_2d, scale=scale + ) + return fp8_tensor.reshape(orig_shape) + + # Fall back to Triton kernel for padding head_dim to a multiple of 16 + return quantize_fp8_pad_head_dim_triton(tensor, scale, skip_scale=skip_scale) def _forward_flashinfer( self, @@ -307,11 +327,11 @@ def _forward_flashinfer( ) -> torch.Tensor: if self.fp8_enabled: assert self.fp8_quant is not None and self.fp8_scales is not None - query = self._quantize_to_fp8(query, self._fp8_q_scale) - key = self._quantize_to_fp8(key, self._fp8_k_scale) - value = self._quantize_to_fp8(value, self._fp8_v_scale) + query = self._quantize_to_fp8(query, self._fp8_q_scale, skip_scale=self.skip_scale_q) + key = self._quantize_to_fp8(key, self._fp8_k_scale, skip_scale=self.skip_scale_k) + value = self._quantize_to_fp8(value, self._fp8_v_scale, skip_scale=self.skip_scale_v) - return vit_flashinfer_wrapper( + output = vit_flashinfer_wrapper( q=query, k=key, v=value, @@ -326,6 +346,13 @@ def _forward_flashinfer( o_data_type=self.dtype if self.fp8_enabled else None, ) + # Un-pad head dimension if it was padded during FP8 quantization + if self.fp8_enabled and output.shape[-1] != self.head_size: + output = output[..., :self.head_size] + output = output.contiguous() + + return output + def _forward_fa4( self, query: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index c1a901c37a0b..dadc8833d920 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os import torch import torch.nn.functional as F @@ -13,12 +14,157 @@ group_broadcast, ) from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON, tl, triton _FP8_DTYPE = current_platform.fp8_dtype() _FP8_MIN, _FP8_MAX = get_fp8_min_max() _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) +@triton.jit +def _quantize_pad_fp8_kernel( + x_ptr, + y_ptr, + scale_ptr, + stride_xs, # input stride along token (seq) dim — may be non-contiguous + stride_xh, # input stride along head dim + stride_xd, # input stride along head_dim dim (usually 1) + stride_ys, # output stride along token dim (contiguous) + stride_yh, # output stride along head dim + stride_yd, # output stride along head_dim dim (usually 1) + num_heads, + n_rows, # total rows = S * H + n_cols, + n_cols_padded, + fp8_min, + fp8_max, + SKIP_SCALE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < n_rows + mask_out = mask_m[:, None] & (offs_n[None, :] < n_cols_padded) + mask_in = mask_m[:, None] & (offs_n[None, :] < n_cols) + + # Decompose flattened row into (token, head) for 3D stride indexing. + # This lets the kernel read directly from non-contiguous QKV views. + s = offs_m // num_heads + h = offs_m % num_heads + + x_ptrs = (x_ptr + + s[:, None] * stride_xs + + h[:, None] * stride_xh + + offs_n[None, :] * stride_xd) + x = tl.load(x_ptrs, mask=mask_in, other=0.0).to(tl.float32) + if SKIP_SCALE: + x_q = x + else: + scale = tl.load(scale_ptr) + x_q = x / scale + x_q = tl.where(mask_in, x_q, 0.0) + x_q = tl.clamp(x_q, fp8_min, fp8_max).to(y_ptr.dtype.element_ty) + + y_ptrs = (y_ptr + + s[:, None] * stride_ys + + h[:, None] * stride_yh + + offs_n[None, :] * stride_yd) + tl.store(y_ptrs, x_q, mask=mask_out) + + +def _get_fp8_pad_quant_config(padded_head_dim: int) -> tuple[int, int, int]: + # Blackwell: use a single static config to avoid recompiles. + if current_platform.is_device_capability_family(100): + block_n, num_warps, block_m = 128, 4, 16 + else: + block_n = triton.next_power_of_2(padded_head_dim) + block_n = max(16, min(block_n, 256)) + num_warps = 4 if block_n >= 128 else 2 + block_m = 16 + + env_block_n = os.getenv("VLLM_FP8_PAD_QUANT_BLOCK_N") + env_num_warps = os.getenv("VLLM_FP8_PAD_QUANT_NUM_WARPS") + env_block_m = os.getenv("VLLM_FP8_PAD_QUANT_BLOCK_M") + if env_block_n is not None: + block_n = max(16, min(int(env_block_n), 256)) + if env_num_warps is not None: + num_warps = int(env_num_warps) + if env_block_m is not None: + block_m = max(1, int(env_block_m)) + + return block_n, num_warps, block_m + + +def quantize_fp8_pad_head_dim_triton( + tensor: torch.Tensor, + scale: torch.Tensor, + skip_scale: bool = False, + block_n: int | None = None, + num_warps: int | None = None, + block_m: int | None = None, +) -> torch.Tensor: + """Quantize a 4D (B, S, H, D) or 3D (S, H, D) tensor to FP8 while padding D to a multiple of 16. + + Reads directly from the input using its 3D strides, so non-contiguous + views (e.g. Q/K/V slices from an interleaved QKV buffer) are handled + without an extra copy. Output is always a fresh contiguous tensor + with shape (S, H, padded_D). + """ + if not HAS_TRITON: + raise RuntimeError( + "Triton is required to quantize with head_dim padding." + ) + + original_shape = tensor.shape + if tensor.dim() == 4: + tensor = tensor.view(-1, tensor.shape[-2], tensor.shape[-1]) + assert tensor.dim() == 3, ( + f"Expected 3D input (S, H, D), got {tensor.dim()}D" + ) + S, H, D = tensor.shape + padded_head_dim = (D + 15) // 16 * 16 + out_dtype = current_platform.fp8_dtype() + output = torch.empty( + (S, H, padded_head_dim), + device=tensor.device, + dtype=out_dtype, + ) + + scale_1d = scale.reshape(-1) + fp8_min, fp8_max = get_fp8_min_max() + n_rows = S * H + + if block_n is None or num_warps is None or block_m is None: + block_n, num_warps, block_m = _get_fp8_pad_quant_config(padded_head_dim) + + grid = (triton.cdiv(n_rows, block_m), + triton.cdiv(padded_head_dim, block_n)) + + _quantize_pad_fp8_kernel[grid]( + tensor, + output, + scale_1d, + tensor.stride(0), tensor.stride(1), tensor.stride(2), + output.stride(0), output.stride(1), output.stride(2), + H, + n_rows, + D, + padded_head_dim, + fp8_min, + fp8_max, + SKIP_SCALE=skip_scale, + BLOCK_M=block_m, + BLOCK_N=block_n, + num_warps=num_warps, + ) + + return output.view((*original_shape[:-1], padded_head_dim)) + + # --8<-- [start:quant_fp8] @CustomOp.register("quant_fp8") class QuantFP8(CustomOp): diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index e3ab33bfa5b7..54f335200996 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -610,6 +610,18 @@ def __init__( norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads + + # When FP8 attention is enabled and head_dim is not a multiple of 16, + # the quantization kernel pads head_dim (e.g. 72 -> 80). Store the + # padded hidden_size so compute_flashinfer_cu_seqlens can produce + # element offsets that match the contiguous FP8 tensor strides. + self.fp8_vit_attn = envs.VLLM_MM_ENCODER_FP8_ATTN + if self.fp8_vit_attn and head_dim % 16 != 0: + padded_head_dim = ((head_dim + 15) // 16) * 16 + self.fp8_padded_hidden_size = self.num_heads * padded_head_dim + else: + self.fp8_padded_hidden_size = None + self.rotary_pos_emb = get_rope( head_size=head_dim, max_position=8192, @@ -927,6 +939,27 @@ def compute_flashinfer_cu_seqlens( rotary_pos_emb_sin: torch.Tensor | None = None, ) -> np.ndarray: batch_size = len(cu_seqlens) - 1 + + if self.fp8_padded_hidden_size is not None: + # FP8 path: after quantization Q/K/V are each independent + # contiguous tensors with stride H * padded_D per token. + # All sections (QK, V, O) use the same element stride. + # The wrapper overrides QK/V batch_offsets to use O offsets. + scale = self.fp8_padded_hidden_size // self.tp_size + cu_seqlens = cu_seqlens * scale + cu_seqlens_padded = self.add_padding_to_fi_seqlens( + cu_seqlens, batch_size, cu_seqlens[-1] + ) + return np.concatenate( + [cu_seqlens_padded, cu_seqlens_padded, cu_seqlens_padded] + ) + + # BF16 path: Q/K/V are non-contiguous views into shared buffers. + # Element stride per token differs by tensor: + # After rotary: Q,K in [Q,K] buffer -> stride 2×H×D + # No rotary: Q,K in [Q,K,V] buffer -> stride 3×H×D + # V always in [Q,K,V] buffer -> stride 3×H×D + # O is contiguous -> stride H×D scale = self.hidden_size // self.tp_size cu_seqlens = cu_seqlens * scale if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: From 1ca34213a82ea8ec77759fc9f36c4f4c474e4766 Mon Sep 17 00:00:00 2001 From: Zhanda Date: Sat, 7 Feb 2026 20:34:27 -0800 Subject: [PATCH 6/9] Fix & linting Signed-off-by: Zhanda --- vllm/envs.py | 1 + .../layers/attention/mm_encoder_attention.py | 33 ++++++++------ .../layers/quantization/input_quant_fp8.py | 44 ++++++++++--------- 3 files changed, 44 insertions(+), 34 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 1c4e9511d406..0be0a9003a3a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1456,6 +1456,7 @@ def get_vllm_port() -> int | None: # average at BF16; 100 entries ≈ 0.9 GB. "VLLM_POS_EMBED_CACHE_SIZE": lambda: int( os.getenv("VLLM_POS_EMBED_CACHE_SIZE", "100") + ), # Controls whether to use FP8 attention for multimodal encoder (e.g., ViT) "VLLM_MM_ENCODER_FP8_ATTN": lambda: bool( int(os.getenv("VLLM_MM_ENCODER_FP8_ATTN", "0")) diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 1796cc894739..1f57e103d0cb 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -53,7 +53,7 @@ def _load_fp8_scales_file(path: str | None) -> dict[str, dict[str, float]]: if q is not None and k is not None and v is not None: scales[layer_name] = {"q": float(q), "k": float(k), "v": float(v)} - logger.info(f"Loaded FP8 attention scales from {path} ({len(scales)} layers)") + logger.info("Loaded FP8 attention scales from %s (%d layers)", path, len(scales)) return scales @@ -180,12 +180,13 @@ def _init_fp8_attention(self, layer_name: str) -> None: self.skip_scale_q = self.fp8_scales["q"] == 1.0 self.skip_scale_k = self.fp8_scales["k"] == 1.0 self.skip_scale_v = self.fp8_scales["v"] == 1.0 - + logger.debug( - f"FP8 attention enabled for {layer_name}: " - f"q={self.fp8_scales['q']:.4f}, " - f"k={self.fp8_scales['k']:.4f}, " - f"v={self.fp8_scales['v']:.4f}" + "FP8 attention enabled for %s: q=%.4f, k=%.4f, v=%.4f", + layer_name, + self.fp8_scales["q"], + self.fp8_scales["k"], + self.fp8_scales["v"], ) @classmethod @@ -292,7 +293,7 @@ def _quantize_to_fp8( skip_scale: bool = False, ) -> torch.Tensor: """Quantize a 3D (S, H, D) tensor to FP8. - + Uses QuantFP8 CustomOp when head_dim is aligned to 16; otherwise falls back to a stride-aware Triton kernel that pads head_dim to a multiple of 16 — no extra copy even for non-contiguous inputs. @@ -307,9 +308,7 @@ def _quantize_to_fp8( # QuantFP8 expects 2D input: (total_tokens, num_heads * head_dim) tensor_2d = tensor.reshape(orig_shape[0], -1) - fp8_tensor, _ = self.fp8_quant.forward_cuda( - tensor_2d, scale=scale - ) + fp8_tensor, _ = self.fp8_quant.forward_cuda(tensor_2d, scale=scale) return fp8_tensor.reshape(orig_shape) # Fall back to Triton kernel for padding head_dim to a multiple of 16 @@ -327,9 +326,15 @@ def _forward_flashinfer( ) -> torch.Tensor: if self.fp8_enabled: assert self.fp8_quant is not None and self.fp8_scales is not None - query = self._quantize_to_fp8(query, self._fp8_q_scale, skip_scale=self.skip_scale_q) - key = self._quantize_to_fp8(key, self._fp8_k_scale, skip_scale=self.skip_scale_k) - value = self._quantize_to_fp8(value, self._fp8_v_scale, skip_scale=self.skip_scale_v) + query = self._quantize_to_fp8( + query, self._fp8_q_scale, skip_scale=self.skip_scale_q + ) + key = self._quantize_to_fp8( + key, self._fp8_k_scale, skip_scale=self.skip_scale_k + ) + value = self._quantize_to_fp8( + value, self._fp8_v_scale, skip_scale=self.skip_scale_v + ) output = vit_flashinfer_wrapper( q=query, @@ -348,7 +353,7 @@ def _forward_flashinfer( # Un-pad head dimension if it was padded during FP8 quantization if self.fp8_enabled and output.shape[-1] != self.head_size: - output = output[..., :self.head_size] + output = output[..., : self.head_size] output = output.contiguous() return output diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index dadc8833d920..97aa7cb2dca5 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os + import torch import torch.nn.functional as F @@ -33,7 +34,7 @@ def _quantize_pad_fp8_kernel( stride_yh, # output stride along head dim stride_yd, # output stride along head_dim dim (usually 1) num_heads, - n_rows, # total rows = S * H + n_rows, # total rows = S * H n_cols, n_cols_padded, fp8_min, @@ -56,10 +57,12 @@ def _quantize_pad_fp8_kernel( s = offs_m // num_heads h = offs_m % num_heads - x_ptrs = (x_ptr - + s[:, None] * stride_xs - + h[:, None] * stride_xh - + offs_n[None, :] * stride_xd) + x_ptrs = ( + x_ptr + + s[:, None] * stride_xs + + h[:, None] * stride_xh + + offs_n[None, :] * stride_xd + ) x = tl.load(x_ptrs, mask=mask_in, other=0.0).to(tl.float32) if SKIP_SCALE: x_q = x @@ -69,10 +72,12 @@ def _quantize_pad_fp8_kernel( x_q = tl.where(mask_in, x_q, 0.0) x_q = tl.clamp(x_q, fp8_min, fp8_max).to(y_ptr.dtype.element_ty) - y_ptrs = (y_ptr - + s[:, None] * stride_ys - + h[:, None] * stride_yh - + offs_n[None, :] * stride_yd) + y_ptrs = ( + y_ptr + + s[:, None] * stride_ys + + h[:, None] * stride_yh + + offs_n[None, :] * stride_yd + ) tl.store(y_ptrs, x_q, mask=mask_out) @@ -107,7 +112,7 @@ def quantize_fp8_pad_head_dim_triton( num_warps: int | None = None, block_m: int | None = None, ) -> torch.Tensor: - """Quantize a 4D (B, S, H, D) or 3D (S, H, D) tensor to FP8 while padding D to a multiple of 16. + """Quantize a 3D/4D tensor to FP8, padding head_dim to a multiple of 16. Reads directly from the input using its 3D strides, so non-contiguous views (e.g. Q/K/V slices from an interleaved QKV buffer) are handled @@ -115,16 +120,12 @@ def quantize_fp8_pad_head_dim_triton( with shape (S, H, padded_D). """ if not HAS_TRITON: - raise RuntimeError( - "Triton is required to quantize with head_dim padding." - ) + raise RuntimeError("Triton is required to quantize with head_dim padding.") original_shape = tensor.shape if tensor.dim() == 4: tensor = tensor.view(-1, tensor.shape[-2], tensor.shape[-1]) - assert tensor.dim() == 3, ( - f"Expected 3D input (S, H, D), got {tensor.dim()}D" - ) + assert tensor.dim() == 3, f"Expected 3D input (S, H, D), got {tensor.dim()}D" S, H, D = tensor.shape padded_head_dim = (D + 15) // 16 * 16 out_dtype = current_platform.fp8_dtype() @@ -141,15 +142,18 @@ def quantize_fp8_pad_head_dim_triton( if block_n is None or num_warps is None or block_m is None: block_n, num_warps, block_m = _get_fp8_pad_quant_config(padded_head_dim) - grid = (triton.cdiv(n_rows, block_m), - triton.cdiv(padded_head_dim, block_n)) + grid = (triton.cdiv(n_rows, block_m), triton.cdiv(padded_head_dim, block_n)) _quantize_pad_fp8_kernel[grid]( tensor, output, scale_1d, - tensor.stride(0), tensor.stride(1), tensor.stride(2), - output.stride(0), output.stride(1), output.stride(2), + tensor.stride(0), + tensor.stride(1), + tensor.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), H, n_rows, D, From f4bf0ac2a5133bcc18667c81bb2efefaabda0786 Mon Sep 17 00:00:00 2001 From: Zhanda Date: Sat, 7 Feb 2026 20:51:33 -0800 Subject: [PATCH 7/9] Temporarily change the flashinfer branch Signed-off-by: Zhanda --- docker/Dockerfile | 2 +- docker/versions.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 9a866c786d90..3fad6a33f0c6 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -495,7 +495,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Install FlashInfer from CentML fork # https://github.com/CentML/flashinfer/tree/mlperf-inf-mm-q3vl-v6.0 ARG FLASHINFER_REPO=https://github.com/CentML/flashinfer.git -ARG FLASHINFER_BRANCH=mlperf-inf-mm-q3vl-v6.0-rc1 +ARG FLASHINFER_BRANCH=zhanda-fp8-vit-attn ARG FLASHINFER_CUBIN_VERSION=0.5.3 ARG FLASHINFER_JIT_CACHE_VERSION=0.5.3 RUN --mount=type=cache,target=/root/.cache/uv \ diff --git a/docker/versions.json b/docker/versions.json index 71dbfad57846..3113db4cce53 100644 --- a/docker/versions.json +++ b/docker/versions.json @@ -71,7 +71,7 @@ "default": "https://github.com/CentML/flashinfer.git" }, "FLASHINFER_BRANCH": { - "default": "mlperf-inf-mm-q3vl-v6.0-rc1" + "default": "zhanda-fp8-vit-attn" }, "FLASHINFER_CUBIN_VERSION": { "default": "0.5.3" From 3abd55bbf90bdfe57c5449e2512347b4e7aa133e Mon Sep 17 00:00:00 2001 From: root Date: Sat, 7 Feb 2026 22:24:07 -0800 Subject: [PATCH 8/9] Add guard for fp8 attn backend & update logging Signed-off-by: <> --- .../layers/attention/mm_encoder_attention.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 1f57e103d0cb..c35cf42090a7 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -53,7 +53,9 @@ def _load_fp8_scales_file(path: str | None) -> dict[str, dict[str, float]]: if q is not None and k is not None and v is not None: scales[layer_name] = {"q": float(q), "k": float(k), "v": float(v)} - logger.info("Loaded FP8 attention scales from %s (%d layers)", path, len(scales)) + logger.info_once( + "Loaded FP8 attention scales from %s (%d layers)", path, len(scales) + ) return scales @@ -127,12 +129,18 @@ def __init__( logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") - # FP8 attention support + # FP8 attention support (currently only FlashInfer cuDNN backend) self.fp8_enabled = False self.fp8_scales: dict[str, float] | None = None self.fp8_quant: QuantFP8 | None = None if envs.VLLM_MM_ENCODER_FP8_ATTN: + if self.attn_backend != AttentionBackendEnum.FLASHINFER: + raise ValueError( + "VLLM_MM_ENCODER_FP8_ATTN requires the FlashInfer " + "cuDNN backend (FLASHINFER), but the current ViT " + f"attention backend is {self.attn_backend}." + ) self._init_fp8_attention(prefix) def _init_fp8_attention(self, layer_name: str) -> None: @@ -183,7 +191,7 @@ def _init_fp8_attention(self, layer_name: str) -> None: logger.debug( "FP8 attention enabled for %s: q=%.4f, k=%.4f, v=%.4f", - layer_name, + layer_name if layer_name else "MMEncoderAttention", self.fp8_scales["q"], self.fp8_scales["k"], self.fp8_scales["v"], From ca80913184f1934ac1c34dc906e740145a251e32 Mon Sep 17 00:00:00 2001 From: Zhanda Date: Sat, 7 Feb 2026 23:03:51 -0800 Subject: [PATCH 9/9] Update the flashinfer branch Signed-off-by: Zhanda --- docker/Dockerfile | 2 +- docker/versions.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 3fad6a33f0c6..9064ea51632b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -495,7 +495,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Install FlashInfer from CentML fork # https://github.com/CentML/flashinfer/tree/mlperf-inf-mm-q3vl-v6.0 ARG FLASHINFER_REPO=https://github.com/CentML/flashinfer.git -ARG FLASHINFER_BRANCH=zhanda-fp8-vit-attn +ARG FLASHINFER_BRANCH=mlperf-inf-mm-q3vl-v6.0 ARG FLASHINFER_CUBIN_VERSION=0.5.3 ARG FLASHINFER_JIT_CACHE_VERSION=0.5.3 RUN --mount=type=cache,target=/root/.cache/uv \ diff --git a/docker/versions.json b/docker/versions.json index 3113db4cce53..3bb174eea948 100644 --- a/docker/versions.json +++ b/docker/versions.json @@ -71,7 +71,7 @@ "default": "https://github.com/CentML/flashinfer.git" }, "FLASHINFER_BRANCH": { - "default": "zhanda-fp8-vit-attn" + "default": "mlperf-inf-mm-q3vl-v6.0" }, "FLASHINFER_CUBIN_VERSION": { "default": "0.5.3"