diff --git a/tuna/rocmlir/config_type.py b/tuna/rocmlir/config_type.py index f55fe4ac..e4c965d2 100644 --- a/tuna/rocmlir/config_type.py +++ b/tuna/rocmlir/config_type.py @@ -33,6 +33,7 @@ class ConfigType(Enum): convolution: str = 'convolution' gemm: str = 'gemm' attention: str = 'attention' + gemm_gemm: str = 'gemm_gemm' def __str__(self) -> str: return self.value diff --git a/tuna/rocmlir/rocmlir_worker.py b/tuna/rocmlir/rocmlir_worker.py index 8c9ab951..30e31fe6 100644 --- a/tuna/rocmlir/rocmlir_worker.py +++ b/tuna/rocmlir/rocmlir_worker.py @@ -180,6 +180,8 @@ def run_cmd(self): special_args = "--operation gemm" elif self.dbt.config_type == ConfigType.attention: special_args = "--operation attention --verify-mode none" + elif self.dbt.config_type == ConfigType.gemm_gemm: + special_args = "--operation gemm_gemm" else: raise ValueError(f"Config type {self.dbt.config_type} not yet supported.") if self.dbt.session.tuning_space: