Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,15 @@ RUN --mount=type=bind,from=build,src=/tmp/ep_kernels_workspace/dist,target=/vllm
# Until then, add /usr/local/nvidia/lib64 before the image cuda path to allow override.
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib64:${LD_LIBRARY_PATH}

# Install Transformer Engine for FP8 attention support in multimodal encoder
RUN --mount=type=cache,target=/root/.cache/uv \
apt-get update -y && \
apt-get install -y --no-install-recommends cmake cuda-toolkit-13-0 libcudnn9-dev-cuda-13 && \
rm -rf /var/lib/apt/lists/* && \
uv pip install --system pybind11 && \
NVTE_FRAMEWORK=pytorch uv pip install --system --no-build-isolation \
git+https://github.com/NVIDIA/TransformerEngine.git@stable

# Copy examples and benchmarks at the end to minimize cache invalidation
COPY examples examples
COPY benchmarks benchmarks
Expand Down
212 changes: 212 additions & 0 deletions vllm/model_executor/layers/attention/mm_encoder_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@

logger = init_logger(__name__)

# Seqlen buckets for BSHD format - Q/K/V tensors are padded to these sizes
# so cuDNN sees a fixed set of tensor shapes and avoids recompilation
TE_SEQLEN_BUCKETS = [1024, 2048, 3072, 4096, 5120, 6144, 7168, 9216, 10240, 13312, 16384, 20480, 25600, 35840, 49152, 65536]

# Fixed max_seqlen to avoid cuDNN recompilation when sequence lengths vary
TE_FIXED_MAX_SEQLEN = 128 * 1024

try:
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch import DotProductAttention, fp8_autocast
except ImportError:
DotProductAttention = None
fp8_autocast = None
DelayedScaling = None
logger.warning("TransformerEngine is not installed.")


# --8<-- [start:mm_encoder_attn]
@CustomOp.register("mm_encoder_attn")
Expand Down Expand Up @@ -88,6 +104,24 @@ def __init__(
get_flash_attn_version() if self.is_flash_attn_backend else None
)

# Initialize Transformer Engine FP8 attention if backend is TE
# for each batch size
self.te_attn_op = None
self.te_fp8_recipe = None
self.is_te_fp8_backend = (
self.attn_backend == AttentionBackendEnum.TE_FP8
if hasattr(AttentionBackendEnum, 'TE_FP8')
else False
)

if self.is_te_fp8_backend:
if DotProductAttention is None:
raise ImportError(
"TransformerEngine is not installed but TE_FP8 backend was selected"
)
self.te_fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=True)
logger.info_once("Initialized FP8 Transformer Engine for MMEncoderAttention.")

logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")

@classmethod
Expand Down Expand Up @@ -118,6 +152,37 @@ def maybe_reshape_qkv_to_4d(

return query, key, value

@staticmethod
def _find_seqlen_bucket(seqlen: int) -> int | None:
"""Find the smallest seqlen bucket that can fit the given seqlen.

Returns None if seqlen exceeds the largest bucket.
"""
for bucket in TE_SEQLEN_BUCKETS:
if bucket >= seqlen:
return bucket
return None

def _lazy_init_te_attn(
self,
num_attention_heads: int,
kv_channels: int,
num_gqa_groups: int | None,
attn_mask_type: str,
softmax_scale: float | None,
qkv_format: str = "bshd",
) -> None:
"""Lazily initialize Transformer Engine attention operator."""
if self.te_attn_op is None:
self.te_attn_op = DotProductAttention(
num_attention_heads,
kv_channels,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
softmax_scale=softmax_scale,
qkv_format=qkv_format,
)

def _forward_sdpa(
self,
query: torch.Tensor,
Expand Down Expand Up @@ -187,6 +252,151 @@ def _forward_fa(
output = output.reshape(bsz, q_len, -1)
return output

def _forward_te_fp8(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass using Transformer Engine FP8 attention with BSHD format.

IMPORTANT: This function processes ONE sample at a time. When cu_seqlens
is provided, it must have length 2 (i.e., [0, seq_len] for a single
sequence).

This batch-1 restriction is a limitation of Transformer Engine, not
cuDNN. TE does not support THD format for FP8 attention, and converting
the upstream THD tensor into a proper multi-batch BSHD tensor would be
too expensive. Instead, we manually reinterpret a batch-1 THD tensor as
BSHD with B=1 and S=T (the total token count), then call the BSHD
kernel. This is semantically consistent because a single sequence in
THD is equivalent to B=1 BSHD.

Input shape:
(batch_size x seq_len x hidden_size) where hidden_size = num_heads * head_size
or (batch_size x seq_len x num_heads x head_size)

Uses BSHD format: (batch, seq, heads, dim)

Note: Head dimension is padded to multiple of 16 for optimal performance.
"""
# Validate single-sample constraint
if cu_seqlens is not None:
assert len(cu_seqlens) == 2, (
f"_forward_te_fp8 (BSHD format) requires exactly one sample at a time. "
f"cu_seqlens must have length 2 (got {len(cu_seqlens)}). "
)

bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_3d_input = query.dim() == 3

# Transform to BSHD format: (batch, seq, heads, dim)
if is_3d_input:
# Input is (batch, seq, hidden_size) - reshape to (batch, seq, heads, dim)
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
# else: already in (batch, seq, heads, dim) format

# Pad head dimension to multiple of 16 for optimal performance
original_head_size = self.head_size
padded_head_size = ((self.head_size + 15) // 16) * 16
needs_padding = padded_head_size != original_head_size

if needs_padding:
pad_size = padded_head_size - original_head_size
query = torch.nn.functional.pad(query, (0, pad_size))
key = torch.nn.functional.pad(key, (0, pad_size))
value = torch.nn.functional.pad(value, (0, pad_size))

# Pad Q/K/V seqlen dimension to a bucket size to avoid cuDNN
# recompilation when different images have different resolutions.
# cu_seqlens already tracks the real sequence boundaries.
bucket_seqlen = self._find_seqlen_bucket(q_len)
if bucket_seqlen is not None and bucket_seqlen > q_len:
seq_pad = bucket_seqlen - q_len
# Pad S dimension: shape is (B, S, H, D), so pad dim=1
query = torch.nn.functional.pad(query, (0, 0, 0, 0, 0, seq_pad))
key = torch.nn.functional.pad(key, (0, 0, 0, 0, 0, seq_pad))
value = torch.nn.functional.pad(value, (0, 0, 0, 0, 0, seq_pad))

# Determine if we have variable sequence lengths
# cu_seqlens indicates variable lengths when provided
attention_mask = None
if cu_seqlens is not None:
# Variable sequence lengths - need padding mask
attn_mask_type = "padding"
else:
# Uniform sequence lengths - no mask needed
attn_mask_type = "no_mask"

# Determine GQA groups - TE will handle the GQA logic internally
num_gqa_groups = self.num_kv_heads if self.num_kv_heads != self.num_heads else None

# Lazy initialization of TE attention operator
self._lazy_init_te_attn(
num_attention_heads=self.num_heads,
kv_channels=padded_head_size,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
softmax_scale=self.scale,
qkv_format="bshd",
)

max_seqlen = TE_FIXED_MAX_SEQLEN

# NVTX annotation with all parameters for lazy_init and te_attn_op
nvtx_msg = (
f"TE_FP8_BSHD: "
f"Q={tuple(query.shape)}, K={tuple(key.shape)}, V={tuple(value.shape)}, "
f"num_heads={self.num_heads}, kv_channels={padded_head_size}, "
f"num_gqa_groups={num_gqa_groups}, attn_mask_type={attn_mask_type}, "
f"softmax_scale={self.scale}, qkv_format=bshd, "
f"cu_seqlens={cu_seqlens.shape if cu_seqlens is not None else None}, "
f"max_seqlen={max_seqlen}"
)
with torch.cuda.nvtx.range(nvtx_msg):
with fp8_autocast(enabled=True, fp8_recipe=self.te_fp8_recipe):
output = self.te_attn_op(
query,
key,
value,
attention_mask=None,
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_kv=max_seqlen,
)

# Output is (batch, padded_seq, heads, padded_dim) or
# (batch, padded_seq, heads*padded_dim).
# Handle both cases.
if output.dim() == 3:
# Output is (batch, padded_seq, heads*dim) flattened
output = output.reshape(
bsz, output.size(1), self.num_heads, padded_head_size
)

# Slice back to original seqlen (remove S-dimension padding)
output = output[:, :q_len, :, :]

# Remove head padding if needed
if needs_padding:
output = output[..., :original_head_size]

# Reshape back to original format
if is_3d_input:
# Back to (batch, seq, hidden_size) where hidden_size = H * D
output = output.reshape(bsz, q_len, self.num_heads * original_head_size)
else:
# Already in (batch, seq, num_heads, head_size) format
pass

return output

def _forward_flashinfer(
self,
query: torch.Tensor,
Expand Down Expand Up @@ -274,6 +484,8 @@ def forward_cuda(
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
elif self.is_te_fp8_backend:
return self._forward_te_fp8(query, key, value, cu_seqlens, max_seqlen)
else:
raise ValueError(
f"Unsupported multi-modal encoder attention backend for CUDA: "
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ def __init__(
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.TE_FP8,
}:
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now."
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
AttentionBackendEnum.FLASH_ATTN_CUTE,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.TE_FP8,
]

@classmethod
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/attention/backends/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend"
)
TORCH_SDPA = "" # this tag is only used for ViT
TE_FP8 = "transformer_engine.pytorch.DotProductAttention" # this tag is only used for MMEncoderAttention
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA = (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
Expand Down
Loading