-
Notifications
You must be signed in to change notification settings - Fork 57
Closed
Description
Additionally, set q_head = 64 and kv_head=8 (which is qwen3-32b's config), failing in some shape. There may exist a bug in corner cases.
error log:
1,1,1024,flash_attn: 0.066 ms,hpc: 0.021 ms
1,1,4096,flash_attn: 0.066 ms,hpc: 0.021 ms
8,1,1024,flash_attn: 0.067 ms,hpc: 0.022 ms
8,1,4096,flash_attn: 0.110 ms,hpc: 0.039 ms
16,1,1024,flash_attn: 0.067 ms,hpc: 0.021 ms
16,1,4096,flash_attn: 0.161 ms,hpc: 0.058 ms
32,1,1024,flash_attn: 0.088 ms,hpc: 0.038 ms
32,1,4096,flash_attn: 0.305 ms,hpc: 0.097 ms
Traceback (most recent call last):
File "/data/autovision-cbs/hpc-ops/tests/bench_code.py", line 171, in <module>
f, h = test_attention_decode_bf16(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/autovision-cbs/hpc-ops/tests/bench_code.py", line 137, in test_attention_decode_bf16
torch.cuda.synchronize()
File "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py", line 1083, in synchronize
return torch._C._cuda_synchronize()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
test code:
import os
import sys
from pathlib import Path
sys.path.insert(
0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])
)
import time
import math
import torch
import hpc
from utils import allclose
from flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
def test_attention_decode_bf16(
num_batch,
num_seq_q,
num_seq_kv,
block_size,
num_head_q,
num_head_kv,
num_dim_qk,
num_dim_v,
use_output=True,
rep=1000,
):
q = torch.randn(
(num_batch, num_seq_q, num_head_q, num_dim_qk),
dtype=torch.bfloat16,
device="cuda",
) / math.sqrt(num_dim_qk)
k = torch.randn(
(num_batch * num_seq_q, num_head_kv, num_dim_qk),
dtype=torch.bfloat16,
device="cuda",
) / math.sqrt(num_dim_qk)
v = torch.randn(
(num_batch * num_seq_q, num_head_kv, num_dim_v),
dtype=torch.bfloat16,
device="cuda",
)
seqlens_q = torch.full((num_batch,), num_seq_q, dtype=torch.int32, device="cuda")
seqlens_kvcache = torch.full(
(num_batch,), num_seq_kv, dtype=torch.int32, device="cuda"
)
cu_seqlens_q = torch.cumsum(
torch.cat([torch.tensor([0], dtype=torch.int32, device="cuda"), seqlens_q]),
dim=0,
).to(torch.int32)
cu_seqlens_kvcache = torch.cumsum(
torch.cat(
[torch.tensor([0], dtype=torch.int32, device="cuda"), seqlens_kvcache]
),
dim=0,
).to(torch.int32)
max_num_blocks = num_batch * (num_seq_kv + block_size - 1) // block_size * 2
kvcache_blocks = (seqlens_kvcache + block_size - 1) // block_size
total_kvcache_blocks = sum(kvcache_blocks)
max_kvcache_blocks = max(kvcache_blocks)
max_seqlens_q = max(seqlens_q)
max_seqlens_kvcache = max(seqlens_kvcache)
kvcache = torch.randn(
max_num_blocks * 2,
2,
block_size,
num_head_kv,
num_dim_qk,
dtype=torch.bfloat16,
device="cuda",
)
packed_block_ids = (
torch.randperm(max_num_blocks)[:total_kvcache_blocks].to(torch.int32).cuda()
)
cu_blocks = 0
block_ids = torch.empty(
num_batch, max_kvcache_blocks, dtype=torch.int32, device="cuda"
)
for i in range(num_batch):
block_ids[i, : kvcache_blocks[i]] = packed_block_ids[
cu_blocks : cu_blocks + kvcache_blocks[i]
]
cu_blocks += kvcache_blocks[i]
for i in range(10):
gt = flash_attn_with_kvcache(
q=q,
k_cache=kvcache[:, 0, :, :],
v_cache=kvcache[:, 1, :, :],
cache_seqlens=seqlens_kvcache,
page_table=block_ids,
causal=True,
)
torch.cuda.synchronize()
st = time.time()
for i in range(rep):
gt = flash_attn_with_kvcache(
q=q,
k_cache=kvcache[:, 0, :, :],
v_cache=kvcache[:, 1, :, :],
cache_seqlens=seqlens_kvcache,
page_table=block_ids,
causal=True,
)
torch.cuda.synchronize()
en = time.time()
f_time = (en - st) / rep * 1000
gt = gt.reshape(-1, num_head_q, num_dim_v)
num_seq_kvcache = (
torch.randint(1, num_seq_kv, (num_batch,), dtype=torch.int32, device="cuda") * 0
+ num_seq_kv
)
new_kv_included = True
my = torch.zeros_like(gt, device="cuda")
for i in range(10):
hpc.attention_decode_bf16(
q.reshape(-1, num_head_q, num_dim_qk),
kvcache[:, 0, :, :, :],
kvcache[:, 1, :, :, :],
block_ids,
num_seq_kvcache + 1 if new_kv_included else num_seq_kvcache,
new_kv_included=True,
splitk=True,
output=my,
)
torch.cuda.synchronize()
st = time.time()
for i in range(rep):
hpc.attention_decode_bf16(
q.reshape(-1, num_head_q, num_dim_qk),
kvcache[:, 0, :, :, :],
kvcache[:, 1, :, :, :],
block_ids,
num_seq_kvcache + 1 if new_kv_included else num_seq_kvcache,
new_kv_included=True,
splitk=True,
output=my,
)
torch.cuda.synchronize()
en = time.time()
h_time = (en - st) / rep * 1000
assert allclose(gt, my, atol=0.016, )
return f"flash_attn: {f_time:.3f} ms", f"hpc: {h_time:.3f} ms"
if __name__ == "__main__":
num_batchs = [1, 8, 16, 32, 64, 128]
num_seq_qs = [1]
num_seq_kvs = [1024, 4096]
block_size = 32 #64
# num_head_q = 64
# num_head_kv = 8
num_head_q = 32
num_head_kv = 4
dim = 128
for num_batch in num_batchs:
for num_seq_q in num_seq_qs:
for num_seq_kv in num_seq_kvs:
f, h = test_attention_decode_bf16(
num_batch,
num_seq_q,
num_seq_kv,
block_size,
num_head_q,
num_head_kv,
dim,
dim,
True,
)
print(",".join(map(str, [num_batch, num_seq_q, num_seq_kv, f, h])))Nevertheless, set block_size=64, the decode kernel is much faster than FA-3, which is very impressive.
1,1,1024,flash_attn: 0.063 ms,hpc: 0.021 ms
1,1,4096,flash_attn: 0.062 ms,hpc: 0.021 ms
8,1,1024,flash_attn: 0.063 ms,hpc: 0.021 ms
8,1,4096,flash_attn: 0.111 ms,hpc: 0.039 ms
16,1,1024,flash_attn: 0.062 ms,hpc: 0.021 ms
16,1,4096,flash_attn: 0.161 ms,hpc: 0.059 ms
32,1,1024,flash_attn: 0.088 ms,hpc: 0.039 ms
32,1,4096,flash_attn: 0.305 ms,hpc: 0.097 ms
64,1,1024,flash_attn: 0.165 ms,hpc: 0.060 ms
64,1,4096,flash_attn: 0.599 ms,hpc: 0.193 ms
128,1,1024,flash_attn: 0.283 ms,hpc: 0.092 ms
128,1,4096,flash_attn: 1.044 ms,hpc: 0.332 msReactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels