From f3420643a0bd43f193629823f6613e265b3cec1e Mon Sep 17 00:00:00 2001 From: Yu Bo Gao Date: Thu, 5 Feb 2026 16:35:44 -0500 Subject: [PATCH 1/7] wip --- .../layers/attention/mm_encoder_attention.py | 312 ++++++++++++++++++ vllm/model_executor/models/qwen3_vl.py | 1 + vllm/platforms/cuda.py | 1 + vllm/v1/attention/backends/registry.py | 1 + vllm/v1/worker/gpu_model_runner.py | 94 +++++- 5 files changed, 396 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 28d83776ebe5..338f31af0751 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -19,6 +19,22 @@ logger = init_logger(__name__) +# Batch buckets for cuDNN graph caching - graphs are cached per bucket size +# This avoids creating a new graph for each unique batch size at runtime +TE_BATCH_BUCKETS = [1, 2, 4, 8, 16, 32] + +# Fixed max_seqlen to avoid cuDNN recompilation when sequence lengths vary +TE_FIXED_MAX_SEQLEN = 128 * 1024 + +try: + from transformer_engine.common.recipe import DelayedScaling + from transformer_engine.pytorch import DotProductAttention, fp8_autocast +except ImportError: + DotProductAttention = None + fp8_autocast = None + DelayedScaling = None + logger.warning("TransformerEngine is not installed.") + # --8<-- [start:mm_encoder_attn] @CustomOp.register("mm_encoder_attn") @@ -88,6 +104,24 @@ def __init__( get_flash_attn_version() if self.is_flash_attn_backend else None ) + # Initialize Transformer Engine FP8 attention if backend is TE + # for each batch size + self.te_attn_op = None + self.te_fp8_recipe = None + self.is_te_fp8_backend = ( + self.attn_backend == AttentionBackendEnum.TE_FP8 + if hasattr(AttentionBackendEnum, 'TE_FP8') + else False + ) + + if self.is_te_fp8_backend: + if DotProductAttention is None: + raise ImportError( + "TransformerEngine is not installed but TE_FP8 backend was selected" + ) + self.te_fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=True) + logger.info_once("Initialized FP8 Transformer Engine for MMEncoderAttention.") + logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") @classmethod @@ -118,6 +152,49 @@ def maybe_reshape_qkv_to_4d( return query, key, value + def _pad_cu_seqlens_to_bucket( + self, + cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + """Pad cu_seqlens to the next batch bucket size to avoid cuDNN recompilation. + + cu_seqlens has length (batch_size + 1), so we pad to (bucket_size + 1). + Padding value is the last value (total tokens), so padded sequences have 0 length. + """ + batch_size = cu_seqlens.size(0) - 1 + # Find the next bucket size >= batch_size + bucket_size = next( + (b for b in TE_BATCH_BUCKETS if b >= batch_size), TE_BATCH_BUCKETS[-1] + ) + if bucket_size == batch_size: + return cu_seqlens + + # Pad cu_seqlens: add entries with the same value as the last entry + # This means padded sequences have 0 length + padding_size = bucket_size - batch_size + padding = cu_seqlens[-1:].expand(padding_size) + return torch.cat([cu_seqlens, padding]) + + def _lazy_init_te_attn( + self, + num_attention_heads: int, + kv_channels: int, + num_gqa_groups: int | None, + attn_mask_type: str, + softmax_scale: float | None, + qkv_format: str = "bshd", + ) -> None: + """Lazily initialize Transformer Engine attention operator.""" + if self.te_attn_op is None: + self.te_attn_op = DotProductAttention( + num_attention_heads, + kv_channels, + num_gqa_groups=num_gqa_groups, + attn_mask_type=attn_mask_type, + softmax_scale=softmax_scale, + qkv_format=qkv_format, + ) + def _forward_sdpa( self, query: torch.Tensor, @@ -187,6 +264,239 @@ def _forward_fa( output = output.reshape(bsz, q_len, -1) return output + def _forward_te_fp8( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass using Transformer Engine FP8 attention with BSHD format. + + Input shape: + (batch_size x seq_len x hidden_size) where hidden_size = num_heads * head_size + or (batch_size x seq_len x num_heads x head_size) + + Uses BSHD format: (batch, seq, heads, dim) + + Note: TE natively supports GQA, so we don't expand KV heads like other backends. + Note: Head dimension is padded to multiple of 16 for optimal performance. + """ + bsz, q_len = query.size()[:2] + kv_len = key.size(1) + is_3d_input = query.dim() == 3 + + # Transform to BSHD format: (batch, seq, heads, dim) + if is_3d_input: + # Input is (batch, seq, hidden_size) - reshape to (batch, seq, heads, dim) + query = query.view(bsz, q_len, self.num_heads, self.head_size) + key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) + value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) + # else: already in (batch, seq, heads, dim) format + + # Pad head dimension to multiple of 16 for optimal performance + original_head_size = self.head_size + padded_head_size = ((self.head_size + 15) // 16) * 16 + needs_padding = padded_head_size != original_head_size + + if needs_padding: + pad_size = padded_head_size - original_head_size + query = torch.nn.functional.pad(query, (0, pad_size)) + key = torch.nn.functional.pad(key, (0, pad_size)) + value = torch.nn.functional.pad(value, (0, pad_size)) + + # Determine if we have variable sequence lengths + # cu_seqlens indicates variable lengths when provided + attention_mask = None + if cu_seqlens is not None: + # Variable sequence lengths - need padding mask + attn_mask_type = "padding" + else: + # Uniform sequence lengths - no mask needed + attn_mask_type = "no_mask" + + # Determine GQA groups - TE will handle the GQA logic internally + num_gqa_groups = self.num_kv_heads if self.num_kv_heads != self.num_heads else None + + # Lazy initialization of TE attention operator + self._lazy_init_te_attn( + num_attention_heads=self.num_heads, + kv_channels=padded_head_size, + num_gqa_groups=num_gqa_groups, + attn_mask_type=attn_mask_type, + softmax_scale=self.scale, + qkv_format="bshd", + ) + + max_seqlen = TE_FIXED_MAX_SEQLEN + + # NVTX annotation with all parameters for lazy_init and te_attn_op + nvtx_msg = ( + f"TE_FP8_BSHD: " + f"Q={tuple(query.shape)}, K={tuple(key.shape)}, V={tuple(value.shape)}, " + f"num_heads={self.num_heads}, kv_channels={padded_head_size}, " + f"num_gqa_groups={num_gqa_groups}, attn_mask_type={attn_mask_type}, " + f"softmax_scale={self.scale}, qkv_format=bshd, " + f"cu_seqlens={cu_seqlens.shape if cu_seqlens is not None else None}, " + f"max_seqlen={max_seqlen}" + ) + with torch.cuda.nvtx.range(nvtx_msg): + with fp8_autocast(enabled=True, fp8_recipe=self.te_fp8_recipe): + output = self.te_attn_op( + query, + key, + value, + attention_mask=None, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + ) + + # Output is (batch, seq, heads, padded_dim) or (batch, seq, heads*padded_dim) + # Handle both cases + if output.dim() == 3: + # Output is (batch, seq, heads*dim) flattened + output = output.reshape(bsz, q_len, self.num_heads, padded_head_size) + + # Remove head padding if needed + if needs_padding: + output = output[..., :original_head_size] + + # Reshape back to original format + if is_3d_input: + # Back to (batch, seq, hidden_size) where hidden_size = H * D + output = output.reshape(bsz, q_len, self.num_heads * original_head_size) + else: + # Already in (batch, seq, num_heads, head_size) format + pass + + + return output + + def _forward_te_fp8_thd( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass using Transformer Engine FP8 attention with THD format. + + Input shape: + (batch_size x seq_len x hidden_size) where hidden_size = num_heads * head_size + or (batch_size x seq_len x num_heads x head_size) + + Uses THD format: (T, H, D) where T = batch*seq, H = num_heads, D = head_size + + Note: TE natively supports GQA, so we don't expand KV heads like other backends. + Note: Head dimension is padded to multiple of 16 for optimal performance. + """ + bsz, q_len = query.size()[:2] + kv_len = key.size(1) + is_3d_input = query.dim() == 3 + + # For THD format, we need cu_seqlens + # Generate if not provided (assumes uniform sequence lengths) + if cu_seqlens is None: + cu_seqlens = torch.arange( + 0, (bsz + 1) * q_len, q_len, + dtype=torch.int32, device=query.device + ) + + # Pad cu_seqlens to the next batch bucket size to avoid cuDNN recompilation + cu_seqlens = self._pad_cu_seqlens_to_bucket(cu_seqlens) + + # Use fixed max_seqlen to avoid cuDNN recompilation when sequence lengths vary + max_seqlen = TE_FIXED_MAX_SEQLEN + + # Transform to THD format: (T, H, D) where T = batch*seq + total_tokens_q = bsz * q_len + total_tokens_kv = bsz * kv_len + + if is_3d_input: + # Input is (batch, seq, hidden_size) - need to split hidden into (H, D) + query = query.view(total_tokens_q, self.num_heads, self.head_size) + key = key.view(total_tokens_kv, self.num_kv_heads, self.head_size) + value = value.view(total_tokens_kv, self.num_kv_heads, self.head_size) + else: + # Input is (batch, seq, num_heads, head_size) - just flatten batch and seq + query = query.view(total_tokens_q, self.num_heads, self.head_size) + key = key.view(total_tokens_kv, self.num_kv_heads, self.head_size) + value = value.view(total_tokens_kv, self.num_kv_heads, self.head_size) + + # Pad head dimension to multiple of 16 for optimal performance + original_head_size = self.head_size + padded_head_size = ((self.head_size + 15) // 16) * 16 + needs_padding = padded_head_size != original_head_size + + if needs_padding: + pad_size = padded_head_size - original_head_size + query = torch.nn.functional.pad(query, (0, pad_size)) + key = torch.nn.functional.pad(key, (0, pad_size)) + value = torch.nn.functional.pad(value, (0, pad_size)) + + # For THD format, attn_mask_type must be "padding" or "padding_causal" + attn_mask_type = "padding" + + # Determine GQA groups - TE will handle the GQA logic internally + num_gqa_groups = self.num_kv_heads if self.num_kv_heads != self.num_heads else None + + # Lazy initialization of TE attention operator + + self._lazy_init_te_attn( + num_attention_heads=self.num_heads, + kv_channels=padded_head_size, + num_gqa_groups=num_gqa_groups, + attn_mask_type=attn_mask_type, + softmax_scale=self.scale, + qkv_format="thd", + ) + + # NVTX annotation with all parameters for lazy_init and te_attn_op + nvtx_msg = ( + f"TE_FP8_THD: " + f"Q={tuple(query.shape)}, K={tuple(key.shape)}, V={tuple(value.shape)}, " + f"batch_size={bsz}, num_heads={self.num_heads}, kv_channels={padded_head_size}, " + f"num_gqa_groups={num_gqa_groups}, attn_mask_type={attn_mask_type}, " + f"softmax_scale={self.scale}, qkv_format=thd, " + f"cu_seqlens={cu_seqlens.shape}, max_seqlen={max_seqlen}" + ) + with torch.cuda.nvtx.range(nvtx_msg): + with fp8_autocast(enabled=True, fp8_recipe=self.te_fp8_recipe): + output = self.te_attn_op( + query, + key, + value, + attention_mask=None, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + ) + + # TE returns (T, H*D_padded) flattened, need to reshape to (T, H, D_padded) first + if output.dim() == 2: + # Output is (T, H*D_padded) - reshape to (T, H, D_padded) + output = output.reshape(total_tokens_q, self.num_heads, padded_head_size) + + # Remove head padding if needed: (T, H, D_padded) -> (T, H, D_original) + if needs_padding: + output = output[..., :original_head_size] + + # Reshape back to original format + if is_3d_input: + # Back to (batch, seq, hidden_size) where hidden_size = H * D + output = output.reshape(bsz, q_len, self.num_heads * original_head_size) + else: + # Back to (batch, seq, num_heads, head_size) + output = output.reshape(bsz, q_len, self.num_heads, original_head_size) + + + return output + def _forward_flashinfer( self, query: torch.Tensor, @@ -274,6 +584,8 @@ def forward_cuda( ) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: return self._forward_sdpa(query, key, value, cu_seqlens) + elif self.is_te_fp8_backend: + return self._forward_te_fp8(query, key, value, cu_seqlens, max_seqlen) else: raise ValueError( f"Unsupported multi-modal encoder attention backend for CUDA: " diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 09972ca7fb4c..db077236cd84 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -448,6 +448,7 @@ def __init__( AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.TE_FP8, }: raise RuntimeError( f"Qwen3-VL does not support {self.attn_backend} backend now." diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 020e948a4a40..f55a98cce393 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -364,6 +364,7 @@ def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: AttentionBackendEnum.FLASH_ATTN_CUTE, AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.TE_FP8, ] @classmethod diff --git a/vllm/v1/attention/backends/registry.py b/vllm/v1/attention/backends/registry.py index 6bdf9691b402..c681d97efd49 100644 --- a/vllm/v1/attention/backends/registry.py +++ b/vllm/v1/attention/backends/registry.py @@ -62,6 +62,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): "vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend" ) TORCH_SDPA = "" # this tag is only used for ViT + TE_FP8 = "transformer_engine.pytorch.DotProductAttention" # this tag is only used for MMEncoderAttention FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" FLASHINFER_MLA = ( "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a0284184891f..15db21868f11 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2651,19 +2651,11 @@ def _execute_mm_encoder( if piecewise_result is not None: curr_group_outputs = piecewise_result else: - # Fall back to non-padded execution. - # Run the encoder. - # `curr_group_outputs` is either of the following: - # 1. A tensor of shape - # (num_items, feature_size, hidden_size) - # in case feature_size is fixed across all - # multimodal items. - # 2. A list or tuple (length: num_items) of tensors, - # each of shape (feature_size, hidden_size) in - # case the feature size is dynamic depending on - # the input multimodal items. - curr_group_outputs = model.embed_multimodal( - **mm_kwargs_group + # Fall back to eager execution, one image at a time. + # This ensures consistent behavior and reduces peak + # memory usage compared to batch processing. + curr_group_outputs = self._execute_encoder_one_by_one_eager( + model, mm_kwargs_group, modality, num_items ) sanity_check_mm_encoder_outputs( @@ -2825,6 +2817,82 @@ def _execute_with_encoder_cudagraph( ) return None + def _execute_encoder_one_by_one_eager( + self, + model: "SupportsMultiModal", + mm_kwargs_group: dict, + modality: str, + num_items: int, + ) -> list[torch.Tensor]: + """ + Execute encoder in eager mode, processing one image at a time. + + This ensures consistent behavior and reduces peak memory usage + compared to batch processing all images together. + + Args: + model: The multimodal model + mm_kwargs_group: Batched multimodal kwargs + modality: The modality type ("image" or "video") + num_items: Number of items in the batch + + Returns: + List of encoder outputs, one per image + """ + # For single item, just process directly + if num_items == 1: + return list(model.embed_multimodal(**mm_kwargs_group)) + + # Only process image/video modalities one-by-one + if modality not in ("image", "video"): + return list(model.embed_multimodal(**mm_kwargs_group)) + + # Extract batched data + if modality == "image": + batched_pixel_values = mm_kwargs_group.get("pixel_values") + grid_thw_list = mm_kwargs_group.get("image_grid_thw") + grid_key = "image_grid_thw" + pixel_key = "pixel_values" + else: # video + batched_pixel_values = mm_kwargs_group.get("pixel_values_videos") + grid_thw_list = mm_kwargs_group.get("video_grid_thw") + grid_key = "video_grid_thw" + pixel_key = "pixel_values_videos" + + # If we can't extract the data, fall back to batch processing + if batched_pixel_values is None or grid_thw_list is None: + return list(model.embed_multimodal(**mm_kwargs_group)) + + # Convert grid_thw to list if tensor + if isinstance(grid_thw_list, torch.Tensor): + grid_thw_list = grid_thw_list.tolist() + + # Process each image one at a time + outputs: list[torch.Tensor] = [] + patch_offset = 0 + + for grid_thw in grid_thw_list: + t, h, w = grid_thw + num_patches = t * h * w + + # Slice pixel_values for this image + single_pixel_values = batched_pixel_values[ + patch_offset : patch_offset + num_patches + ] + patch_offset += num_patches + + # Build single-image kwargs + single_mm_inputs = { + pixel_key: single_pixel_values, + grid_key: torch.tensor([grid_thw], dtype=torch.int64), + } + + # Process this single image + single_output = model.embed_multimodal(**single_mm_inputs) + outputs.append(single_output[0]) + + return outputs + def _execute_grouped_batched_encoder( self, model: "SupportsMultiModal", From 0fdaa97f1f3f3104fe40a0576b5aa71fb6bc1a75 Mon Sep 17 00:00:00 2001 From: Yu Bo Gao Date: Thu, 5 Feb 2026 18:07:37 -0500 Subject: [PATCH 2/7] pad to avoid recompilation during eager --- .../layers/attention/mm_encoder_attention.py | 45 ++++++++++++++++--- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 338f31af0751..88c43786d617 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -23,6 +23,10 @@ # This avoids creating a new graph for each unique batch size at runtime TE_BATCH_BUCKETS = [1, 2, 4, 8, 16, 32] +# Seqlen buckets for BSHD format - Q/K/V tensors are padded to these sizes +# so cuDNN sees a fixed set of tensor shapes and avoids recompilation +TE_SEQLEN_BUCKETS = [2048, 4096, 6144, 10240, 16384, 25600, 35840, 65536] + # Fixed max_seqlen to avoid cuDNN recompilation when sequence lengths vary TE_FIXED_MAX_SEQLEN = 128 * 1024 @@ -152,6 +156,17 @@ def maybe_reshape_qkv_to_4d( return query, key, value + @staticmethod + def _find_seqlen_bucket(seqlen: int) -> int | None: + """Find the smallest seqlen bucket that can fit the given seqlen. + + Returns None if seqlen exceeds the largest bucket. + """ + for bucket in TE_SEQLEN_BUCKETS: + if bucket >= seqlen: + return bucket + return None + def _pad_cu_seqlens_to_bucket( self, cu_seqlens: torch.Tensor, @@ -305,7 +320,18 @@ def _forward_te_fp8( query = torch.nn.functional.pad(query, (0, pad_size)) key = torch.nn.functional.pad(key, (0, pad_size)) value = torch.nn.functional.pad(value, (0, pad_size)) - + + # Pad Q/K/V seqlen dimension to a bucket size to avoid cuDNN + # recompilation when different images have different resolutions. + # cu_seqlens already tracks the real sequence boundaries. + bucket_seqlen = self._find_seqlen_bucket(q_len) + if bucket_seqlen is not None and bucket_seqlen > q_len: + seq_pad = bucket_seqlen - q_len + # Pad S dimension: shape is (B, S, H, D), so pad dim=1 + query = torch.nn.functional.pad(query, (0, 0, 0, 0, 0, seq_pad)) + key = torch.nn.functional.pad(key, (0, 0, 0, 0, 0, seq_pad)) + value = torch.nn.functional.pad(value, (0, 0, 0, 0, 0, seq_pad)) + # Determine if we have variable sequence lengths # cu_seqlens indicates variable lengths when provided attention_mask = None @@ -354,12 +380,18 @@ def _forward_te_fp8( max_seqlen_kv=max_seqlen, ) - # Output is (batch, seq, heads, padded_dim) or (batch, seq, heads*padded_dim) - # Handle both cases + # Output is (batch, padded_seq, heads, padded_dim) or + # (batch, padded_seq, heads*padded_dim). + # Handle both cases. if output.dim() == 3: - # Output is (batch, seq, heads*dim) flattened - output = output.reshape(bsz, q_len, self.num_heads, padded_head_size) - + # Output is (batch, padded_seq, heads*dim) flattened + output = output.reshape( + bsz, output.size(1), self.num_heads, padded_head_size + ) + + # Slice back to original seqlen (remove S-dimension padding) + output = output[:, :q_len, :, :] + # Remove head padding if needed if needs_padding: output = output[..., :original_head_size] @@ -372,7 +404,6 @@ def _forward_te_fp8( # Already in (batch, seq, num_heads, head_size) format pass - return output def _forward_te_fp8_thd( From dd1abce58c2a61a3f6f05f6351c5b2dd0d3ca072 Mon Sep 17 00:00:00 2001 From: Yu Bo Gao Date: Mon, 9 Feb 2026 12:23:01 -0800 Subject: [PATCH 3/7] cleanup --- .../layers/attention/mm_encoder_attention.py | 169 ++---------------- 1 file changed, 19 insertions(+), 150 deletions(-) diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 88c43786d617..37dcc7d24bb2 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -19,10 +19,6 @@ logger = init_logger(__name__) -# Batch buckets for cuDNN graph caching - graphs are cached per bucket size -# This avoids creating a new graph for each unique batch size at runtime -TE_BATCH_BUCKETS = [1, 2, 4, 8, 16, 32] - # Seqlen buckets for BSHD format - Q/K/V tensors are padded to these sizes # so cuDNN sees a fixed set of tensor shapes and avoids recompilation TE_SEQLEN_BUCKETS = [2048, 4096, 6144, 10240, 16384, 25600, 35840, 65536] @@ -167,29 +163,6 @@ def _find_seqlen_bucket(seqlen: int) -> int | None: return bucket return None - def _pad_cu_seqlens_to_bucket( - self, - cu_seqlens: torch.Tensor, - ) -> torch.Tensor: - """Pad cu_seqlens to the next batch bucket size to avoid cuDNN recompilation. - - cu_seqlens has length (batch_size + 1), so we pad to (bucket_size + 1). - Padding value is the last value (total tokens), so padded sequences have 0 length. - """ - batch_size = cu_seqlens.size(0) - 1 - # Find the next bucket size >= batch_size - bucket_size = next( - (b for b in TE_BATCH_BUCKETS if b >= batch_size), TE_BATCH_BUCKETS[-1] - ) - if bucket_size == batch_size: - return cu_seqlens - - # Pad cu_seqlens: add entries with the same value as the last entry - # This means padded sequences have 0 length - padding_size = bucket_size - batch_size - padding = cu_seqlens[-1:].expand(padding_size) - return torch.cat([cu_seqlens, padding]) - def _lazy_init_te_attn( self, num_attention_heads: int, @@ -289,15 +262,33 @@ def _forward_te_fp8( ) -> torch.Tensor: """Forward pass using Transformer Engine FP8 attention with BSHD format. + IMPORTANT: This function processes ONE sample at a time. When cu_seqlens + is provided, it must have length 2 (i.e., [0, seq_len] for a single + sequence). + + This batch-1 restriction is a limitation of Transformer Engine, not + cuDNN. TE does not support THD format for FP8 attention, and converting + the upstream THD tensor into a proper multi-batch BSHD tensor would be + too expensive. Instead, we manually reinterpret a batch-1 THD tensor as + BSHD with B=1 and S=T (the total token count), then call the BSHD + kernel. This is semantically consistent because a single sequence in + THD is equivalent to B=1 BSHD. + Input shape: (batch_size x seq_len x hidden_size) where hidden_size = num_heads * head_size or (batch_size x seq_len x num_heads x head_size) Uses BSHD format: (batch, seq, heads, dim) - Note: TE natively supports GQA, so we don't expand KV heads like other backends. Note: Head dimension is padded to multiple of 16 for optimal performance. """ + # Validate single-sample constraint + if cu_seqlens is not None: + assert len(cu_seqlens) == 2, ( + f"_forward_te_fp8 (BSHD format) requires exactly one sample at a time. " + f"cu_seqlens must have length 2 (got {len(cu_seqlens)}). " + ) + bsz, q_len = query.size()[:2] kv_len = key.size(1) is_3d_input = query.dim() == 3 @@ -406,128 +397,6 @@ def _forward_te_fp8( return output - def _forward_te_fp8_thd( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: torch.Tensor | None = None, - ) -> torch.Tensor: - """Forward pass using Transformer Engine FP8 attention with THD format. - - Input shape: - (batch_size x seq_len x hidden_size) where hidden_size = num_heads * head_size - or (batch_size x seq_len x num_heads x head_size) - - Uses THD format: (T, H, D) where T = batch*seq, H = num_heads, D = head_size - - Note: TE natively supports GQA, so we don't expand KV heads like other backends. - Note: Head dimension is padded to multiple of 16 for optimal performance. - """ - bsz, q_len = query.size()[:2] - kv_len = key.size(1) - is_3d_input = query.dim() == 3 - - # For THD format, we need cu_seqlens - # Generate if not provided (assumes uniform sequence lengths) - if cu_seqlens is None: - cu_seqlens = torch.arange( - 0, (bsz + 1) * q_len, q_len, - dtype=torch.int32, device=query.device - ) - - # Pad cu_seqlens to the next batch bucket size to avoid cuDNN recompilation - cu_seqlens = self._pad_cu_seqlens_to_bucket(cu_seqlens) - - # Use fixed max_seqlen to avoid cuDNN recompilation when sequence lengths vary - max_seqlen = TE_FIXED_MAX_SEQLEN - - # Transform to THD format: (T, H, D) where T = batch*seq - total_tokens_q = bsz * q_len - total_tokens_kv = bsz * kv_len - - if is_3d_input: - # Input is (batch, seq, hidden_size) - need to split hidden into (H, D) - query = query.view(total_tokens_q, self.num_heads, self.head_size) - key = key.view(total_tokens_kv, self.num_kv_heads, self.head_size) - value = value.view(total_tokens_kv, self.num_kv_heads, self.head_size) - else: - # Input is (batch, seq, num_heads, head_size) - just flatten batch and seq - query = query.view(total_tokens_q, self.num_heads, self.head_size) - key = key.view(total_tokens_kv, self.num_kv_heads, self.head_size) - value = value.view(total_tokens_kv, self.num_kv_heads, self.head_size) - - # Pad head dimension to multiple of 16 for optimal performance - original_head_size = self.head_size - padded_head_size = ((self.head_size + 15) // 16) * 16 - needs_padding = padded_head_size != original_head_size - - if needs_padding: - pad_size = padded_head_size - original_head_size - query = torch.nn.functional.pad(query, (0, pad_size)) - key = torch.nn.functional.pad(key, (0, pad_size)) - value = torch.nn.functional.pad(value, (0, pad_size)) - - # For THD format, attn_mask_type must be "padding" or "padding_causal" - attn_mask_type = "padding" - - # Determine GQA groups - TE will handle the GQA logic internally - num_gqa_groups = self.num_kv_heads if self.num_kv_heads != self.num_heads else None - - # Lazy initialization of TE attention operator - - self._lazy_init_te_attn( - num_attention_heads=self.num_heads, - kv_channels=padded_head_size, - num_gqa_groups=num_gqa_groups, - attn_mask_type=attn_mask_type, - softmax_scale=self.scale, - qkv_format="thd", - ) - - # NVTX annotation with all parameters for lazy_init and te_attn_op - nvtx_msg = ( - f"TE_FP8_THD: " - f"Q={tuple(query.shape)}, K={tuple(key.shape)}, V={tuple(value.shape)}, " - f"batch_size={bsz}, num_heads={self.num_heads}, kv_channels={padded_head_size}, " - f"num_gqa_groups={num_gqa_groups}, attn_mask_type={attn_mask_type}, " - f"softmax_scale={self.scale}, qkv_format=thd, " - f"cu_seqlens={cu_seqlens.shape}, max_seqlen={max_seqlen}" - ) - with torch.cuda.nvtx.range(nvtx_msg): - with fp8_autocast(enabled=True, fp8_recipe=self.te_fp8_recipe): - output = self.te_attn_op( - query, - key, - value, - attention_mask=None, - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - ) - - # TE returns (T, H*D_padded) flattened, need to reshape to (T, H, D_padded) first - if output.dim() == 2: - # Output is (T, H*D_padded) - reshape to (T, H, D_padded) - output = output.reshape(total_tokens_q, self.num_heads, padded_head_size) - - # Remove head padding if needed: (T, H, D_padded) -> (T, H, D_original) - if needs_padding: - output = output[..., :original_head_size] - - # Reshape back to original format - if is_3d_input: - # Back to (batch, seq, hidden_size) where hidden_size = H * D - output = output.reshape(bsz, q_len, self.num_heads * original_head_size) - else: - # Back to (batch, seq, num_heads, head_size) - output = output.reshape(bsz, q_len, self.num_heads, original_head_size) - - - return output - def _forward_flashinfer( self, query: torch.Tensor, From 413d6c3df137604fc0c0891f10976ec952686556 Mon Sep 17 00:00:00 2001 From: Yu Bo Gao Date: Mon, 9 Feb 2026 12:35:20 -0800 Subject: [PATCH 4/7] update docs --- vllm/v1/worker/gpu_model_runner.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 15db21868f11..269bc09b70d3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2487,10 +2487,12 @@ def _execute_mm_encoder( and num_items > 1 and modality in ("image", "video") ): - # Fall back to one-by-one processing for remaining images + # Fall back to one-by-one processing for remaining images. # Process each image individually for CUDA graph support + # and for TE FP8 compatibility (TE does not support THD + # format for FP8; see MMEncoderAttention._forward_te_fp8). # Extract batched data and slice per-image to avoid - # re-calling group_mm_kwargs_by_modality overhead + # re-calling group_mm_kwargs_by_modality overhead. # Note: list may contain None for unprocessed images; # these will be filled in by one-by-one processing below if has_partial_results and grouped_batched_result is not None: @@ -2652,8 +2654,8 @@ def _execute_mm_encoder( curr_group_outputs = piecewise_result else: # Fall back to eager execution, one image at a time. - # This ensures consistent behavior and reduces peak - # memory usage compared to batch processing. + # This is required by the TE FP8 attention backend + # which only supports batch-1 BSHD (see _forward_te_fp8). curr_group_outputs = self._execute_encoder_one_by_one_eager( model, mm_kwargs_group, modality, num_items ) @@ -2827,8 +2829,14 @@ def _execute_encoder_one_by_one_eager( """ Execute encoder in eager mode, processing one image at a time. - This ensures consistent behavior and reduces peak memory usage - compared to batch processing all images together. + One-at-a-time processing is required by the TE FP8 attention + backend (see MMEncoderAttention._forward_te_fp8 in + mm_encoder_attention.py). TE does not support THD format for FP8 + attention, and converting the upstream THD tensor into a proper + multi-batch BSHD tensor would be too expensive. Instead, we process + one image at a time so that the single-sequence THD tensor can be + reinterpreted as BSHD with B=1 and S=T, which is semantically + equivalent and avoids any data layout conversion. Args: model: The multimodal model From 7075fb8da48301ccc8ef0bdeb80a3bf1d964b010 Mon Sep 17 00:00:00 2001 From: Yu Bo Gao Date: Mon, 9 Feb 2026 13:16:41 -0800 Subject: [PATCH 5/7] install dependencies --- docker/Dockerfile | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docker/Dockerfile b/docker/Dockerfile index 9064ea51632b..f63631828ad8 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -588,6 +588,15 @@ RUN --mount=type=bind,from=build,src=/tmp/ep_kernels_workspace/dist,target=/vllm # Until then, add /usr/local/nvidia/lib64 before the image cuda path to allow override. ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib64:${LD_LIBRARY_PATH} +# Install Transformer Engine for FP8 attention support in multimodal encoder +RUN --mount=type=cache,target=/root/.cache/uv \ + apt-get update -y && \ + apt-get install -y --no-install-recommends cuda-toolkit-13-0 libcudnn9-dev-cuda-13 && \ + rm -rf /var/lib/apt/lists/* && \ + uv pip install --system pybind11 && \ + NVTE_FRAMEWORK=pytorch uv pip install --system --no-build-isolation \ + git+https://github.com/NVIDIA/TransformerEngine.git@stable + # Copy examples and benchmarks at the end to minimize cache invalidation COPY examples examples COPY benchmarks benchmarks From 53bf4d5ce702d695dd52c96c9ac20604ab876ab7 Mon Sep 17 00:00:00 2001 From: Yu Bo Gao Date: Mon, 9 Feb 2026 15:02:08 -0800 Subject: [PATCH 6/7] add cmake --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index f63631828ad8..602ed0f70e69 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -591,7 +591,7 @@ ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib64:${LD_LIBRARY_PATH} # Install Transformer Engine for FP8 attention support in multimodal encoder RUN --mount=type=cache,target=/root/.cache/uv \ apt-get update -y && \ - apt-get install -y --no-install-recommends cuda-toolkit-13-0 libcudnn9-dev-cuda-13 && \ + apt-get install -y --no-install-recommends cmake cuda-toolkit-13-0 libcudnn9-dev-cuda-13 && \ rm -rf /var/lib/apt/lists/* && \ uv pip install --system pybind11 && \ NVTE_FRAMEWORK=pytorch uv pip install --system --no-build-isolation \ From fb9336ed5d027d21a4d172530071d8b16ab61b2f Mon Sep 17 00:00:00 2001 From: Yu Bo Gao Date: Tue, 10 Feb 2026 15:03:38 -0800 Subject: [PATCH 7/7] use 16 buckets --- vllm/model_executor/layers/attention/mm_encoder_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 37dcc7d24bb2..0c4c5524e094 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -21,7 +21,7 @@ # Seqlen buckets for BSHD format - Q/K/V tensors are padded to these sizes # so cuDNN sees a fixed set of tensor shapes and avoids recompilation -TE_SEQLEN_BUCKETS = [2048, 4096, 6144, 10240, 16384, 25600, 35840, 65536] +TE_SEQLEN_BUCKETS = [1024, 2048, 3072, 4096, 5120, 6144, 7168, 9216, 10240, 13312, 16384, 20480, 25600, 35840, 49152, 65536] # Fixed max_seqlen to avoid cuDNN recompilation when sequence lengths vary TE_FIXED_MAX_SEQLEN = 128 * 1024