diff --git a/tuna/rocmlir/rocmlir_tables.py b/tuna/rocmlir/rocmlir_tables.py index de29738f..a5416c79 100644 --- a/tuna/rocmlir/rocmlir_tables.py +++ b/tuna/rocmlir/rocmlir_tables.py @@ -51,8 +51,64 @@ from tuna.rocmlir.config_type import ConfigType from tuna.rocmlir.tuning_space import TuningSpace +from hip import hip #pylint: disable=too-few-public-methods +# Attention data types based on architecture +DATA_TYPES_ATTENTION_WMMA = ['i8', 'f16', 'bf16'] +DATA_TYPES_ATTENTION_MFMA = ['i8', 'f32', 'f16', 'bf16'] +DATA_TYPES_ATTENTION = None + +def hipCheck(call_result): + err = call_result[0] + result = call_result[1:] + if len(result) == 1: + result = result[0] + if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess: + raise RuntimeError(str(err)) + return result + + +def getArch() -> str: + agents = set() + device_count = hipCheck(hip.hipGetDeviceCount()) + for device in range(device_count): + props = hip.hipDeviceProp_t() + hipCheck(hip.hipGetDeviceProperties(props, device)) + agent = props.gcnArchName.decode('utf-8') + agents.add(agent) + if len(agents) > 1: + print( + f"WARNING: Found {len(agents)} different kinds of agents on the same machine : {', '.join(agents)}" + ) + print( + "WARNING: Using the first agent by default. If you want to use a different agent, please set the HIP_VISIBLE_DEVICES environment variable." + ) + # select first agent by default + return list(agents)[0] + + +def getChip(): + arch = getArch() + chip = GFX_CHIP_RE.search(arch).group(0) + return chip + +def initializeAttentionDatatypes(): + """Initialize attention data types based on architecture. + SYNC WITH perfRunner.py's initialize_dtypes_attn() + """ + global DATA_TYPES_ATTENTION + chip = getChip() + if chip.startswith('gfx9'): + DATA_TYPES_ATTENTION = DATA_TYPES_ATTENTION_MFMA + elif chip.startswith('gfx1'): + DATA_TYPES_ATTENTION = DATA_TYPES_ATTENTION_WMMA + else: + raise ValueError(f"Could not determine attention data types for architecture: {chip}") + return DATA_TYPES_ATTENTION + +def getAttentionDatatypes(): + return DATA_TYPES_ATTENTION class SessionRocMLIR(BASE, SessionMixin): """Session table to keep track of tuning sessions""" @@ -733,6 +789,7 @@ def get_configurations(self, filename): """Read attention-configs from filename and expand into all combinations of type and transpose. """ + datatypes = getAttentionDatatypes() configs = [] with open(filename, 'r', encoding='utf8') as config_file: @@ -740,7 +797,7 @@ def get_configurations(self, filename): # All combinations of types and transposition (A and B) for datatype, transQ, transK, transV, transO, withAttnScale, withAttnBias, causal, return_lse, line in \ - itertools.product(['f32', 'f16'], ['false', 'true'], + itertools.product(datatypes, ['false', 'true'], ['false', 'true'], ['false', 'true'], ['false', 'true'], ['false'], ['false'], ['false'], ['false'], lines): line = line.strip()