diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 3f51b96b6..0bd062f6b 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -16,6 +16,34 @@ namespace ck_fused_attn{ +// We want to cache and reuse the log stream so we use thread_local here. +namespace { +std::ostream* get_bwd_log_stream() { + thread_local std::ofstream log_file; + thread_local bool attempted = false; + thread_local bool opened = false; + thread_local bool requested = false; + thread_local std::string log_dir_str; + if (!attempted) { + attempted = true; + if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG")) { + log_dir_str = std::string(env_p); + requested = !log_dir_str.empty() && log_dir_str != "0"; + } + if (requested) { + opened = open_ck_fused_attn_log_file(log_file, "ck_fused_attn_bwd", log_dir_str); + } + } + if (!requested) { + return nullptr; + } + if (!opened) { + return &std::cout; + } + return &log_file; +} +} // namespace + // TODO: unify with binary search in TE/common/fused_attn(rocm)/util // no device std::upper_bound // in an increasing array with given size len, search for the index that: @@ -346,110 +374,104 @@ void log_bwd_config(const char* func_name, const bool is_v3_atomic_fp32, const int how_v3_bf16_cvt, const fmha_bwd_args& fmha_args){ - - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } - if (ck_fused_attn_log_config) { - std::cout<::type>(mask_type)<::type>(bias_type)<::type>(mask_type) << "\n"; + *log_file << "bias_type: " << static_cast::type>(bias_type) << "\n"; + *log_file << "has_dbias: " << has_dbias << "\n"; + *log_file << "has_dropout: " << has_dropout << "\n"; + *log_file << "is_store_randval: " << is_store_randval << "\n"; + *log_file << "is_deterministic: " << is_deterministic << "\n"; + *log_file << "uses_bwd_v3: " << uses_bwd_v3 << "\n"; + *log_file << "is_v3_atomic_fp32: " << is_v3_atomic_fp32 << "\n"; + *log_file << "how_v3_bf16_cvt: " << how_v3_bf16_cvt << "\n"; // fmha_args debug - std::cout<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset)) << "\n"; + *log_file << "dropout_offset_ptr: " << std::get<1>(std::get>(fmha_args.drop_seed_offset)) << "\n"; } } @@ -529,15 +551,10 @@ hipError_t ck_attn_bwd( right = window_size_right; mask_enum mask_type = static_cast(attn_mask_type); - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_bwd_log_stream() != nullptr}; ck_tile::index_t shape_seqlen_q = seqlen_q; ck_tile::index_t shape_seqlen_k = seqlen_k; @@ -707,18 +724,18 @@ hipError_t ck_attn_bwd( dim3 grid(b, s_kv, hg); if (d_qk == d_v) { dim3 block(d_qk); - if (ck_fused_attn_log_config){ - std::cout<(dbias_expanded_ptr), static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::k1HSS){ - if (ck_fused_attn_log_config){ - std::cout<(dbias_expanded_ptr), static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::kB1SS){ - if (ck_fused_attn_log_config){ - std::cout<(attn_mask_type); - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_bwd_log_stream() != nullptr}; std::string data_type_str = get_data_type_str(dtype); @@ -1034,8 +1046,9 @@ hipError_t ck_attn_varlen_bwd( // lse_thd_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { if(std::string(env_p) == "1"){ - if(ck_fused_attn_log_config){ - std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if (auto* log_file = get_bwd_log_stream()) { + *log_file + << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.\n"; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); @@ -1068,18 +1081,18 @@ hipError_t ck_attn_varlen_bwd( dim3 grid(max_tokens_kv, hg); if (d_qk == d_v) { dim3 block(d_qk); - if (ck_fused_attn_log_config){ - std::cout<::type>(mask_type)<::type>(bias_type)<::type>(mask_type) << "\n"; + *log_file << "bias_type: " << static_cast::type>(bias_type) << "\n"; + *log_file << "has_lse: " << has_lse << "\n"; + *log_file << "has_dropout: " << has_dropout << "\n"; + *log_file << "do_fp8_static_quant: " << do_fp8_static_quant << "\n"; + *log_file << "skip_min_seqlen_q: " << (fmha_args.min_seqlen_q != 0) << "\n"; + *log_file << "uses_fwd_v3: " << uses_fwd_v3 << "\n"; + *log_file << "how_v3_bf16_cvt: " << how_v3_bf16_cvt << "\n"; // debug fmha_args - std::cout<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset)) << "\n"; + *log_file << "dropout_offset_ptr: " << std::get<1>(std::get>(fmha_args.drop_seed_offset)) << "\n"; } } @@ -179,14 +201,9 @@ hipError_t ck_attn_fwd( right = window_size_right; mask_enum mask_type = static_cast(attn_mask_type); - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_fwd_log_stream() != nullptr}; std::string data_type_str = get_data_type_str(dtype); @@ -354,14 +371,9 @@ hipError_t ck_attn_varlen_fwd( bias_enum bias_type = bias_enum::no_bias; - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_fwd_log_stream() != nullptr}; std::string data_type_str = get_data_type_str(dtype); @@ -457,8 +469,9 @@ hipError_t ck_attn_varlen_fwd( // lse_thd_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ if(std::string(env_p) == "1"){ - if(ck_fused_attn_log_config){ - std::cout << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if (auto* log_file = get_fwd_log_stream()) { + *log_file + << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.\n"; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, cu_seqlen_q_padded_ptr, lse_thd_ptr, stream); } diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index 26c92ca2b..6bbfbda4f 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -1,10 +1,14 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ #include +#include +#include +#include +#include #include "ck_fused_attn_utils.hpp" #include "ck_fused_attn/ck_fused_attn.hpp" #include "mask.hpp" @@ -13,6 +17,22 @@ namespace ck_fused_attn{ +bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix, const std::string& log_dir_str) { + // Explicitly use std::cout as a fallback + if (log_dir_str == "1") { + return false; + } + std::filesystem::path log_dir(log_dir_str); + std::ostringstream filename; + filename << file_prefix << "_" << getpid() << "_" << std::this_thread::get_id() << ".log"; + log_file.open(log_dir / filename.str(), std::ios_base::app); + if (!log_file.is_open()) { + std::cerr << "Failed to open log file: " << (log_dir / filename.str()) << "\n"; + return false; + } + return true; +} + std::string get_data_type_str(DType dtype){ std::string data_type_str; if(dtype==DType::kFloat16){ diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index a75915ee2..a0ea13d81 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -7,8 +7,9 @@ #ifndef CK_FUSED_ATTN_UTILS_H #define CK_FUSED_ATTN_UTILS_H -#include -#include +#include +#include +#include #include //forward declaration for ck_tile enum @@ -56,5 +57,7 @@ std::pair get_ck_bias_type_shape(BiasType attn_bias_type, uint64_t get_runtime_max_seqlen(uint64_t b, const void* cu_seqlen_ptr, const void* cu_seqlen_padded_ptr, void* workspace, hipStream_t stream); +bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix, const std::string& log_dir_str); + }//namespace ck_fused_attn #endif // CK_FUSED_ATTN_UTILS_H