diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index bbd16ba0e..e2818b299 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -240,7 +240,12 @@ def test_dot_product_attention( flash_attn_supported = True # Skip if only unfused backend is supported - if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: + # Double-count the CK backend since we want to compare V2/V3 kernels + if ( + len(fused_attn_backends) + + int(IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends) + + flash_attn_supported + unfused_attn_supported + ) < 2: pytest.skip("Less than two backends to compare.") # UnfusedDotProductAttention backend @@ -269,7 +274,21 @@ def test_dot_product_attention( pad_between_seqs, is_training, ) - if len(fused_attn_backends) == 2: + # We can consider the CK backend as being two, since we have V2/V3 kernels + if IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends: + os.environ["NVTE_CK_USES_FWD_V3"] = "0" + os.environ["NVTE_CK_USES_BWD_V3"] = "0" + fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + pad_between_seqs, + is_training, + ) + elif len(fused_attn_backends) == 2: os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_CK"] = "0" os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1" @@ -286,8 +305,6 @@ def test_dot_product_attention( os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" os.environ["NVTE_FUSED_ATTN_CK"] = "1" os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0" - os.environ["NVTE_CK_USES_FWD_V3"] = "1" - os.environ["NVTE_CK_USES_BWD_V3"] = "1" fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( dtype, config, @@ -352,6 +369,16 @@ def test_dot_product_attention( torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols) for i, _ in enumerate(fused_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols) + if ( + fused_attn_supported and + len(fused_attn_backends) == 1 and + IS_HIP_EXTENSION and + FusedAttnBackend["CK"] in fused_attn_backends + ): + logging.info("[test_dot_product_attention]: CK fused attn V2 vs V3") + torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) + for i, _ in enumerate(fused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @@ -1259,7 +1286,12 @@ def test_transformer_layer( flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends # Skip if only unfused backend is supported - if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: + # Double-count the CK backend since we want to compare V2/V3 kernels + if ( + len(fused_attn_backends) + + int(IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends) + + flash_attn_supported + unfused_attn_supported + ) < 2: pytest.skip("Less than two backends to compare.") # Skip if qkv_format = thd and "padding" not in attn_mask_type if qkv_format == "thd" and "padding" not in config.attn_mask_type: @@ -1293,6 +1325,20 @@ def test_transformer_layer( RoPE, is_training, ) + if IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends: + os.environ["NVTE_CK_USES_FWD_V3"] = "0" + os.environ["NVTE_CK_USES_BWD_V3"] = "0" + fused_attn_fwd_1, fused_attn_bwd_1 = _run_transformer_layer( + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_format, + workspace_opt, + fused_qkv_params, + RoPE, + is_training, + ) elif len(fused_attn_backends) == 2: os.environ["NVTE_FUSED_ATTN_CK"] = "0" os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1" @@ -1363,15 +1409,24 @@ def test_transformer_layer( logging.info("[test_transformer_layer]: fused attn vs flash attn") torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) - if IS_HIP_EXTENSION and fused_attn_supported and len(fused_attn_backends) == 2: - logging.info("[test_transformer_layer]: fused attn backend 0 vs 1") - torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) - for i, _ in enumerate(fused_attn_bwd): - torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) - logging.info("[test_transformer_layer]: fused attn backend 0 vs 2") - torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols) - for i, _ in enumerate(fused_attn_bwd): - torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols) + if IS_HIP_EXTENSION and fused_attn_supported: + if len(fused_attn_backends) == 2: + logging.info("[test_transformer_layer]: fused attn backend 0 vs 1") + torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) + for i, _ in enumerate(fused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) + logging.info("[test_transformer_layer]: fused attn backend 0 vs 2") + torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols) + for i, _ in enumerate(fused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols) + elif ( + len(fused_attn_backends) == 1 and + FusedAttnBackend["CK"] in fused_attn_backends + ): + logging.info("[test_dot_product_attention]: CK fused attn V2 vs V3") + torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) + for i, _ in enumerate(fused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")