Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 69 additions & 14 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,12 @@ def test_dot_product_attention(
flash_attn_supported = True

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
# Double-count the CK backend since we want to compare V2/V3 kernels
if (
len(fused_attn_backends) +
int(IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends) +
flash_attn_supported + unfused_attn_supported
) < 2:
pytest.skip("Less than two backends to compare.")

# UnfusedDotProductAttention backend
Expand Down Expand Up @@ -269,7 +274,21 @@ def test_dot_product_attention(
pad_between_seqs,
is_training,
)
if len(fused_attn_backends) == 2:
# We can consider the CK backend as being two, since we have V2/V3 kernels
if IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends:
os.environ["NVTE_CK_USES_FWD_V3"] = "0"
os.environ["NVTE_CK_USES_BWD_V3"] = "0"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)
elif len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
os.environ["NVTE_FUSED_ATTN_CK"] = "0"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1"
Expand All @@ -286,8 +305,6 @@ def test_dot_product_attention(
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
os.environ["NVTE_FUSED_ATTN_CK"] = "1"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0"
os.environ["NVTE_CK_USES_FWD_V3"] = "1"
os.environ["NVTE_CK_USES_BWD_V3"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype,
config,
Expand Down Expand Up @@ -352,6 +369,16 @@ def test_dot_product_attention(
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols)
if (
fused_attn_supported and
len(fused_attn_backends) == 1 and
IS_HIP_EXTENSION and
FusedAttnBackend["CK"] in fused_attn_backends
):
logging.info("[test_dot_product_attention]: CK fused attn V2 vs V3")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
Expand Down Expand Up @@ -1259,7 +1286,12 @@ def test_transformer_layer(
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
# Double-count the CK backend since we want to compare V2/V3 kernels
if (
len(fused_attn_backends) +
int(IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends) +
flash_attn_supported + unfused_attn_supported
) < 2:
pytest.skip("Less than two backends to compare.")
# Skip if qkv_format = thd and "padding" not in attn_mask_type
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
Expand Down Expand Up @@ -1293,6 +1325,20 @@ def test_transformer_layer(
RoPE,
is_training,
)
if IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends:
os.environ["NVTE_CK_USES_FWD_V3"] = "0"
os.environ["NVTE_CK_USES_BWD_V3"] = "0"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_transformer_layer(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)
elif len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_CK"] = "0"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1"
Expand Down Expand Up @@ -1363,15 +1409,24 @@ def test_transformer_layer(
logging.info("[test_transformer_layer]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
if IS_HIP_EXTENSION and fused_attn_supported and len(fused_attn_backends) == 2:
logging.info("[test_transformer_layer]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
logging.info("[test_transformer_layer]: fused attn backend 0 vs 2")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols)
if IS_HIP_EXTENSION and fused_attn_supported:
if len(fused_attn_backends) == 2:
logging.info("[test_transformer_layer]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
logging.info("[test_transformer_layer]: fused attn backend 0 vs 2")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols)
elif (
len(fused_attn_backends) == 1 and
FusedAttnBackend["CK"] in fused_attn_backends
):
logging.info("[test_dot_product_attention]: CK fused attn V2 vs V3")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
Expand Down