diff --git a/minference/ops/block_sparse_flash_attention.py b/minference/ops/block_sparse_flash_attention.py index b651189..4f9bd09 100644 --- a/minference/ops/block_sparse_flash_attention.py +++ b/minference/ops/block_sparse_flash_attention.py @@ -175,11 +175,12 @@ def block_sparse_attention( block_size_N: int = 64, ): batch_size, num_heads, context_size, head_dim = query.shape - pad = block_size_M - (query.shape[2] & (block_size_M - 1)) + mask_M = block_size_M - 1 + pad = (block_size_M - (query.shape[2] & mask_M)) & mask_M query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0]) key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0]) value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) - seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device) + seqlens = torch.tensor([context_size] * batch_size, dtype=torch.int32, device=query.device) sm_scale = head_dim ** -0.5 block_index = _build_block_index(query, key, top_k, block_size_N, block_size_N) out = _triton_block_sparse_attention(query, key, value, seqlens, block_index, sm_scale, block_size_M, block_size_N)