Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f02c48e
Initial inclusion of new API in fwd as well as part 1 of refactor
Micky774 Feb 2, 2026
0b0ad93
Initial implementation of refactor/API update across ALL CK funcs
Micky774 Feb 3, 2026
c198cbd
Updated logging
Micky774 Feb 6, 2026
a52bb32
Add script for comparing AITER/TE API
Micky774 Feb 6, 2026
77f0a05
Reconcile new AITER mask type
Micky774 Feb 9, 2026
1637266
Updated API helper tool
Micky774 Feb 9, 2026
568e9b5
Merge branch 'dev' into zain/aiter-api
Micky774 Feb 9, 2026
2cb6d82
Formatting
Micky774 Feb 9, 2026
cf4aa9e
Added sys exit to script
Micky774 Feb 9, 2026
e25cea8
Slightly better error message
Micky774 Feb 9, 2026
2122479
Updated AITER_ASM_DIR implementation
Micky774 Feb 11, 2026
4817e72
Update AITER
Micky774 Feb 11, 2026
837b827
Updated AITER_ASM_DIR logic to allow for hip-free use
Micky774 Feb 12, 2026
68ca0fe
Re-introduce setup AITER API check
Micky774 Feb 16, 2026
ae688ab
Update AITER to custom feature branch
Micky774 Feb 16, 2026
762b91b
Reduce AITER build verbosity
Micky774 Feb 16, 2026
0a7187d
Updated API
Micky774 Feb 16, 2026
39b27bc
Address PR comments
Micky774 Feb 17, 2026
29878cf
Updated bias stride calculations
Micky774 Feb 17, 2026
6846a27
Merge branch 'dev' into zain/aiter-api
Micky774 Feb 18, 2026
47592ac
Reverted AITER feature branch use due to verbosity changes
Micky774 Feb 18, 2026
357b5ce
PR review comments
Micky774 Feb 18, 2026
1f080c1
Reintroduced warning suppression in AITER
Micky774 Feb 18, 2026
a657bdd
Removes auto-setting of AITER_LOG_MORE, corrects batch stride impl
Micky774 Feb 18, 2026
c225448
Removes AITER_LOG_MORE from CI runs
Micky774 Feb 18, 2026
4193158
Minor corrections
Micky774 Feb 18, 2026
dbb6106
PR feedback
Micky774 Feb 19, 2026
899162e
Formatting
Micky774 Feb 19, 2026
1081c5e
Copyright
Micky774 Feb 19, 2026
9514855
Merge branch 'dev' into zain/aiter-api
Micky774 Feb 19, 2026
5fb6bc3
Updated ck_fused_attn lib build to include copying HSA
Micky774 Feb 20, 2026
7adfb0d
Build update
Micky774 Feb 23, 2026
a8d98e6
Removed old file
Micky774 Feb 23, 2026
c69efa2
Added check args script
Micky774 Feb 23, 2026
badd286
Updated CMake with explicit permissions
Micky774 Feb 23, 2026
efb4204
Revert to working commit
Micky774 Feb 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/aiter
Submodule aiter updated 1180 files
11 changes: 11 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from importlib import metadata
import os
import sys
import time
from pathlib import Path
from typing import List, Tuple
Expand Down Expand Up @@ -88,6 +89,16 @@ def setup_common_extension() -> CMakeExtension:
cmake_flags.append("-DUSE_FUSED_ATTN_CK=OFF")
elif os.getenv("NVTE_FUSED_ATTN_CK") or os.getenv("NVTE_FUSED_ATTN"):
cmake_flags.append("-DUSE_FUSED_ATTN_CK=ON")
# This helper script scans semi-hard-coded files wrt TE source-code in order to directly compare
# AITER's internal API and our attempt at utilizing it.
try:
subprocess.run(
sys.executable + " tools/check_aiter_mha_args_usage.py --mode both",
shell=True, check=True
)
except subprocess.CalledProcessError:
print("Error checking AITER mha_args usage.")
sys.exit(1)

if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))) and os.getenv("NVTE_ENABLE_ROCSHMEM") is None:
os.environ["NVTE_ENABLE_ROCSHMEM"] = '1'
Expand Down
106 changes: 106 additions & 0 deletions tools/check_aiter_mha_args_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.

"""
This script is run during setup through setup.py, and can be run independently
to check that the fields defined in the mha_{fwd,bwd}_args structs in the AITER
headers are correctly referenced in the source code.
"""

import argparse
import re
from pathlib import Path
from typing import List, Set
import sys

def parse_with_skip_comments(buffer, line, regex, outputs):
# skip comments
stripped = line.strip()
if not stripped or stripped.startswith("//"):
return
line_no_comment = re.sub(r"//.*", "", line)
buffer[0] += " " + line_no_comment.strip()
if ";" not in line_no_comment:
return
match = regex.search(buffer[0])
if match:
outputs.append(match.group(1))
buffer[0] = ""


def extract_fields_from_header(text: str, struct_name: str) -> List[str]:
struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$")
struct_end_re = re.compile(r"^\s*};\s*$")

struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b")
lines = text.splitlines()
in_struct = False
fields: List[str] = []
buffer = [""]
for line in lines:
if not in_struct:
if struct_start_re.search(line):
in_struct = True
continue
if struct_end_re.search(line):
break
parse_with_skip_comments(buffer, line, struct_field_re, fields)
return fields


def extract_usage_from_source(text: str, var_name: str) -> Set[str]:
assign_re = re.compile(rf"\b{re.escape(var_name)}\.([A-Za-z_][A-Za-z0-9_]*)\b\s*=")
assignments = []
lines = text.splitlines()
buffer = [""]
for line in lines:
parse_with_skip_comments(buffer, line, assign_re, assignments)
return set(assignments)


