From 80ba7a0d8895a418fad06bb774849f7d0756f7bc Mon Sep 17 00:00:00 2001 From: Livinfly Date: Mon, 22 Dec 2025 17:26:39 +0800 Subject: [PATCH 1/2] Fix(MInference): fix redundancy pad in block_sparse_flash_attention Signed-off-by: Livinfly --- minference/ops/block_sparse_flash_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/minference/ops/block_sparse_flash_attention.py b/minference/ops/block_sparse_flash_attention.py index b651189..b736e13 100644 --- a/minference/ops/block_sparse_flash_attention.py +++ b/minference/ops/block_sparse_flash_attention.py @@ -175,7 +175,8 @@ def block_sparse_attention( block_size_N: int = 64, ): batch_size, num_heads, context_size, head_dim = query.shape - pad = block_size_M - (query.shape[2] & (block_size_M - 1)) + mask_M = block_size_M - 1 + pad = (block_size_M - (query.shape[2] & mask_M)) & mask_M query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0]) key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0]) value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) From 2cd80cab4849c4757086930534cf729928951218 Mon Sep 17 00:00:00 2001 From: Livinfly Date: Mon, 22 Dec 2025 21:03:41 +0800 Subject: [PATCH 2/2] Fix(MInference): fix Out-of-Bounds Access in block_sparse_flash_attention Signed-off-by: Livinfly --- minference/ops/block_sparse_flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/minference/ops/block_sparse_flash_attention.py b/minference/ops/block_sparse_flash_attention.py index b736e13..4f9bd09 100644 --- a/minference/ops/block_sparse_flash_attention.py +++ b/minference/ops/block_sparse_flash_attention.py @@ -180,7 +180,7 @@ def block_sparse_attention( query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0]) key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0]) value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) - seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device) + seqlens = torch.tensor([context_size] * batch_size, dtype=torch.int32, device=query.device) sm_scale = head_dim ** -0.5 block_index = _build_block_index(query, key, top_k, block_size_N, block_size_N) out = _triton_block_sparse_attention(query, key, value, seqlens, block_index, sm_scale, block_size_M, block_size_N)