Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion tuna/rocmlir/rocmlir_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,64 @@
from tuna.rocmlir.config_type import ConfigType
from tuna.rocmlir.tuning_space import TuningSpace

from hip import hip
#pylint: disable=too-few-public-methods

# Attention data types based on architecture
DATA_TYPES_ATTENTION_WMMA = ['i8', 'f16', 'bf16']
DATA_TYPES_ATTENTION_MFMA = ['i8', 'f32', 'f16', 'bf16']
DATA_TYPES_ATTENTION = None

def hipCheck(call_result):
err = call_result[0]
result = call_result[1:]
if len(result) == 1:
result = result[0]
if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess:
raise RuntimeError(str(err))
return result


def getArch() -> str:
agents = set()
device_count = hipCheck(hip.hipGetDeviceCount())
for device in range(device_count):
props = hip.hipDeviceProp_t()
hipCheck(hip.hipGetDeviceProperties(props, device))
agent = props.gcnArchName.decode('utf-8')
agents.add(agent)
if len(agents) > 1:
print(
f"WARNING: Found {len(agents)} different kinds of agents on the same machine : {', '.join(agents)}"
)
print(
"WARNING: Using the first agent by default. If you want to use a different agent, please set the HIP_VISIBLE_DEVICES environment variable."
)
# select first agent by default
return list(agents)[0]


def getChip():
arch = getArch()
chip = GFX_CHIP_RE.search(arch).group(0)
return chip

def initializeAttentionDatatypes():
"""Initialize attention data types based on architecture.
SYNC WITH perfRunner.py's initialize_dtypes_attn()
"""
global DATA_TYPES_ATTENTION
chip = getChip()
if chip.startswith('gfx9'):
DATA_TYPES_ATTENTION = DATA_TYPES_ATTENTION_MFMA
elif chip.startswith('gfx1'):
DATA_TYPES_ATTENTION = DATA_TYPES_ATTENTION_WMMA
else:
raise ValueError(f"Could not determine attention data types for architecture: {chip}")
return DATA_TYPES_ATTENTION

def getAttentionDatatypes():
return DATA_TYPES_ATTENTION

class SessionRocMLIR(BASE, SessionMixin):
"""Session table to keep track of tuning sessions"""
Expand Down Expand Up @@ -733,14 +789,15 @@ def get_configurations(self, filename):
"""Read attention-configs from filename and expand into all combinations of
type and transpose.
"""
datatypes = getAttentionDatatypes()

configs = []
with open(filename, 'r', encoding='utf8') as config_file:
lines = config_file.readlines()

# All combinations of types and transposition (A and B)
for datatype, transQ, transK, transV, transO, withAttnScale, withAttnBias, causal, return_lse, line in \
itertools.product(['f32', 'f16'], ['false', 'true'],
itertools.product(datatypes, ['false', 'true'],
['false', 'true'], ['false', 'true'],
['false', 'true'], ['false'], ['false'], ['false'], ['false'], lines):
line = line.strip()
Expand Down