Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Install FlashInfer from CentML fork
# https://github.com/CentML/flashinfer/tree/mlperf-inf-mm-q3vl-v6.0
ARG FLASHINFER_REPO=https://github.com/CentML/flashinfer.git
ARG FLASHINFER_BRANCH=mlperf-inf-mm-q3vl-v6.0-rc1
ARG FLASHINFER_BRANCH=mlperf-inf-mm-q3vl-v6.0
ARG FLASHINFER_CUBIN_VERSION=0.5.3
ARG FLASHINFER_JIT_CACHE_VERSION=0.5.3
RUN --mount=type=cache,target=/root/.cache/uv \
Expand Down
2 changes: 1 addition & 1 deletion docker/versions.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"default": "https://github.com/CentML/flashinfer.git"
},
"FLASHINFER_BRANCH": {
"default": "mlperf-inf-mm-q3vl-v6.0-rc1"
"default": "mlperf-inf-mm-q3vl-v6.0"
},
"FLASHINFER_CUBIN_VERSION": {
"default": "0.5.3"
Expand Down
10 changes: 10 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
VLLM_USE_TRITON_POS_EMBED: bool = False
VLLM_POS_EMBED_CACHE_SIZE: int = 100
VLLM_MM_ENCODER_FP8_ATTN: bool = False
VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH: str | None = None
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
VLLM_LOOPBACK_IP: str = ""
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True
Expand Down Expand Up @@ -1455,6 +1457,14 @@ def get_vllm_port() -> int | None:
"VLLM_POS_EMBED_CACHE_SIZE": lambda: int(
os.getenv("VLLM_POS_EMBED_CACHE_SIZE", "100")
),
# Controls whether to use FP8 attention for multimodal encoder (e.g., ViT)
"VLLM_MM_ENCODER_FP8_ATTN": lambda: bool(
int(os.getenv("VLLM_MM_ENCODER_FP8_ATTN", "0"))
),
# Path to JSON file containing FP8 attention scales for multimodal encoder
"VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH": lambda: os.getenv(
"VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH", None
),
# If set to 1/True, use the TRTLLM attention backend in flashinfer.
# If set to 0/False, use the default attention backend in flashinfer.
# If not set, auto-detect the attention backend in flashinfer.
Expand Down
164 changes: 161 additions & 3 deletions vllm/model_executor/layers/attention/mm_encoder_attention.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import functools
import json
from pathlib import Path

import torch

import vllm.envs as envs
from vllm.config import MultiModalConfig
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
QuantFP8,
quantize_fp8_pad_head_dim_triton,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.registry import AttentionBackendEnum
Expand All @@ -20,6 +29,36 @@
logger = init_logger(__name__)


@functools.cache
def _load_fp8_scales_file(path: str | None) -> dict[str, dict[str, float]]:
"""Load FP8 scales from file. Results are cached."""
if path is None:
return {}

path = str(Path(path).expanduser())
with open(path, encoding="utf-8") as f:
data = json.load(f)

# Handle nested "layers" format
if "layers" in data and isinstance(data["layers"], dict):
data = data["layers"]

scales: dict[str, dict[str, float]] = {}
for layer_name, layer_scales in data.items():
if not isinstance(layer_scales, dict):
continue
q = layer_scales.get("q", layer_scales.get("q_scale"))
k = layer_scales.get("k", layer_scales.get("k_scale"))
v = layer_scales.get("v", layer_scales.get("v_scale"))
if q is not None and k is not None and v is not None:
scales[layer_name] = {"q": float(q), "k": float(k), "v": float(v)}

logger.info_once(
"Loaded FP8 attention scales from %s (%d layers)", path, len(scales)
)
return scales


# --8<-- [start:mm_encoder_attn]
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
Expand All @@ -43,8 +82,7 @@ def __init__(
head_size: hidden_size per attention head.
scale: scale factor.
num_kv_heads: number of kv heads.
prefix: This has no effect, it is only here to make it easier to
swap between Attention and MultiHeadAttention
prefix: layer name prefix, used to look up FP8 scales.
multimodal_config: configs for multi-modal.
"""
super().__init__()
Expand All @@ -64,6 +102,7 @@ def __init__(
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
self.dtype = dtype

# Try to get vision attention backend from multimodal_config.
attn_backend_override = None
Expand All @@ -90,6 +129,74 @@ def __init__(

logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")

# FP8 attention support (currently only FlashInfer cuDNN backend)
self.fp8_enabled = False
self.fp8_scales: dict[str, float] | None = None
self.fp8_quant: QuantFP8 | None = None

if envs.VLLM_MM_ENCODER_FP8_ATTN:
if self.attn_backend != AttentionBackendEnum.FLASHINFER:
raise ValueError(
"VLLM_MM_ENCODER_FP8_ATTN requires the FlashInfer "
"cuDNN backend (FLASHINFER), but the current ViT "
f"attention backend is {self.attn_backend}."
)
self._init_fp8_attention(prefix)

def _init_fp8_attention(self, layer_name: str) -> None:
"""Initialize FP8 attention for this layer."""
scale_path = envs.VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH
all_scales = _load_fp8_scales_file(scale_path)

if scale_path is None:
# No scale path provided - use scale=1.0 and warn
logger.warning_once(
"VLLM_MM_ENCODER_FP8_ATTN enabled but "
"VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH not set. "
"Using scale=1.0 for all Q/K/V (may cause accuracy issues)."
)
self.fp8_scales = {"q": 1.0, "k": 1.0, "v": 1.0}
else:
# Scale path provided - layer must exist
layer_scales = all_scales.get(layer_name)
if layer_scales is None:
raise ValueError(
"FP8 attention enabled but scales not found for layer "
f"'{layer_name}' in {scale_path}. "
f"Available layers: {list(all_scales.keys())}"
)
self.fp8_scales = layer_scales

# Register scale tensors as buffers (auto-move to device with module)
# Shape (1, 1, 1, 1) as required by cuDNN
self.register_buffer(
"_fp8_q_scale",
torch.tensor([self.fp8_scales["q"]], dtype=torch.float32).view(1, 1, 1, 1),
)
self.register_buffer(
"_fp8_k_scale",
torch.tensor([self.fp8_scales["k"]], dtype=torch.float32).view(1, 1, 1, 1),
)
self.register_buffer(
"_fp8_v_scale",
torch.tensor([self.fp8_scales["v"]], dtype=torch.float32).view(1, 1, 1, 1),
)

# Create QuantFP8 for efficient quantization
self.fp8_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
self.fp8_enabled = True
self.skip_scale_q = self.fp8_scales["q"] == 1.0
self.skip_scale_k = self.fp8_scales["k"] == 1.0
self.skip_scale_v = self.fp8_scales["v"] == 1.0

logger.debug(
"FP8 attention enabled for %s: q=%.4f, k=%.4f, v=%.4f",
layer_name if layer_name else "MMEncoderAttention",
self.fp8_scales["q"],
self.fp8_scales["k"],
self.fp8_scales["v"],
)

@classmethod
def enabled(cls) -> bool:
return True
Expand Down Expand Up @@ -187,6 +294,34 @@ def _forward_fa(
output = output.reshape(bsz, q_len, -1)
return output

def _quantize_to_fp8(
self,
tensor: torch.Tensor,
scale: torch.Tensor,
skip_scale: bool = False,
) -> torch.Tensor:
"""Quantize a 3D (S, H, D) tensor to FP8.

Uses QuantFP8 CustomOp when head_dim is aligned to 16; otherwise
falls back to a stride-aware Triton kernel that pads head_dim to
a multiple of 16 — no extra copy even for non-contiguous inputs.
"""
assert self.fp8_quant is not None
orig_shape = tensor.shape
head_dim = orig_shape[-1]

if head_dim % 16 == 0:
if skip_scale:
return tensor.to(torch.float8_e4m3fn)

# QuantFP8 expects 2D input: (total_tokens, num_heads * head_dim)
tensor_2d = tensor.reshape(orig_shape[0], -1)
fp8_tensor, _ = self.fp8_quant.forward_cuda(tensor_2d, scale=scale)
return fp8_tensor.reshape(orig_shape)

# Fall back to Triton kernel for padding head_dim to a multiple of 16
return quantize_fp8_pad_head_dim_triton(tensor, scale, skip_scale=skip_scale)

def _forward_flashinfer(
self,
query: torch.Tensor,
Expand All @@ -197,7 +332,19 @@ def _forward_flashinfer(
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
return vit_flashinfer_wrapper(
if self.fp8_enabled:
assert self.fp8_quant is not None and self.fp8_scales is not None
query = self._quantize_to_fp8(
query, self._fp8_q_scale, skip_scale=self.skip_scale_q
)
key = self._quantize_to_fp8(
key, self._fp8_k_scale, skip_scale=self.skip_scale_k
)
value = self._quantize_to_fp8(
value, self._fp8_v_scale, skip_scale=self.skip_scale_v
)

output = vit_flashinfer_wrapper(
q=query,
k=key,
v=value,
Expand All @@ -206,8 +353,19 @@ def _forward_flashinfer(
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
q_scale=self._fp8_q_scale if self.fp8_enabled else None,
k_scale=self._fp8_k_scale if self.fp8_enabled else None,
v_scale=self._fp8_v_scale if self.fp8_enabled else None,
o_data_type=self.dtype if self.fp8_enabled else None,
)

# Un-pad head dimension if it was padded during FP8 quantization
if self.fp8_enabled and output.shape[-1] != self.head_size:
output = output[..., : self.head_size]
output = output.contiguous()

return output

def _forward_fa4(
self,
query: torch.Tensor,
Expand Down
Loading