diff --git a/docker/Dockerfile b/docker/Dockerfile index 9a866c786d90..9064ea51632b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -495,7 +495,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Install FlashInfer from CentML fork # https://github.com/CentML/flashinfer/tree/mlperf-inf-mm-q3vl-v6.0 ARG FLASHINFER_REPO=https://github.com/CentML/flashinfer.git -ARG FLASHINFER_BRANCH=mlperf-inf-mm-q3vl-v6.0-rc1 +ARG FLASHINFER_BRANCH=mlperf-inf-mm-q3vl-v6.0 ARG FLASHINFER_CUBIN_VERSION=0.5.3 ARG FLASHINFER_JIT_CACHE_VERSION=0.5.3 RUN --mount=type=cache,target=/root/.cache/uv \ diff --git a/docker/versions.json b/docker/versions.json index 71dbfad57846..3bb174eea948 100644 --- a/docker/versions.json +++ b/docker/versions.json @@ -71,7 +71,7 @@ "default": "https://github.com/CentML/flashinfer.git" }, "FLASHINFER_BRANCH": { - "default": "mlperf-inf-mm-q3vl-v6.0-rc1" + "default": "mlperf-inf-mm-q3vl-v6.0" }, "FLASHINFER_CUBIN_VERSION": { "default": "0.5.3" diff --git a/vllm/envs.py b/vllm/envs.py index 0b1215483396..0be0a9003a3a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -216,6 +216,8 @@ VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_USE_TRITON_POS_EMBED: bool = False VLLM_POS_EMBED_CACHE_SIZE: int = 100 + VLLM_MM_ENCODER_FP8_ATTN: bool = False + VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH: str | None = None VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True @@ -1455,6 +1457,14 @@ def get_vllm_port() -> int | None: "VLLM_POS_EMBED_CACHE_SIZE": lambda: int( os.getenv("VLLM_POS_EMBED_CACHE_SIZE", "100") ), + # Controls whether to use FP8 attention for multimodal encoder (e.g., ViT) + "VLLM_MM_ENCODER_FP8_ATTN": lambda: bool( + int(os.getenv("VLLM_MM_ENCODER_FP8_ATTN", "0")) + ), + # Path to JSON file containing FP8 attention scales for multimodal encoder + "VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH": lambda: os.getenv( + "VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH", None + ), # If set to 1/True, use the TRTLLM attention backend in flashinfer. # If set to 0/False, use the default attention backend in flashinfer. # If not set, auto-detect the attention backend in flashinfer. diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 28d83776ebe5..c35cf42090a7 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -1,12 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +import json +from pathlib import Path import torch +import vllm.envs as envs from vllm.config import MultiModalConfig from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.input_quant_fp8 import ( + QuantFP8, + quantize_fp8_pad_head_dim_triton, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -20,6 +29,36 @@ logger = init_logger(__name__) +@functools.cache +def _load_fp8_scales_file(path: str | None) -> dict[str, dict[str, float]]: + """Load FP8 scales from file. Results are cached.""" + if path is None: + return {} + + path = str(Path(path).expanduser()) + with open(path, encoding="utf-8") as f: + data = json.load(f) + + # Handle nested "layers" format + if "layers" in data and isinstance(data["layers"], dict): + data = data["layers"] + + scales: dict[str, dict[str, float]] = {} + for layer_name, layer_scales in data.items(): + if not isinstance(layer_scales, dict): + continue + q = layer_scales.get("q", layer_scales.get("q_scale")) + k = layer_scales.get("k", layer_scales.get("k_scale")) + v = layer_scales.get("v", layer_scales.get("v_scale")) + if q is not None and k is not None and v is not None: + scales[layer_name] = {"q": float(q), "k": float(k), "v": float(v)} + + logger.info_once( + "Loaded FP8 attention scales from %s (%d layers)", path, len(scales) + ) + return scales + + # --8<-- [start:mm_encoder_attn] @CustomOp.register("mm_encoder_attn") class MMEncoderAttention(CustomOp): @@ -43,8 +82,7 @@ def __init__( head_size: hidden_size per attention head. scale: scale factor. num_kv_heads: number of kv heads. - prefix: This has no effect, it is only here to make it easier to - swap between Attention and MultiHeadAttention + prefix: layer name prefix, used to look up FP8 scales. multimodal_config: configs for multi-modal. """ super().__init__() @@ -64,6 +102,7 @@ def __init__( # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() + self.dtype = dtype # Try to get vision attention backend from multimodal_config. attn_backend_override = None @@ -90,6 +129,74 @@ def __init__( logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") + # FP8 attention support (currently only FlashInfer cuDNN backend) + self.fp8_enabled = False + self.fp8_scales: dict[str, float] | None = None + self.fp8_quant: QuantFP8 | None = None + + if envs.VLLM_MM_ENCODER_FP8_ATTN: + if self.attn_backend != AttentionBackendEnum.FLASHINFER: + raise ValueError( + "VLLM_MM_ENCODER_FP8_ATTN requires the FlashInfer " + "cuDNN backend (FLASHINFER), but the current ViT " + f"attention backend is {self.attn_backend}." + ) + self._init_fp8_attention(prefix) + + def _init_fp8_attention(self, layer_name: str) -> None: + """Initialize FP8 attention for this layer.""" + scale_path = envs.VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH + all_scales = _load_fp8_scales_file(scale_path) + + if scale_path is None: + # No scale path provided - use scale=1.0 and warn + logger.warning_once( + "VLLM_MM_ENCODER_FP8_ATTN enabled but " + "VLLM_MM_ENCODER_FP8_ATTN_SCALE_PATH not set. " + "Using scale=1.0 for all Q/K/V (may cause accuracy issues)." + ) + self.fp8_scales = {"q": 1.0, "k": 1.0, "v": 1.0} + else: + # Scale path provided - layer must exist + layer_scales = all_scales.get(layer_name) + if layer_scales is None: + raise ValueError( + "FP8 attention enabled but scales not found for layer " + f"'{layer_name}' in {scale_path}. " + f"Available layers: {list(all_scales.keys())}" + ) + self.fp8_scales = layer_scales + + # Register scale tensors as buffers (auto-move to device with module) + # Shape (1, 1, 1, 1) as required by cuDNN + self.register_buffer( + "_fp8_q_scale", + torch.tensor([self.fp8_scales["q"]], dtype=torch.float32).view(1, 1, 1, 1), + ) + self.register_buffer( + "_fp8_k_scale", + torch.tensor([self.fp8_scales["k"]], dtype=torch.float32).view(1, 1, 1, 1), + ) + self.register_buffer( + "_fp8_v_scale", + torch.tensor([self.fp8_scales["v"]], dtype=torch.float32).view(1, 1, 1, 1), + ) + + # Create QuantFP8 for efficient quantization + self.fp8_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) + self.fp8_enabled = True + self.skip_scale_q = self.fp8_scales["q"] == 1.0 + self.skip_scale_k = self.fp8_scales["k"] == 1.0 + self.skip_scale_v = self.fp8_scales["v"] == 1.0 + + logger.debug( + "FP8 attention enabled for %s: q=%.4f, k=%.4f, v=%.4f", + layer_name if layer_name else "MMEncoderAttention", + self.fp8_scales["q"], + self.fp8_scales["k"], + self.fp8_scales["v"], + ) + @classmethod def enabled(cls) -> bool: return True @@ -187,6 +294,34 @@ def _forward_fa( output = output.reshape(bsz, q_len, -1) return output + def _quantize_to_fp8( + self, + tensor: torch.Tensor, + scale: torch.Tensor, + skip_scale: bool = False, + ) -> torch.Tensor: + """Quantize a 3D (S, H, D) tensor to FP8. + + Uses QuantFP8 CustomOp when head_dim is aligned to 16; otherwise + falls back to a stride-aware Triton kernel that pads head_dim to + a multiple of 16 — no extra copy even for non-contiguous inputs. + """ + assert self.fp8_quant is not None + orig_shape = tensor.shape + head_dim = orig_shape[-1] + + if head_dim % 16 == 0: + if skip_scale: + return tensor.to(torch.float8_e4m3fn) + + # QuantFP8 expects 2D input: (total_tokens, num_heads * head_dim) + tensor_2d = tensor.reshape(orig_shape[0], -1) + fp8_tensor, _ = self.fp8_quant.forward_cuda(tensor_2d, scale=scale) + return fp8_tensor.reshape(orig_shape) + + # Fall back to Triton kernel for padding head_dim to a multiple of 16 + return quantize_fp8_pad_head_dim_triton(tensor, scale, skip_scale=skip_scale) + def _forward_flashinfer( self, query: torch.Tensor, @@ -197,7 +332,19 @@ def _forward_flashinfer( sequence_lengths: torch.Tensor | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: - return vit_flashinfer_wrapper( + if self.fp8_enabled: + assert self.fp8_quant is not None and self.fp8_scales is not None + query = self._quantize_to_fp8( + query, self._fp8_q_scale, skip_scale=self.skip_scale_q + ) + key = self._quantize_to_fp8( + key, self._fp8_k_scale, skip_scale=self.skip_scale_k + ) + value = self._quantize_to_fp8( + value, self._fp8_v_scale, skip_scale=self.skip_scale_v + ) + + output = vit_flashinfer_wrapper( q=query, k=key, v=value, @@ -206,8 +353,19 @@ def _forward_flashinfer( cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, sequence_lengths=sequence_lengths, + q_scale=self._fp8_q_scale if self.fp8_enabled else None, + k_scale=self._fp8_k_scale if self.fp8_enabled else None, + v_scale=self._fp8_v_scale if self.fp8_enabled else None, + o_data_type=self.dtype if self.fp8_enabled else None, ) + # Un-pad head dimension if it was padded during FP8 quantization + if self.fp8_enabled and output.shape[-1] != self.head_size: + output = output[..., : self.head_size] + output = output.contiguous() + + return output + def _forward_fa4( self, query: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index c1a901c37a0b..97aa7cb2dca5 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + import torch import torch.nn.functional as F @@ -13,12 +15,160 @@ group_broadcast, ) from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON, tl, triton _FP8_DTYPE = current_platform.fp8_dtype() _FP8_MIN, _FP8_MAX = get_fp8_min_max() _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) +@triton.jit +def _quantize_pad_fp8_kernel( + x_ptr, + y_ptr, + scale_ptr, + stride_xs, # input stride along token (seq) dim — may be non-contiguous + stride_xh, # input stride along head dim + stride_xd, # input stride along head_dim dim (usually 1) + stride_ys, # output stride along token dim (contiguous) + stride_yh, # output stride along head dim + stride_yd, # output stride along head_dim dim (usually 1) + num_heads, + n_rows, # total rows = S * H + n_cols, + n_cols_padded, + fp8_min, + fp8_max, + SKIP_SCALE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < n_rows + mask_out = mask_m[:, None] & (offs_n[None, :] < n_cols_padded) + mask_in = mask_m[:, None] & (offs_n[None, :] < n_cols) + + # Decompose flattened row into (token, head) for 3D stride indexing. + # This lets the kernel read directly from non-contiguous QKV views. + s = offs_m // num_heads + h = offs_m % num_heads + + x_ptrs = ( + x_ptr + + s[:, None] * stride_xs + + h[:, None] * stride_xh + + offs_n[None, :] * stride_xd + ) + x = tl.load(x_ptrs, mask=mask_in, other=0.0).to(tl.float32) + if SKIP_SCALE: + x_q = x + else: + scale = tl.load(scale_ptr) + x_q = x / scale + x_q = tl.where(mask_in, x_q, 0.0) + x_q = tl.clamp(x_q, fp8_min, fp8_max).to(y_ptr.dtype.element_ty) + + y_ptrs = ( + y_ptr + + s[:, None] * stride_ys + + h[:, None] * stride_yh + + offs_n[None, :] * stride_yd + ) + tl.store(y_ptrs, x_q, mask=mask_out) + + +def _get_fp8_pad_quant_config(padded_head_dim: int) -> tuple[int, int, int]: + # Blackwell: use a single static config to avoid recompiles. + if current_platform.is_device_capability_family(100): + block_n, num_warps, block_m = 128, 4, 16 + else: + block_n = triton.next_power_of_2(padded_head_dim) + block_n = max(16, min(block_n, 256)) + num_warps = 4 if block_n >= 128 else 2 + block_m = 16 + + env_block_n = os.getenv("VLLM_FP8_PAD_QUANT_BLOCK_N") + env_num_warps = os.getenv("VLLM_FP8_PAD_QUANT_NUM_WARPS") + env_block_m = os.getenv("VLLM_FP8_PAD_QUANT_BLOCK_M") + if env_block_n is not None: + block_n = max(16, min(int(env_block_n), 256)) + if env_num_warps is not None: + num_warps = int(env_num_warps) + if env_block_m is not None: + block_m = max(1, int(env_block_m)) + + return block_n, num_warps, block_m + + +def quantize_fp8_pad_head_dim_triton( + tensor: torch.Tensor, + scale: torch.Tensor, + skip_scale: bool = False, + block_n: int | None = None, + num_warps: int | None = None, + block_m: int | None = None, +) -> torch.Tensor: + """Quantize a 3D/4D tensor to FP8, padding head_dim to a multiple of 16. + + Reads directly from the input using its 3D strides, so non-contiguous + views (e.g. Q/K/V slices from an interleaved QKV buffer) are handled + without an extra copy. Output is always a fresh contiguous tensor + with shape (S, H, padded_D). + """ + if not HAS_TRITON: + raise RuntimeError("Triton is required to quantize with head_dim padding.") + + original_shape = tensor.shape + if tensor.dim() == 4: + tensor = tensor.view(-1, tensor.shape[-2], tensor.shape[-1]) + assert tensor.dim() == 3, f"Expected 3D input (S, H, D), got {tensor.dim()}D" + S, H, D = tensor.shape + padded_head_dim = (D + 15) // 16 * 16 + out_dtype = current_platform.fp8_dtype() + output = torch.empty( + (S, H, padded_head_dim), + device=tensor.device, + dtype=out_dtype, + ) + + scale_1d = scale.reshape(-1) + fp8_min, fp8_max = get_fp8_min_max() + n_rows = S * H + + if block_n is None or num_warps is None or block_m is None: + block_n, num_warps, block_m = _get_fp8_pad_quant_config(padded_head_dim) + + grid = (triton.cdiv(n_rows, block_m), triton.cdiv(padded_head_dim, block_n)) + + _quantize_pad_fp8_kernel[grid]( + tensor, + output, + scale_1d, + tensor.stride(0), + tensor.stride(1), + tensor.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + H, + n_rows, + D, + padded_head_dim, + fp8_min, + fp8_max, + SKIP_SCALE=skip_scale, + BLOCK_M=block_m, + BLOCK_N=block_n, + num_warps=num_warps, + ) + + return output.view((*original_shape[:-1], padded_head_dim)) + + # --8<-- [start:quant_fp8] @CustomOp.register("quant_fp8") class QuantFP8(CustomOp): diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index e3ab33bfa5b7..54f335200996 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -610,6 +610,18 @@ def __init__( norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads + + # When FP8 attention is enabled and head_dim is not a multiple of 16, + # the quantization kernel pads head_dim (e.g. 72 -> 80). Store the + # padded hidden_size so compute_flashinfer_cu_seqlens can produce + # element offsets that match the contiguous FP8 tensor strides. + self.fp8_vit_attn = envs.VLLM_MM_ENCODER_FP8_ATTN + if self.fp8_vit_attn and head_dim % 16 != 0: + padded_head_dim = ((head_dim + 15) // 16) * 16 + self.fp8_padded_hidden_size = self.num_heads * padded_head_dim + else: + self.fp8_padded_hidden_size = None + self.rotary_pos_emb = get_rope( head_size=head_dim, max_position=8192, @@ -927,6 +939,27 @@ def compute_flashinfer_cu_seqlens( rotary_pos_emb_sin: torch.Tensor | None = None, ) -> np.ndarray: batch_size = len(cu_seqlens) - 1 + + if self.fp8_padded_hidden_size is not None: + # FP8 path: after quantization Q/K/V are each independent + # contiguous tensors with stride H * padded_D per token. + # All sections (QK, V, O) use the same element stride. + # The wrapper overrides QK/V batch_offsets to use O offsets. + scale = self.fp8_padded_hidden_size // self.tp_size + cu_seqlens = cu_seqlens * scale + cu_seqlens_padded = self.add_padding_to_fi_seqlens( + cu_seqlens, batch_size, cu_seqlens[-1] + ) + return np.concatenate( + [cu_seqlens_padded, cu_seqlens_padded, cu_seqlens_padded] + ) + + # BF16 path: Q/K/V are non-contiguous views into shared buffers. + # Element stride per token differs by tensor: + # After rotary: Q,K in [Q,K] buffer -> stride 2×H×D + # No rotary: Q,K in [Q,K,V] buffer -> stride 3×H×D + # V always in [Q,K,V] buffer -> stride 3×H×D + # O is contiguous -> stride H×D scale = self.hidden_size // self.tp_size cu_seqlens = cu_seqlens * scale if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index 84b1438fb1b0..fb9bab7d0fff 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -273,6 +273,10 @@ def flashinfer_wrapper( cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, sequence_lengths: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + k_scale: torch.Tensor | None = None, + v_scale: torch.Tensor | None = None, + o_data_type: torch.dtype | None = None, ) -> torch.Tensor: from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache @@ -306,6 +310,10 @@ def flashinfer_wrapper( batch_offsets_k=batch_offsets_qk, batch_offsets_v=batch_offsets_v, batch_offsets_o=batch_offsets_o, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, + o_data_type=o_data_type, ) if is_reshaped: @@ -323,7 +331,13 @@ def vit_flashinfer_wrapper_fake( cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, sequence_lengths: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + k_scale: torch.Tensor | None = None, + v_scale: torch.Tensor | None = None, + o_data_type: torch.dtype | None = None, ) -> torch.Tensor: + if o_data_type is not None and o_data_type != q.dtype: + return torch.empty(q.shape, device=q.device, dtype=o_data_type) return torch.empty_like(q) @@ -343,7 +357,22 @@ def vit_flashinfer_wrapper( cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, sequence_lengths: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + k_scale: torch.Tensor | None = None, + v_scale: torch.Tensor | None = None, + o_data_type: torch.dtype | None = None, ) -> torch.Tensor: return torch.ops.vllm.flashinfer_wrapper( - q, k, v, scale, workspace_buffer, cu_seqlens, max_seqlen, sequence_lengths + q, + k, + v, + scale, + workspace_buffer, + cu_seqlens, + max_seqlen, + sequence_lengths, + q_scale, + k_scale, + v_scale, + o_data_type, )