Skip to content

[bug] test_attention_decode_bf16 can't pass with block_size=32 #26

@WingEdge777

Description

@WingEdge777

Additionally, set q_head = 64 and kv_head=8 (which is qwen3-32b's config), failing in some shape. There may exist a bug in corner cases.

error log:

1,1,1024,flash_attn: 0.066 ms,hpc: 0.021 ms
1,1,4096,flash_attn: 0.066 ms,hpc: 0.021 ms
8,1,1024,flash_attn: 0.067 ms,hpc: 0.022 ms
8,1,4096,flash_attn: 0.110 ms,hpc: 0.039 ms
16,1,1024,flash_attn: 0.067 ms,hpc: 0.021 ms
16,1,4096,flash_attn: 0.161 ms,hpc: 0.058 ms
32,1,1024,flash_attn: 0.088 ms,hpc: 0.038 ms
32,1,4096,flash_attn: 0.305 ms,hpc: 0.097 ms
Traceback (most recent call last):
  File "/data/autovision-cbs/hpc-ops/tests/bench_code.py", line 171, in <module>
    f, h = test_attention_decode_bf16(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/autovision-cbs/hpc-ops/tests/bench_code.py", line 137, in test_attention_decode_bf16
    torch.cuda.synchronize()
  File "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py", line 1083, in synchronize
    return torch._C._cuda_synchronize()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

test code:

import os
import sys
from pathlib import Path

sys.path.insert(
    0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])
)

import time
import math
import torch
import hpc
from utils import allclose
from flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache


def test_attention_decode_bf16(
    num_batch,
    num_seq_q,
    num_seq_kv,
    block_size,
    num_head_q,
    num_head_kv,
    num_dim_qk,
    num_dim_v,
    use_output=True,
    rep=1000,
):

    q = torch.randn(
        (num_batch, num_seq_q, num_head_q, num_dim_qk),
        dtype=torch.bfloat16,
        device="cuda",
    ) / math.sqrt(num_dim_qk)
    k = torch.randn(
        (num_batch * num_seq_q, num_head_kv, num_dim_qk),
        dtype=torch.bfloat16,
        device="cuda",
    ) / math.sqrt(num_dim_qk)
    v = torch.randn(
        (num_batch * num_seq_q, num_head_kv, num_dim_v),
        dtype=torch.bfloat16,
        device="cuda",
    )

    seqlens_q = torch.full((num_batch,), num_seq_q, dtype=torch.int32, device="cuda")
    seqlens_kvcache = torch.full(
        (num_batch,), num_seq_kv, dtype=torch.int32, device="cuda"
    )
    cu_seqlens_q = torch.cumsum(
        torch.cat([torch.tensor([0], dtype=torch.int32, device="cuda"), seqlens_q]),
        dim=0,
    ).to(torch.int32)
    cu_seqlens_kvcache = torch.cumsum(
        torch.cat(
            [torch.tensor([0], dtype=torch.int32, device="cuda"), seqlens_kvcache]
        ),
        dim=0,
    ).to(torch.int32)

    max_num_blocks = num_batch * (num_seq_kv + block_size - 1) // block_size * 2
    kvcache_blocks = (seqlens_kvcache + block_size - 1) // block_size
    total_kvcache_blocks = sum(kvcache_blocks)
    max_kvcache_blocks = max(kvcache_blocks)
    max_seqlens_q = max(seqlens_q)
    max_seqlens_kvcache = max(seqlens_kvcache)

    kvcache = torch.randn(
        max_num_blocks * 2,
        2,
        block_size,
        num_head_kv,
        num_dim_qk,
        dtype=torch.bfloat16,
        device="cuda",
    )
    packed_block_ids = (
        torch.randperm(max_num_blocks)[:total_kvcache_blocks].to(torch.int32).cuda()
    )

    cu_blocks = 0
    block_ids = torch.empty(
        num_batch, max_kvcache_blocks, dtype=torch.int32, device="cuda"
    )
    for i in range(num_batch):
        block_ids[i, : kvcache_blocks[i]] = packed_block_ids[
            cu_blocks : cu_blocks + kvcache_blocks[i]
        ]
        cu_blocks += kvcache_blocks[i]

    for i in range(10):
        gt = flash_attn_with_kvcache(
            q=q,
            k_cache=kvcache[:, 0, :, :],
            v_cache=kvcache[:, 1, :, :],
            cache_seqlens=seqlens_kvcache,
            page_table=block_ids,
            causal=True,
        )
    torch.cuda.synchronize()
    st = time.time()
    for i in range(rep):
        gt = flash_attn_with_kvcache(
            q=q,
            k_cache=kvcache[:, 0, :, :],
            v_cache=kvcache[:, 1, :, :],
            cache_seqlens=seqlens_kvcache,
            page_table=block_ids,
            causal=True,
        )
    torch.cuda.synchronize()
    en = time.time()
    f_time = (en - st) / rep * 1000

    gt = gt.reshape(-1, num_head_q, num_dim_v)

    num_seq_kvcache = (
        torch.randint(1, num_seq_kv, (num_batch,), dtype=torch.int32, device="cuda") * 0
        + num_seq_kv
    )
    new_kv_included = True

    my = torch.zeros_like(gt, device="cuda")

    for i in range(10):
        hpc.attention_decode_bf16(
            q.reshape(-1, num_head_q, num_dim_qk),
            kvcache[:, 0, :, :, :],
            kvcache[:, 1, :, :, :],
            block_ids,
            num_seq_kvcache + 1 if new_kv_included else num_seq_kvcache,
            new_kv_included=True,
            splitk=True,
            output=my,
        )

    torch.cuda.synchronize()
    st = time.time()
    for i in range(rep):
        hpc.attention_decode_bf16(
            q.reshape(-1, num_head_q, num_dim_qk),
            kvcache[:, 0, :, :, :],
            kvcache[:, 1, :, :, :],
            block_ids,
            num_seq_kvcache + 1 if new_kv_included else num_seq_kvcache,
            new_kv_included=True,
            splitk=True,
            output=my,
        )
    torch.cuda.synchronize()
    en = time.time()
    h_time = (en - st) / rep * 1000

    assert allclose(gt, my, atol=0.016, )
    return f"flash_attn: {f_time:.3f} ms", f"hpc: {h_time:.3f} ms"