def main() -> int:
parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition")
parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both", help="Mode: fwd, bwd, or both")
parser.add_argument("--te-dir", type=Path, default=Path(__file__).parent.parent, help="Root directory of TransformerEngine")
args = parser.parse_args()
modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode]
mismatch = 0
for mode in modes:
header_path = args.te_dir / f"3rdparty/aiter/csrc/include/mha_{mode}.h"
source_path = args.te_dir / f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp"
header_text = header_path.read_text(encoding="utf-8")
source_text = source_path.read_text(encoding="utf-8")

header_fields = extract_fields_from_header(header_text, f"mha_{mode}_args")
header_set = set(header_fields)
used_fields = extract_usage_from_source(source_text, f"fmha_args")

missing_in_usage = sorted(header_set - used_fields)
unknown_in_header = sorted(used_fields - header_set)
mismatch += len(missing_in_usage) + len(unknown_in_header)

print(f"\nAnalyzing mha_{mode}_args\n")
print(f"mha_{mode}_args fields in header:", len(header_set))
print(f"mha_{mode}_args fields referenced in source:", len(used_fields))

if missing_in_usage:
print("\nFields present in header but not referenced in source:")
for name in missing_in_usage:
print(f" - {name}")
else:
print("\nAll header fields are referenced in source.")

if unknown_in_header:
print("\nFields referenced in source but not in header:")
for name in unknown_in_header:
print(f" - {name}")
else:
print("\nNo unknown fields referenced in source.")

if mismatch:
print(f"\nTotal mismatched fields: {mismatch}")
return 1
return 0


if __name__ == "__main__":
sys.exit(main())
1 change: 0 additions & 1 deletion transformer_engine/common/ck_fused_attn/aiter_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ if [[ -z "${AITER_DIR}" || -z "${AITER_TEST_DIR}" || -z "${GPU_ARCHS_VAL}" ]]; t
fi

rm -rf "${AITER_DIR}/aiter/jit/build"
AITER_LOG_MORE=1 \
CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT="${CK_TILE_BF16_DEFAULT}" \
GPU_ARCHS="${GPU_ARCHS_VAL}" \
python3 "${AITER_TEST_DIR}/compile.py"
Expand Down
109 changes: 109 additions & 0 deletions transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.


"""
This script is run during setup through setup.py, and can be run independently
to check that the fields defined in the mha_{fwd,bwd}_args structs in the AITER
headers are correctly referenced in the source code.
"""

import argparse
import re
from pathlib import Path
from typing import List, Set
import sys

def parse_with_skip_comments(buffer, line, regex, outputs):
# skip comments
stripped = line.strip()
if not stripped or stripped.startswith("//"):
return
line_no_comment = re.sub(r"//.*", "", line)
buffer[0] += " " + line_no_comment.strip()
if ";" not in line_no_comment:
return
match = regex.search(buffer[0])
if match:
outputs.append(match.group(1))
buffer[0] = ""


def extract_fields_from_header(text: str, struct_name: str) -> List[str]:
struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$")
struct_end_re = re.compile(r"^\s*};\s*$")

struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b")
lines = text.splitlines()
in_struct = False
fields: List[str] = []
buffer = [""]
for line in lines:
if not in_struct:
if struct_start_re.search(line):
in_struct = True
continue
if struct_end_re.search(line):
break
parse_with_skip_comments(buffer, line, struct_field_re, fields)
return fields


def extract_usage_from_source(text: str, var_name: str) -> Set[str]:
assign_re = re.compile(rf"\b{re.escape(var_name)}\.([A-Za-z_][A-Za-z0-9_]*)\b\s*=")
assignments = []
lines = text.splitlines()
buffer = [""]
for line in lines:
parse_with_skip_comments(buffer, line, assign_re, assignments)
return set(assignments)


def main() -> int:
parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition")
parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both", help="Mode: fwd, bwd, or both")
parser.add_argument("--te-dir", type=Path, default=Path(__file__).parent.parent.parent.parent, help="Root directory of TransformerEngine")
args = parser.parse_args()
modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode]
mismatch = 0
for mode in modes:
header_path = args.te_dir / f"3rdparty/aiter/csrc/include/mha_{mode}.h"
source_path = args.te_dir / f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp"
header_text = header_path.read_text(encoding="utf-8")
source_text = source_path.read_text(encoding="utf-8")

header_fields = extract_fields_from_header(header_text, f"mha_{mode}_args")
header_set = set(header_fields)
used_fields = extract_usage_from_source(source_text, f"fmha_args")

missing_in_usage = sorted(header_set - used_fields)
unknown_in_header = sorted(used_fields - header_set)
mismatch += len(missing_in_usage) + len(unknown_in_header)

print(f"\nAnalyzing mha_{mode}_args\n")
print(f"mha_{mode}_args fields in header:", len(header_set))
print(f"mha_{mode}_args fields referenced in source:", len(used_fields))

if missing_in_usage:
print("\nFields present in header but not referenced in source:")
for name in missing_in_usage:
print(f" - {name}")
else:
print("\nAll header fields are referenced in source.")

if unknown_in_header:
print("\nFields referenced in source but not in header:")
for name in unknown_in_header:
print(f" - {name}")
else:
print("\nNo unknown fields referenced in source.")

if mismatch:
print(f"\nTotal mismatched fields: {mismatch}")
return 1
return 0


if __name__ == "__main__":
sys.exit(main())
Loading