diff --git a/api/common/launch.py b/api/common/launch.py index 69e29dd2ab..7e5f531cfc 100644 --- a/api/common/launch.py +++ b/api/common/launch.py @@ -26,10 +26,15 @@ def is_ampere_gpu(): stdout, exit_code = system.run_command("nvidia-smi -L") if exit_code == 0: gpu_list = stdout.split("\n") + # have nvidia gpu if len(gpu_list) >= 1: - #print(gpu_list[0]) - # GPU 0: NVIDIA A100-SXM4-40GB (UUID: xxxx) - return gpu_list[0].find("A100") > 0 + stdout, exit_code = system.run_command("nvidia-smi --query-gpu=compute_cap --format=csv -i=0") + if exit_code == 0: + compute_cap_list = stdout.split("\n") + compute_cap = float(compute_cap_list[1]) + #Capability for ampere is 8.x, Ada lovelace is 8.9, H100 is 9.0 + if compute_cap>8: + return True return False