From d6e659145da2f7e71403d44ba8deb4b8c405bc89 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 20 Feb 2026 08:57:13 -0600 Subject: [PATCH 1/4] Updated test to include CK/AITER V2/V3 test in single backend case --- tests/pytorch/attention/test_attention.py | 37 ++++++++++++++++++++--- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index bbd16ba0e..0f1d78acf 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -258,7 +258,7 @@ def test_dot_product_attention( # FusedAttention backend if fused_attn_supported: - if len(fused_attn_backends) == 1: + if len(fused_attn_backends) == 1 and FusedAttnBackend["CK"] not in fused_attn_backends: fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( dtype, config, @@ -269,7 +269,33 @@ def test_dot_product_attention( pad_between_seqs, is_training, ) - if len(fused_attn_backends) == 2: + # We can consider the CK backend as being two, since we have V2/V3 kernels + elif len(fused_attn_backends) == 1: + os.environ["NVTE_FUSED_ATTN_CK"] = "1" + os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0" + fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + pad_between_seqs, + is_training, + ) + os.environ["NVTE_CK_USES_FWD_V3"] = "0" + os.environ["NVTE_CK_USES_BWD_V3"] = "0" + fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + pad_between_seqs, + is_training, + ) + elif len(fused_attn_backends) == 2: os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_CK"] = "0" os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1" @@ -286,8 +312,6 @@ def test_dot_product_attention( os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" os.environ["NVTE_FUSED_ATTN_CK"] = "1" os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0" - os.environ["NVTE_CK_USES_FWD_V3"] = "1" - os.environ["NVTE_CK_USES_BWD_V3"] = "1" fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( dtype, config, @@ -352,6 +376,11 @@ def test_dot_product_attention( torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols) for i, _ in enumerate(fused_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols) + if fused_attn_supported and len(fused_attn_backends) == 1 and FusedAttnBackend["CK"] in fused_attn_backends: + logging.info("[test_dot_product_attention]: CK fused attn V2 vs V3") + torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) + for i, _ in enumerate(fused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") From 83c715dfee46d79de9253f3795449206f5dcce26 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 20 Feb 2026 13:16:19 -0600 Subject: [PATCH 2/4] Streamlined per PR review --- tests/pytorch/attention/test_attention.py | 42 ++++++++--------------- 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 0f1d78acf..75ce64126 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -258,7 +258,7 @@ def test_dot_product_attention( # FusedAttention backend if fused_attn_supported: - if len(fused_attn_backends) == 1 and FusedAttnBackend["CK"] not in fused_attn_backends: + if len(fused_attn_backends) == 1: fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( dtype, config, @@ -269,32 +269,20 @@ def test_dot_product_attention( pad_between_seqs, is_training, ) - # We can consider the CK backend as being two, since we have V2/V3 kernels - elif len(fused_attn_backends) == 1: - os.environ["NVTE_FUSED_ATTN_CK"] = "1" - os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0" - fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( - dtype, - config, - "FusedAttention", - ckpt_attn, - qkv_layout, - workspace_opt, - pad_between_seqs, - is_training, - ) - os.environ["NVTE_CK_USES_FWD_V3"] = "0" - os.environ["NVTE_CK_USES_BWD_V3"] = "0" - fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( - dtype, - config, - "FusedAttention", - ckpt_attn, - qkv_layout, - workspace_opt, - pad_between_seqs, - is_training, - ) + # We can consider the CK backend as being two, since we have V2/V3 kernels + if IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends: + os.environ["NVTE_CK_USES_FWD_V3"] = "0" + os.environ["NVTE_CK_USES_BWD_V3"] = "0" + fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + pad_between_seqs, + is_training, + ) elif len(fused_attn_backends) == 2: os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_CK"] = "0" From b0b9906e1a2c53d33098428811679a7f7bd8829b Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 20 Feb 2026 13:20:39 -0600 Subject: [PATCH 3/4] Updated skip condition --- tests/pytorch/attention/test_attention.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 75ce64126..6430efe91 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -240,7 +240,12 @@ def test_dot_product_attention( flash_attn_supported = True # Skip if only unfused backend is supported - if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: + # Double-count the CK backend since we want to compare V2/V3 kernels + if ( + len(fused_attn_backends) + + int(IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends) + + flash_attn_supported + unfused_attn_supported + ) < 2: pytest.skip("Less than two backends to compare.") # UnfusedDotProductAttention backend @@ -1276,7 +1281,12 @@ def test_transformer_layer( flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends # Skip if only unfused backend is supported - if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: + # Double-count the CK backend since we want to compare V2/V3 kernels + if ( + len(fused_attn_backends) + + int(IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends) + + flash_attn_supported + unfused_attn_supported + ) < 2: pytest.skip("Less than two backends to compare.") # Skip if qkv_format = thd and "padding" not in attn_mask_type if qkv_format == "thd" and "padding" not in config.attn_mask_type: From 7ac6eaba7ddc6b45e69cb4349a50ff906c348f84 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 23 Feb 2026 09:17:52 -0600 Subject: [PATCH 4/4] Added IS_HIP_EXTENSION guard, and included CK/AITER comp in transformer test --- tests/pytorch/attention/test_attention.py | 48 ++++++++++++++++++----- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 6430efe91..e2818b299 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -369,7 +369,12 @@ def test_dot_product_attention( torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols) for i, _ in enumerate(fused_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols) - if fused_attn_supported and len(fused_attn_backends) == 1 and FusedAttnBackend["CK"] in fused_attn_backends: + if ( + fused_attn_supported and + len(fused_attn_backends) == 1 and + IS_HIP_EXTENSION and + FusedAttnBackend["CK"] in fused_attn_backends + ): logging.info("[test_dot_product_attention]: CK fused attn V2 vs V3") torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) for i, _ in enumerate(fused_attn_bwd): @@ -1320,6 +1325,20 @@ def test_transformer_layer( RoPE, is_training, ) + if IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends: + os.environ["NVTE_CK_USES_FWD_V3"] = "0" + os.environ["NVTE_CK_USES_BWD_V3"] = "0" + fused_attn_fwd_1, fused_attn_bwd_1 = _run_transformer_layer( + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_format, + workspace_opt, + fused_qkv_params, + RoPE, + is_training, + ) elif len(fused_attn_backends) == 2: os.environ["NVTE_FUSED_ATTN_CK"] = "0" os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1" @@ -1390,15 +1409,24 @@ def test_transformer_layer( logging.info("[test_transformer_layer]: fused attn vs flash attn") torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) - if IS_HIP_EXTENSION and fused_attn_supported and len(fused_attn_backends) == 2: - logging.info("[test_transformer_layer]: fused attn backend 0 vs 1") - torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) - for i, _ in enumerate(fused_attn_bwd): - torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) - logging.info("[test_transformer_layer]: fused attn backend 0 vs 2") - torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols) - for i, _ in enumerate(fused_attn_bwd): - torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols) + if IS_HIP_EXTENSION and fused_attn_supported: + if len(fused_attn_backends) == 2: + logging.info("[test_transformer_layer]: fused attn backend 0 vs 1") + torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) + for i, _ in enumerate(fused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) + logging.info("[test_transformer_layer]: fused attn backend 0 vs 2") + torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols) + for i, _ in enumerate(fused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols) + elif ( + len(fused_attn_backends) == 1 and + FusedAttnBackend["CK"] in fused_attn_backends + ): + logging.info("[test_dot_product_attention]: CK fused attn V2 vs V3") + torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) + for i, _ in enumerate(fused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")