if __name__ == "__main__":
    num_batchs = [1, 8, 16, 32, 64, 128]
    num_seq_qs = [1]
    num_seq_kvs = [1024, 4096]
    block_size = 32 #64
    # num_head_q = 64
    # num_head_kv = 8
    num_head_q = 32
    num_head_kv = 4
    dim = 128
    for num_batch in num_batchs:
        for num_seq_q in num_seq_qs:
            for num_seq_kv in num_seq_kvs:
                f, h = test_attention_decode_bf16(
                    num_batch,
                    num_seq_q,
                    num_seq_kv,
                    block_size,
                    num_head_q,
                    num_head_kv,
                    dim,
                    dim,
                    True,
                )
                print(",".join(map(str, [num_batch, num_seq_q, num_seq_kv, f, h])))

Nevertheless, set block_size=64, the decode kernel is much faster than FA-3, which is very impressive.

1,1,1024,flash_attn: 0.063 ms,hpc: 0.021 ms
1,1,4096,flash_attn: 0.062 ms,hpc: 0.021 ms
8,1,1024,flash_attn: 0.063 ms,hpc: 0.021 ms
8,1,4096,flash_attn: 0.111 ms,hpc: 0.039 ms
16,1,1024,flash_attn: 0.062 ms,hpc: 0.021 ms
16,1,4096,flash_attn: 0.161 ms,hpc: 0.059 ms
32,1,1024,flash_attn: 0.088 ms,hpc: 0.039 ms
32,1,4096,flash_attn: 0.305 ms,hpc: 0.097 ms
64,1,1024,flash_attn: 0.165 ms,hpc: 0.060 ms
64,1,4096,flash_attn: 0.599 ms,hpc: 0.193 ms
128,1,1024,flash_attn: 0.283 ms,hpc: 0.092 ms
128,1,4096,flash_attn: 1.044 ms,hpc: 0.332 ms

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions