diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py index b79cd82b21..936e1e1666 100644 --- a/flashinfer/cudnn/prefill.py +++ b/flashinfer/cudnn/prefill.py @@ -16,9 +16,20 @@ # Global cudnn handle. need to make it per device in future _cudnn_handle = None +_dummy_scale_tensors: dict[torch.device, torch.Tensor] = {} + + +def _get_dummy_scale_tensor(device: torch.device): + t = _dummy_scale_tensors.get(device) + if t is None: + t = torch.tensor([1.0], device=device, dtype=torch.float32).reshape(1, 1, 1, 1) + _dummy_scale_tensors[device] = t + return t + def _create_cudnn_handle(stream: torch.cuda.Stream): global _cudnn_handle + if _cudnn_handle is None: _cudnn_handle = cudnn.create_handle() cudnn.set_stream(_cudnn_handle, stream.cuda_stream) @@ -49,6 +60,16 @@ class UIDs(Enum): O_UID = 1000 # Output tensor STATS_UID = 1001 # Stats tensor + Q_SCALE_UID = 150 # Query scale tensor + K_SCALE_UID = 151 # Key scale tensor + V_SCALE_UID = 152 # Value scale tensor + S_SCALE_UID = 153 # Scale tensor + S_DESCALE_UID = 154 # Descale tensor + O_SCALE_UID = 155 # Output scale tensor + + S_AMAX_UID = 160 # Scale amax tensor + O_AMAX_UID = 161 # Output amax tensor + def _sdpa_prefill_key_fn( q: torch.Tensor, @@ -71,6 +92,7 @@ def _sdpa_prefill_key_fn( batch_offsets_stats: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + o_data_type: Optional[torch.dtype] = None, ): graph_b = actual_seq_lens_q.shape[0] @@ -90,6 +112,7 @@ def _sdpa_prefill_key_fn( key = ( graph_b, q.dim(), + q.dtype, k_cache.dim(), max_token_seq_q, max_sequence_kv, @@ -101,6 +124,7 @@ def _sdpa_prefill_key_fn( return_lse, bottom_right_causal_mask, page_size, + o_data_type, ) return key @@ -129,6 +153,7 @@ def _build_prefill_graph( batch_offsets_stats: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + o_data_type: Optional[torch.dtype] = None, ): handle = _create_cudnn_handle(torch.cuda.current_stream(q.device)) @@ -136,23 +161,97 @@ def _build_prefill_graph( graph_s_qo = max_token_seq_q graph_s_kv = max_sequence_kv + if not cudnn.datatypes.is_torch_available(): + raise RuntimeError("torch is not available") + + cudnn_q_data_type = cudnn.datatypes._torch_to_cudnn_data_type(q.dtype) + cudnn_k_data_type = cudnn.datatypes._torch_to_cudnn_data_type(k_cache.dtype) + cudnn_v_data_type = cudnn.datatypes._torch_to_cudnn_data_type(v_cache.dtype) + + if o_data_type is None: + o_data_type = q.dtype + + cudnn_o_data_type = cudnn.datatypes._torch_to_cudnn_data_type(o_data_type) + + if ( + cudnn_q_data_type == cudnn.data_type.FP8_E4M3 + or cudnn_q_data_type == cudnn.data_type.FP8_E5M2 + ) and cudnn.backend_version() < 91701: + raise RuntimeError( + f"FP8 is not supported in cuDNN backend version < 9.17.1, current version is {cudnn.backend_version()}" + ) + with cudnn.graph(handle) as (g, _): # Create tensors from the input tensors if q.dim() == 3: h_qo, d_qk = q.shape[1], q.shape[2] + s_stride, h_stride, d_stride = q.stride() elif q.dim() == 4: h_qo, d_qk = q.shape[2], q.shape[3] + s_stride, h_stride, d_stride = q.stride() else: raise ValueError(f"Invalid query tensor shape: {q.shape}") - s_stride, h_stride, d_stride = q.stride() cudnn_q = g.tensor( name="q", dim=(graph_b, h_qo, graph_s_qo, d_qk), stride=(h_qo * d_qk, h_stride, s_stride, d_stride), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_q_data_type, ) + if ( + cudnn_q_data_type == cudnn.data_type.FP8_E4M3 + or cudnn_q_data_type == cudnn.data_type.FP8_E5M2 + ): + cudnn_q_scale = g.tensor( + name="q_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_k_scale = g.tensor( + name="k_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_v_scale = g.tensor( + name="v_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_s_scale = g.tensor( + name="s_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_s_descale = g.tensor( + name="s_descale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_o_scale = g.tensor( + name="o_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_q_scale.set_uid(UIDs.Q_SCALE_UID.value) + cudnn_k_scale.set_uid(UIDs.K_SCALE_UID.value) + cudnn_v_scale.set_uid(UIDs.V_SCALE_UID.value) + cudnn_s_scale.set_uid(UIDs.S_SCALE_UID.value) + cudnn_s_descale.set_uid(UIDs.S_DESCALE_UID.value) + cudnn_o_scale.set_uid(UIDs.O_SCALE_UID.value) + if batch_offsets_q is not None: ragged_q = g.tensor_like(batch_offsets_q) ragged_q.set_uid(UIDs.RAGGED_Q_UID.value) @@ -177,7 +276,7 @@ def _build_prefill_graph( name="k_cache", dim=(graph_b, h_kv, graph_s_kv, d_qk), stride=(h_kv * d_qk * graph_s_kv, h_stride, s_stride, d_stride), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_k_data_type, ) if batch_offsets_k is not None: @@ -185,12 +284,15 @@ def _build_prefill_graph( ragged_k.set_uid(UIDs.RAGGED_K_UID.value) cudnn_k_cache.set_ragged_offset(ragged_k) + assert v_cache.dim() == 3, ( + "v_cache must have 3 dimensions since k_cache has 3 dimensions" + ) s_stride, h_stride, d_stride = v_cache.stride() cudnn_v_cache = g.tensor( name="v_cache", dim=(graph_b, h_kv, graph_s_kv, d_vo), stride=(h_kv * d_vo * graph_s_kv, h_stride, s_stride, d_stride), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_v_data_type, ) if batch_offsets_v is not None: @@ -203,14 +305,14 @@ def _build_prefill_graph( name="k_cache", dim=k_cache.shape, stride=k_cache.stride(), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_k_data_type, ) cudnn_v_cache = g.tensor( name="v_cache", dim=v_cache.shape, stride=v_cache.stride(), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_v_data_type, ) cudnn_q.set_uid(UIDs.Q_UID.value) @@ -241,32 +343,86 @@ def _build_prefill_graph( actual_seq_lens_q is not None and actual_seq_lens_kv is not None ) - O, Stats = g.sdpa( - name="sdpa", - q=cudnn_q, - k=cudnn_k_cache, - v=cudnn_v_cache, - seq_len_q=( - cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None - ), - seq_len_kv=( - cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None - ), - use_padding_mask=padding_mask, - attn_scale=scale, - generate_stats=return_lse, - use_causal_mask_bottom_right=bottom_right_causal_mask, - paged_attention_k_table=( - cudnn_k_block_tables if block_tables is not None else None - ), - paged_attention_v_table=( - cudnn_v_block_tables if block_tables is not None else None - ), - paged_attention_max_seq_len_kv=( - graph_s_kv if block_tables is not None else None - ), - compute_data_type=cudnn.data_type.FLOAT, - ) + if ( + cudnn_q_data_type == cudnn.data_type.BFLOAT16 + or cudnn_q_data_type == cudnn.data_type.HALF + ): + O, Stats = g.sdpa( + name="sdpa", + q=cudnn_q, + k=cudnn_k_cache, + v=cudnn_v_cache, + seq_len_q=( + cudnn_actual_seq_lens_q + if actual_seq_lens_q is not None + else None + ), + seq_len_kv=( + cudnn_actual_seq_lens_kv + if actual_seq_lens_kv is not None + else None + ), + use_padding_mask=padding_mask, + attn_scale=scale, + generate_stats=return_lse, + use_causal_mask_bottom_right=bottom_right_causal_mask, + paged_attention_k_table=( + cudnn_k_block_tables if block_tables is not None else None + ), + paged_attention_v_table=( + cudnn_v_block_tables if block_tables is not None else None + ), + paged_attention_max_seq_len_kv=( + graph_s_kv if block_tables is not None else None + ), + compute_data_type=cudnn.data_type.FLOAT, + ) + + elif ( + cudnn_q_data_type == cudnn.data_type.FP8_E4M3 + or cudnn_q_data_type == cudnn.data_type.FP8_E5M2 + ): + O, Stats, amax_s, amax_o = g.sdpa_fp8( + q=cudnn_q, + k=cudnn_k_cache, + v=cudnn_v_cache, + descale_q=cudnn_q_scale, + descale_k=cudnn_k_scale, + descale_v=cudnn_v_scale, + scale_s=cudnn_s_scale, + descale_s=cudnn_s_descale, + scale_o=cudnn_o_scale, + generate_stats=True, + attn_scale=scale, + use_causal_mask_bottom_right=bottom_right_causal_mask, + use_padding_mask=padding_mask, + seq_len_q=( + cudnn_actual_seq_lens_q + if actual_seq_lens_q is not None + else None + ), + seq_len_kv=( + cudnn_actual_seq_lens_kv + if actual_seq_lens_kv is not None + else None + ), + paged_attention_k_table=( + cudnn_k_block_tables if block_tables is not None else None + ), + paged_attention_v_table=( + cudnn_v_block_tables if block_tables is not None else None + ), + paged_attention_max_seq_len_kv=( + graph_s_kv if block_tables is not None else None + ), + ) + + amax_s.set_uid(UIDs.S_AMAX_UID.value).set_output(False).set_dim( + (1, 1, 1, 1) + ).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) + amax_o.set_uid(UIDs.O_AMAX_UID.value).set_output(False).set_dim( + (1, 1, 1, 1) + ).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) if batch_offsets_o is not None: ragged_o = g.tensor_like(batch_offsets_o) @@ -282,7 +438,7 @@ def _build_prefill_graph( [graph_b, h_qo, graph_s_qo, d_vo] ).set_stride( [graph_s_qo * d_vo * h_qo, d_vo, d_vo * h_qo, 1] - ).set_data_type(cudnn.data_type.BFLOAT16) + ).set_data_type(cudnn_o_data_type) if return_lse: Stats.set_uid(UIDs.STATS_UID.value).set_output( @@ -317,6 +473,9 @@ def _batch_prefill_with_kv_cache( block_tables: Optional[torch.Tensor] = None, causal: bool, return_lse: bool, + q_scale: Optional[torch.Tensor] = None, + k_scale: Optional[torch.Tensor] = None, + v_scale: Optional[torch.Tensor] = None, batch_offsets_q: Optional[torch.Tensor] = None, batch_offsets_o: Optional[torch.Tensor] = None, batch_offsets_k: Optional[torch.Tensor] = None, @@ -324,6 +483,7 @@ def _batch_prefill_with_kv_cache( batch_offsets_stats: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + o_data_type: Optional[torch.dtype] = None, ) -> tuple[torch.Tensor, torch.Tensor]: graph, tensors = _build_prefill_graph( q=q, @@ -344,6 +504,7 @@ def _batch_prefill_with_kv_cache( batch_offsets_stats=batch_offsets_stats, out=out, lse=lse, + o_data_type=o_data_type, ) var_map = { @@ -377,6 +538,17 @@ def _batch_prefill_with_kv_cache( if batch_offsets_stats is not None: var_map[UIDs.RAGGED_STATS_UID.value] = batch_offsets_stats + if q_scale is not None: + dummy_scale_tensor = _get_dummy_scale_tensor(q.device) + var_map[UIDs.Q_SCALE_UID.value] = q_scale + var_map[UIDs.S_SCALE_UID.value] = dummy_scale_tensor + var_map[UIDs.S_DESCALE_UID.value] = dummy_scale_tensor + var_map[UIDs.O_SCALE_UID.value] = dummy_scale_tensor + if k_scale is not None: + var_map[UIDs.K_SCALE_UID.value] = k_scale + if v_scale is not None: + var_map[UIDs.V_SCALE_UID.value] = v_scale + handle = _create_cudnn_handle(torch.cuda.current_stream(q.device)) graph.execute(var_map, workspace=workspace_buffer, handle=handle) @@ -400,6 +572,9 @@ def cudnn_batch_prefill_with_kv_cache( block_tables: Optional[torch.Tensor] = None, causal: bool, return_lse: bool, + q_scale: Optional[torch.Tensor] = None, + k_scale: Optional[torch.Tensor] = None, + v_scale: Optional[torch.Tensor] = None, batch_offsets_q: Optional[torch.Tensor] = None, batch_offsets_o: Optional[torch.Tensor] = None, batch_offsets_k: Optional[torch.Tensor] = None, @@ -409,6 +584,7 @@ def cudnn_batch_prefill_with_kv_cache( lse: Optional[torch.Tensor] = None, is_cuda_graph_compatible: bool = False, backend: Optional[str] = None, + o_data_type: Optional[torch.dtype] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Performs batched prefill attention with paged KV cache using cuDNN. @@ -428,11 +604,14 @@ def cudnn_batch_prefill_with_kv_cache( out: Optional pre-allocated output tensor lse: Optional pre-allocated tensor for log-sum-exp values if return_lse is True else returns None is_cuda_graph_compatible: Whether the prefill operation is compatible with CUDA graph + q_scale: Optional scale tensor for query tensor of shape (1, 1, 1, 1) on GPU + k_scale: Optional scale tensor for key tensor of shape (1, 1, 1, 1) on GPU + v_scale: Optional scale tensor for value tensor of shape (1, 1, 1, 1) on GPU batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU batch_offsets_v: Optional batch offsets for value tensor of shape (batch_size,) on GPU - + o_data_type: Optional data type for output tensor Returns: Output tensor of shape (batch_size * seq_len_q, num_heads_qo, head_dim) If return_lse is True, also returns log-sum-exp tensor of shape (batch_size, seq_len_q, num_heads_qo) @@ -473,9 +652,12 @@ def cudnn_batch_prefill_with_kv_cache( "lse must have shape (num_sequences, max_token_per_sequence, h_qo)" ) + if o_data_type is None: + o_data_type = q.dtype + if out is None: out_shape = (num_tokens, h_qo, d_vo) - out = torch.empty(out_shape, device=q.device, dtype=q.dtype) + out = torch.empty(out_shape, device=q.device, dtype=o_data_type) if CUDNN_AVAILABLE and backend != "cubin": return _batch_prefill_with_kv_cache( @@ -491,6 +673,9 @@ def cudnn_batch_prefill_with_kv_cache( block_tables=block_tables, causal=causal, return_lse=return_lse, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, batch_offsets_q=batch_offsets_q, batch_offsets_o=batch_offsets_o, batch_offsets_k=batch_offsets_k, @@ -498,6 +683,7 @@ def cudnn_batch_prefill_with_kv_cache( batch_offsets_stats=batch_offsets_stats, out=out, lse=lse, + o_data_type=o_data_type, ) else: assert return_lse, "Currently only supports return_lse = True"