From 7413caa8979dee0127b6b227b7945f583c7f1c1d Mon Sep 17 00:00:00 2001 From: Djordje Antic Date: Wed, 3 Dec 2025 14:08:31 +0000 Subject: [PATCH 1/2] Sync rocmlir_tables for attn --- tuna/rocmlir/rocmlir_tables.py | 59 +++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/tuna/rocmlir/rocmlir_tables.py b/tuna/rocmlir/rocmlir_tables.py index de29738f..70c48c2d 100644 --- a/tuna/rocmlir/rocmlir_tables.py +++ b/tuna/rocmlir/rocmlir_tables.py @@ -51,8 +51,64 @@ from tuna.rocmlir.config_type import ConfigType from tuna.rocmlir.tuning_space import TuningSpace +from hip import hip #pylint: disable=too-few-public-methods +# Attention data types based on architecture +DATA_TYPES_ATTENTION_WMMA = ['i8', 'f16', 'bf16'] +DATA_TYPES_ATTENTION_MFMA = ['i8', 'f32', 'f16', 'bf16'] +DATA_TYPES_ATTENTION = None + +def hipCheck(call_result): + err = call_result[0] + result = call_result[1:] + if len(result) == 1: + result = result[0] + if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess: + raise RuntimeError(str(err)) + return result + + +def getArch() -> str: + agents = set() + device_count = hipCheck(hip.hipGetDeviceCount()) + for device in range(device_count): + props = hip.hipDeviceProp_t() + hipCheck(hip.hipGetDeviceProperties(props, device)) + agent = props.gcnArchName.decode('utf-8') + agents.add(agent) + if (len(agents) > 1): + print( + f"WARNING: Found {len(agents)} different kinds of agents on the same machine : {', '.join(agents)}" + ) + print( + "WARNING: Using the first agent by default. If you want to use a different agent, please set the HIP_VISIBLE_DEVICES environment variable." + ) + # select first agent by default + return list(agents)[0] + + +def getChip(): + arch = getArch() + chip = GFX_CHIP_RE.search(arch).group(0) + return chip + +def initializeAttentionDatatypes(): + """Initialize attention data types based on architecture. + SYNC WITH perfRunner.py's initialize_dtypes_attn() + """ + global DATA_TYPES_ATTENTION + chip = getChip() + if chip.startswith('gfx9'): + DATA_TYPES_ATTENTION = DATA_TYPES_ATTENTION_MFMA + else if chip.startswith('gfx1'): + DATA_TYPES_ATTENTION = DATA_TYPES_ATTENTION_WMMA + else: + raise ValueError(f"Could not determine attention data types for architecture: {arch_name}") + return DATA_TYPES_ATTENTION + +def getAttentionDatatypes(): + return DATA_TYPES_ATTENTION class SessionRocMLIR(BASE, SessionMixin): """Session table to keep track of tuning sessions""" @@ -733,6 +789,7 @@ def get_configurations(self, filename): """Read attention-configs from filename and expand into all combinations of type and transpose. """ + datatypes = getAttentionDatatypes() configs = [] with open(filename, 'r', encoding='utf8') as config_file: @@ -740,7 +797,7 @@ def get_configurations(self, filename): # All combinations of types and transposition (A and B) for datatype, transQ, transK, transV, transO, withAttnScale, withAttnBias, causal, return_lse, line in \ - itertools.product(['f32', 'f16'], ['false', 'true'], + itertools.product(datatypes, ['false', 'true'], ['false', 'true'], ['false', 'true'], ['false', 'true'], ['false'], ['false'], ['false'], ['false'], lines): line = line.strip() From bf3acc95dae8e7864c7b30e83b2a1a0a200ce371 Mon Sep 17 00:00:00 2001 From: Djordje Antic Date: Wed, 3 Dec 2025 14:13:22 +0000 Subject: [PATCH 2/2] Use MITunas format --- tuna/rocmlir/rocmlir_tables.py | 56 +++++++++++++++++----------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tuna/rocmlir/rocmlir_tables.py b/tuna/rocmlir/rocmlir_tables.py index 70c48c2d..a5416c79 100644 --- a/tuna/rocmlir/rocmlir_tables.py +++ b/tuna/rocmlir/rocmlir_tables.py @@ -60,38 +60,38 @@ DATA_TYPES_ATTENTION = None def hipCheck(call_result): - err = call_result[0] - result = call_result[1:] - if len(result) == 1: - result = result[0] - if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess: - raise RuntimeError(str(err)) - return result + err = call_result[0] + result = call_result[1:] + if len(result) == 1: + result = result[0] + if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess: + raise RuntimeError(str(err)) + return result def getArch() -> str: - agents = set() - device_count = hipCheck(hip.hipGetDeviceCount()) - for device in range(device_count): - props = hip.hipDeviceProp_t() - hipCheck(hip.hipGetDeviceProperties(props, device)) - agent = props.gcnArchName.decode('utf-8') - agents.add(agent) - if (len(agents) > 1): - print( - f"WARNING: Found {len(agents)} different kinds of agents on the same machine : {', '.join(agents)}" - ) - print( - "WARNING: Using the first agent by default. If you want to use a different agent, please set the HIP_VISIBLE_DEVICES environment variable." - ) - # select first agent by default - return list(agents)[0] + agents = set() + device_count = hipCheck(hip.hipGetDeviceCount()) + for device in range(device_count): + props = hip.hipDeviceProp_t() + hipCheck(hip.hipGetDeviceProperties(props, device)) + agent = props.gcnArchName.decode('utf-8') + agents.add(agent) + if len(agents) > 1: + print( + f"WARNING: Found {len(agents)} different kinds of agents on the same machine : {', '.join(agents)}" + ) + print( + "WARNING: Using the first agent by default. If you want to use a different agent, please set the HIP_VISIBLE_DEVICES environment variable." + ) + # select first agent by default + return list(agents)[0] def getChip(): - arch = getArch() - chip = GFX_CHIP_RE.search(arch).group(0) - return chip + arch = getArch() + chip = GFX_CHIP_RE.search(arch).group(0) + return chip def initializeAttentionDatatypes(): """Initialize attention data types based on architecture. @@ -101,10 +101,10 @@ def initializeAttentionDatatypes(): chip = getChip() if chip.startswith('gfx9'): DATA_TYPES_ATTENTION = DATA_TYPES_ATTENTION_MFMA - else if chip.startswith('gfx1'): + elif chip.startswith('gfx1'): DATA_TYPES_ATTENTION = DATA_TYPES_ATTENTION_WMMA else: - raise ValueError(f"Could not determine attention data types for architecture: {arch_name}") + raise ValueError(f"Could not determine attention data types for architecture: {chip}") return DATA_TYPES_ATTENTION def getAttentionDatatypes():