From 002a9632d7135f1fed50301da3aae1cf42f21e3d Mon Sep 17 00:00:00 2001
From: Djordje Antic
Date: Wed, 30 Jul 2025 08:23:38 +0000
Subject: [PATCH 1/7] Sync rocMLIR's branch with develop
Signed-off-by: Djordje Antic
---
.coveragerc | 6 +-
.jenkins/export_db | 2 +-
.jenkins/init_session | 30 +-
.jenkins/load_jobs | 2 +-
.jenkins/perf_compile | 8 +-
.jenkins/perf_eval | 11 +-
.jenkins/query_config | 4 +-
.jenkins/update_solvers | 7 +-
.readthedocs.yaml | 32 +
Dockerfile | 81 ++-
Dockerfile.flask | 2 +-
Jenkinsfile | 46 +-
README.md | 326 +++++++---
build_docker.sh | 4 +-
doc/GoldDB.md | 22 -
doc/findocs.rst | 2 -
doc/flaskapp.rst | 4 -
doc/src/FinDocs.md | 18 -
doc/src/FlaskAppDoc.md | 83 ---
doc/src/TuningCycle.md | 171 -----
doc/tuningcycle.rst | 4 -
docker-compose-flower.yaml | 33 +
docker-compose-flower_rabbitmq.yaml | 34 +
docker-compose.yaml | 43 ++
{doc => docs}/Doxyfile | 2 +-
{doc => docs}/Makefile | 0
{doc => docs}/conf.py | 9 +-
docs/findocs.rst | 2 +
{doc => docs}/index.rst | 6 +-
{doc => docs}/readme.rst | 3 -
docs/sphinx/requirements.txt | 72 ++
{doc => docs/src}/DBVersioning.md | 22 +-
docs/tuning.rst | 10 +
flaskapp/templates/display_keys.html | 27 -
flaskapp/templates/input-form.html | 5 -
flaskapp/views/example_grafana.py | 81 ---
flaskapp/views/fdb_key.py | 72 --
flaskapp/views/grafana.py | 87 ---
requirements.txt | 47 +-
setup.py | 5 +-
tests/dummy_machine.py | 3 +
tests/test_abort_file.py | 3 +-
tests/test_add_session.py | 6 +-
tests/test_add_session_rocmlir.py | 1 -
tests/test_celery.py | 314 +++++++++
tests/test_dbBase.py | 1 -
tests/test_driver.py | 85 +--
tests/test_example.py | 56 +-
tests/test_export_db.py | 2 +-
tests/test_fin_builder.py | 139 ++--
tests/test_fin_class.py | 1 -
tests/test_fin_evaluator.py | 235 +++----
tests/test_fin_utils.py | 5 +-
tests/test_helper.py | 44 ++
tests/test_importconfigs.py | 2 +-
tests/test_load_job.py | 8 +-
tests/test_mituna_interface.py | 58 ++
tests/test_rocmlir.py | 2 +-
tests/test_update_golden.py | 2 +-
tests/test_worker.py | 3 +-
tests/utils.py | 105 ++-
tuna/celery_app/README.md | 85 +++
{flaskapp => tuna/celery_app}/__init__.py | 5 +-
tuna/celery_app/celery_app.py | 150 +++++
tuna/celery_app/celery_workers.py | 103 +++
tuna/celery_app/utility.py | 45 ++
flaskapp/setup.py => tuna/custom_errors.py | 20 +-
tuna/db/session_mixin.py | 2 +-
tuna/driver.py | 65 ++
tuna/example/README.md | 53 --
.../example/build_schema.py | 25 +-
tuna/example/celery_tuning/celery_tasks.py | 76 +++
tuna/example/doc/Tuning.md | 78 +++
tuna/example/example_lib.py | 125 +++-
tuna/example/example_tables.py | 4 +-
tuna/example/example_worker.py | 38 +-
tuna/example/load_job.py | 1 -
tuna/flask_example.py | 64 --
tuna/go_fish.py | 31 +-
tuna/lib_utils.py | 2 +-
tuna/libraries.py | 9 +
tuna/machine.py | 18 +-
tuna/miopen/celery_tuning/celery_tasks.py | 94 +++
tuna/miopen/db/__init__.py | 0
tuna/miopen/db/batch_norm_tables.py | 151 +++++
tuna/miopen/db/benchmark.py | 27 +
tuna/miopen/db/bn_golden_tables.py | 46 ++
tuna/miopen/db/build_schema.py | 4 +-
tuna/miopen/db/convolutionjob_tables.py | 235 +++++++
tuna/miopen/db/fusion_config_tables.py | 78 +++
tuna/miopen/db/get_db_tables.py | 56 ++
tuna/miopen/db/miopen_tables.py | 613 +-----------------
tuna/miopen/db/mixin_tables.py | 171 +++++
tuna/miopen/db/session.py | 3 +-
tuna/miopen/db/solver.py | 82 +++
tuna/miopen/db/tables.py | 27 +-
tuna/miopen/db/tensortable.py | 55 ++
tuna/miopen/doc/FinDocs.md | 21 +
tuna/miopen/doc/Tuning.md | 250 +++++++
tuna/miopen/driver/base.py | 84 ++-
tuna/miopen/driver/batchnorm.py | 10 +-
tuna/miopen/driver/convolution.py | 41 +-
tuna/miopen/metadata.py | 5 +
tuna/miopen/miopen_lib.py | 488 ++++++++++++--
tuna/miopen/scripts/query_db.py | 2 +-
tuna/miopen/scripts/report.py | 2 +-
tuna/miopen/subcmd/export_db.py | 165 ++++-
tuna/miopen/subcmd/import_configs.py | 16 -
tuna/miopen/subcmd/import_db.py | 3 +-
tuna/miopen/subcmd/load_job.py | 16 +-
tuna/miopen/subcmd/update_golden.py | 20 +-
tuna/miopen/utils/config_type.py | 3 +
tuna/miopen/utils/helper.py | 41 +-
tuna/miopen/utils/json_to_sql.py | 312 +++++++++
tuna/miopen/utils/lib_helper.py | 44 ++
tuna/miopen/utils/metadata.py | 2 +
tuna/miopen/worker/fin_builder.py | 95 +--
tuna/miopen/worker/fin_class.py | 467 ++++---------
tuna/miopen/worker/fin_eval.py | 207 ++----
tuna/mituna_interface.py | 539 +++++++++++++--
tuna/parse_args.py | 38 +-
tuna/rocmlir/config_type.py | 10 +-
tuna/rocmlir/rocmlir_lib.py | 32 +-
tuna/rocmlir/rocmlir_tables.py | 22 +-
tuna/rocmlir/rocmlir_worker.py | 3 +-
tuna/rocmlir/tuning_space.py | 10 +-
flaskapp/app.py => tuna/utils/celery_utils.py | 30 +-
tuna/utils/db_utility.py | 101 +--
tuna/utils/logger.py | 47 +-
.../{miopen_utility.py => machine_utility.py} | 2 +-
tuna/utils/utility.py | 34 +
tuna/worker_interface.py | 133 ++--
vars/utils.groovy | 439 ++++++++-----
133 files changed, 5786 insertions(+), 3089 deletions(-)
create mode 100644 .readthedocs.yaml
delete mode 100644 doc/GoldDB.md
delete mode 100644 doc/findocs.rst
delete mode 100644 doc/flaskapp.rst
delete mode 100644 doc/src/FinDocs.md
delete mode 100644 doc/src/FlaskAppDoc.md
delete mode 100644 doc/src/TuningCycle.md
delete mode 100644 doc/tuningcycle.rst
create mode 100644 docker-compose-flower.yaml
create mode 100644 docker-compose-flower_rabbitmq.yaml
create mode 100644 docker-compose.yaml
rename {doc => docs}/Doxyfile (99%)
rename {doc => docs}/Makefile (100%)
rename {doc => docs}/conf.py (94%)
create mode 100644 docs/findocs.rst
rename {doc => docs}/index.rst (87%)
rename {doc => docs}/readme.rst (61%)
create mode 100644 docs/sphinx/requirements.txt
rename {doc => docs/src}/DBVersioning.md (63%)
create mode 100644 docs/tuning.rst
delete mode 100644 flaskapp/templates/display_keys.html
delete mode 100644 flaskapp/templates/input-form.html
delete mode 100644 flaskapp/views/example_grafana.py
delete mode 100644 flaskapp/views/fdb_key.py
delete mode 100644 flaskapp/views/grafana.py
create mode 100644 tests/test_celery.py
create mode 100644 tests/test_helper.py
create mode 100644 tests/test_mituna_interface.py
create mode 100644 tuna/celery_app/README.md
rename {flaskapp => tuna/celery_app}/__init__.py (91%)
create mode 100644 tuna/celery_app/celery_app.py
create mode 100644 tuna/celery_app/celery_workers.py
create mode 100644 tuna/celery_app/utility.py
rename flaskapp/setup.py => tuna/custom_errors.py (82%)
create mode 100644 tuna/driver.py
delete mode 100644 tuna/example/README.md
rename flaskapp/example_app.py => tuna/example/build_schema.py (73%)
mode change 100644 => 100755
create mode 100644 tuna/example/celery_tuning/celery_tasks.py
create mode 100644 tuna/example/doc/Tuning.md
delete mode 100644 tuna/flask_example.py
create mode 100644 tuna/miopen/celery_tuning/celery_tasks.py
create mode 100644 tuna/miopen/db/__init__.py
create mode 100644 tuna/miopen/db/batch_norm_tables.py
create mode 100644 tuna/miopen/db/bn_golden_tables.py
create mode 100644 tuna/miopen/db/convolutionjob_tables.py
create mode 100644 tuna/miopen/db/fusion_config_tables.py
create mode 100644 tuna/miopen/db/get_db_tables.py
create mode 100644 tuna/miopen/db/mixin_tables.py
create mode 100644 tuna/miopen/db/solver.py
create mode 100644 tuna/miopen/db/tensortable.py
create mode 100644 tuna/miopen/doc/FinDocs.md
create mode 100644 tuna/miopen/doc/Tuning.md
create mode 100644 tuna/miopen/utils/json_to_sql.py
create mode 100644 tuna/miopen/utils/lib_helper.py
rename flaskapp/app.py => tuna/utils/celery_utils.py (69%)
rename tuna/utils/{miopen_utility.py => machine_utility.py} (98%)
diff --git a/.coveragerc b/.coveragerc
index 3a48c28c4..3e20d3ba9 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -6,8 +6,6 @@ omit =
tuna/corrupt_configs.py
tuna/data_migration_updated_layouts.py
tuna/export_configs.py
- tuna/flask.py
- tuna/flask_example.py
tuna/miopen/analyze_fdb.py
tuna/prune_db.py
tuna/query_db.py
@@ -17,10 +15,12 @@ omit =
tuna/utils/dupe_resolve.py
tuna/build_driver_cmd.py
tuna/solver_res.py
+ tuna/celery_app/celery_app.py
+ tuna/miopen/celery_tuning/celery_tasks.py
[report]
exclude_lines =
def __repr __
raise NotImplementedError
if __name__ == .__main__.:
-
\ No newline at end of file
+
diff --git a/.jenkins/export_db b/.jenkins/export_db
index 7aa280981..3d63d5a57 100644
--- a/.jenkins/export_db
+++ b/.jenkins/export_db
@@ -75,7 +75,7 @@ pipeline {
string(name: 'db_name', defaultValue: ${TUNA_DB_NAME}, description: 'Name of the database schema')
string(name: 'db_user', defaultValue: ${JENKINS_USER}, description: 'Username for the databse')
string(name: 'db_password', defaultValue: ${JENKINS_PWD}, description: 'Password for the user')
- string(name: 'docker_registry', defaultValue: '${headnode}:5000', description: 'Name of the docker registry for pushing images')
+ string(name: 'docker_registry', defaultValue: '${DOCKER_REGISTRY}', description: 'Name of the docker registry for pushing images')
}
stages {
stage("Check params")
diff --git a/.jenkins/init_session b/.jenkins/init_session
index 214eea78c..5cc4d9bb6 100644
--- a/.jenkins/init_session
+++ b/.jenkins/init_session
@@ -8,32 +8,35 @@ echo "${util_lib}"
library "${util_lib}"
def initSession(){
- backend = "HIPNOGPU"
+ backend = "HIP"
def tuna_docker
- def tuna_docker_name = utils.getDockerName("HIPNOGPU")
- def build_args = " --network host --build-arg ROCMVERSION=${params.rocm_version} --build-arg OSDB_BKC_VERSION=${params.osdb_bkc_version} --build-arg BACKEND=${backend} --build-arg MIOPEN_BRANCH=${miopen_branch_name} --build-arg DB_NAME=${params.db_name} --build-arg DB_USER_NAME=${db_user} --build-arg DB_USER_PASSWORD=${db_password} --build-arg DB_HOSTNAME=${db_host} --build-arg MIOPEN_USE_MLIR=${params.use_mlir}"
+ def tuna_docker_name = utils.getDockerName("HIP")
+ def build_args = " --network host --build-arg ROCMVERSION=${params.rocm_version} --build-arg OSDB_BKC_VERSION=${params.osdb_bkc_version} --build-arg BUILD_MIOPEN_DEPS=${params.build_miopen_deps} --build-arg BACKEND=${backend} --build-arg MIOPEN_BRANCH=${miopen_branch_name} --build-arg DB_NAME=${params.db_name} --build-arg DB_USER_NAME=${db_user} --build-arg DB_USER_PASSWORD=${db_password} --build-arg DB_HOSTNAME=${db_host} --build-arg MIOPEN_USE_MLIR=${params.use_mlir}"
if(params.base_image != '')
{
- build_args = build_args + " --build-arg BASEIMAGE=${params.base_image} --build-arg ROCM_PRE=1"
+ build_args = build_args + " --build-arg BASEIMAGE=${params.base_image}"
+ }
+ if(params.arch != '')
+ {
+ build_args = build_args + " --build-arg ARCH_TARGET=${params.arch}"
}
sh "echo ${build_args}"
tuna_docker = docker.build("${tuna_docker_name}", "${build_args} ." )
- def docker_run_args = "--network host --dns 8.8.8.8 --device=/dev/kfd --device /dev/dri:/dev/dri:rw --volume /dev/dri:/dev/dri:rw --group-add video -e TUNA_DB_HOSTNAME=${db_host} -e TUNA_DB_NAME=${params.db_name} -e TUNA_DB_USER_NAME=${db_user} -e TUNA_DB_PASSWORD=${db_password} -e gateway_ip=${gateway_ip} -e gateway_port=${gateway_port} -e gateway_user=${gateway_user} -e TUNA_LOGLEVEL=${params.tuna_loglevel}"
-
def num_session_prev = utils.runsql("SELECT count(*) from session where reason='${job_label}';")
if(params.arch != '' && params.num_cu != '')
{
def margs = "-a ${params.arch} -n ${params.num_cu}"
- sh "docker run ${docker_run_args} ${tuna_docker_name} ./tuna/go_fish.py miopen --init_session -l ${job_label} ${margs}"
+ sh "docker run ${docker_args} ${tuna_docker_name} ./tuna/go_fish.py miopen --init_session -l ${job_label} ${margs} --docker_name ${base_image}"
+ tuna_docker.push()
}
else
{
tuna_docker.push()
- sh "srun --no-kill -p ${slurm_partition} -N 1 -l bash -c 'docker run ${docker_run_args} ${tuna_docker_name} ./tuna/go_fish.py miopen --init_session -l ${job_label}'"
+ sh "srun --no-kill -p ${slurm_partition} -N 1 -l bash -c 'docker run ${docker_args} ${tuna_docker_name} ./tuna/go_fish.py miopen --init_session -l ${job_label} --docker_name ${base_image}'"
}
def num_session_now = utils.runsql("SELECT count(*) from session where reason='${job_label}';")
@@ -48,11 +51,6 @@ def initSession(){
def VerifyArgs()
{
- if(params.rocm_version == '' && params.osdb_bkc_version == '')
- {
- error "Either ROCm version or OSDB build number is required"
- }
-
if(params.rocm_version != '' && params.osdb_bkc_version != '')
{
error "Can only specify either the ROCm version or the OSDB build number"
@@ -65,13 +63,14 @@ def VerifyArgs()
}
pipeline {
- agent { node { label 'mysql' } }
+ agent { node { label 'build-node' } }
environment {
db_name = "${params.db_name}"
db_host = "${params.db_host}"
db_user = "${params.db_user}"
db_password = "${params.db_password}"
branch_id = "${params.branch_name}_${BUILD_ID}"
+ docker_args = "--network host --device=/dev/kfd --device /dev/dri:/dev/dri:rw --volume /dev/dri:/dev/dri:rw --group-add video -e TUNA_LOGLEVEL=${tuna_loglevel} -e TUNA_CELERY_BROKER_HOST=${PIPELINE_CELERY_BROKER_HOST} -e TUNA_CELERY_BROKER_USER=${TUNA_CELERY_BROKER_USER} -e TUNA_CELERY_BROKER_PWD=${TUNA_CELERY_BROKER_PWD} -e TUNA_CELERY_BROKER_PORT=${TUNA_CELERY_BROKER_PORT} -e TUNA_CELERY_BACKEND_HOST=${PIPELINE_CELERY_BACKEND_HOST} -e TUNA_CELERY_BACKEND_PORT=${TUNA_CELERY_BACKEND_PORT} -e TUNA_DB_HOSTNAME=${db_host} -e TUNA_DB_NAME=${params.db_name} -e TUNA_DB_USER_NAME=${db_user} -e TUNA_DB_PASSWORD=${db_password} -e gateway_ip=${gateway_ip} -e gateway_port=${gateway_port} -e gateway_user=${gateway_user}"
}
parameters {
string(name: 'branch_name', defaultValue: 'init_session', description: '')
@@ -85,10 +84,11 @@ pipeline {
string(name: 'db_host', defaultValue: "${headnode}", description: 'Name of the machine hosting the database instance')
string(name: 'rocm_version', defaultValue: '', description: 'Version of ROCm for base docker packages, exclusive with osdb_bkc_version')
string(name: 'osdb_bkc_version', defaultValue: '', description: 'Build number for OSDB, exclusive with rocm_version')
+ string(name: 'build_miopen_deps', defaultValue: '', description: 'Build miopen dependencies: set to 1 to build dependencies, else leave blank')
string(name: 'db_name', defaultValue: "${PIPELINE_DB_NAME}", description: 'Name of the database schema')
string(name: 'db_user', defaultValue: "${PIPELINE_USER}", description: 'Username for the databse')
string(name: 'db_password', defaultValue: "${PIPELINE_PWD}", description: 'Password for the user')
- string(name: 'docker_registry', defaultValue: "${headnode}:5000", description: 'Name of the docker registry for pushing images')
+ string(name: 'docker_registry', defaultValue: "${DOCKER_REGISTRY}", description: 'Name of the docker registry for pushing images')
string(name: 'base_image', defaultValue: '', description: 'Put a fully qualified docker name here to use (optional)')
}
stages {
diff --git a/.jenkins/load_jobs b/.jenkins/load_jobs
index 41efc30cf..06c8f2d79 100644
--- a/.jenkins/load_jobs
+++ b/.jenkins/load_jobs
@@ -64,7 +64,7 @@ pipeline {
string(name: 'db_name', defaultValue: "${PIPELINE_DB_NAME}", description: 'Name of the database schema')
string(name: 'db_user', defaultValue: "${PIPELINE_USER}", description: 'Username for the databse')
string(name: 'db_password', defaultValue: "${PIPELINE_PWD}", description: 'Password for the user')
- string(name: 'docker_registry', defaultValue: "${headnode}:5000", description: 'Name of the docker registry for pushing images')
+ string(name: 'docker_registry', defaultValue: "${DOCKER_REGISTRY}", description: 'Name of the docker registry for pushing images')
string(name: 'base_image', defaultValue: '', description: 'Put a fully qualified docker name here to use (optional)')
choice(name: 'cmd', choices: ['', 'conv', 'convfp16', 'convbfp16'], description: 'get configs for cmd type')
choice(name: 'stage', choices: ['perf', 'fin_find'], description: 'Load jobs args')
diff --git a/.jenkins/perf_compile b/.jenkins/perf_compile
index 47c272931..39d92a5df 100644
--- a/.jenkins/perf_compile
+++ b/.jenkins/perf_compile
@@ -21,13 +21,14 @@ pipeline {
agent { node { label 'slurm' } }
environment {
backend = 'HIPNOGPU'
- docker_args = "--network host -e TUNA_LOGLEVEL=${tuna_loglevel} -e TUNA_DB_HOSTNAME=${db_host} -e TUNA_DB_NAME=${params.db_name} -e TUNA_DB_USER_NAME=${db_user} -e TUNA_DB_PASSWORD=${db_password}"
+ docker_args = "--network host -e TUNA_LOGLEVEL=${tuna_loglevel} -e TUNA_CELERY_BROKER_HOST=${PIPELINE_CELERY_BROKER_HOST} -e TUNA_CELERY_BROKER_USER=${TUNA_CELERY_BROKER_USER} -e TUNA_CELERY_BROKER_PWD=${TUNA_CELERY_BROKER_PWD} -e TUNA_CELERY_BROKER_PORT=${TUNA_CELERY_BROKER_PORT} -e TUNA_CELERY_BACKEND_HOST=${PIPELINE_CELERY_BACKEND_HOST} -e TUNA_CELERY_BACKEND_PORT=${TUNA_CELERY_BACKEND_PORT} -e TUNA_DB_HOSTNAME=${db_host} -e TUNA_DB_NAME=${params.db_name} -e TUNA_DB_USER_NAME=${db_user} -e TUNA_DB_PASSWORD=${db_password}"
db_name = "${params.db_name}"
partition = "${params.slurm_partition}"
branch_id = "${params.branch_name}_${BUILD_ID}"
+ CREDS = credentials("$DOCKER_CRED")
}
parameters {
- string(name: 'branch_name', defaultValue: 'compile_pipe_gold', description: '')
+ string(name: 'branch_name', defaultValue: 'compile_pipe_celery', description: '')
choice(name: 'use_mlir', choices: ['On', 'Off'], description: 'Build MIOpen with MLIR enabled')
booleanParam(name: 'dynamic_solvers_only', defaultValue: false, description: 'Only use dynamic solvers in tuning')
string(name: 'session_id', defaultValue: '', description: 'session id for compile')
@@ -38,8 +39,7 @@ pipeline {
string(name: 'db_name', defaultValue: "${PIPELINE_DB_NAME}", description: 'Name of the database schema')
string(name: 'db_user', defaultValue: "${PIPELINE_USER}", description: 'Username for the databse')
string(name: 'db_password', defaultValue: "${PIPELINE_PWD}", description: 'Password for the user')
- string(name: 'docker_registry', defaultValue: "${headnode}:5000", description: 'Name of the docker registry for pushing images')
- string(name: 'base_image', defaultValue: '', description: 'Put a fully qualified docker name here to use (optional)')
+ string(name: 'docker_registry', defaultValue: "${DOCKER_REGISTRY}", description: 'Name of the docker registry for pushing images')
choice(name: 'stage', choices: ['perf', 'fin_find'], description: 'Compile method')
}
stages {
diff --git a/.jenkins/perf_eval b/.jenkins/perf_eval
index c96e0916b..f389e0ab4 100644
--- a/.jenkins/perf_eval
+++ b/.jenkins/perf_eval
@@ -18,23 +18,24 @@ pipeline {
agent { node { label 'slurm' } }
environment {
backend = 'HIP'
- docker_args = "--network host -e TUNA_LOGLEVEL=${tuna_loglevel} -e TUNA_DB_HOSTNAME=${db_host} -e TUNA_DB_NAME=${params.db_name} -e TUNA_DB_USER_NAME=${db_user} -e TUNA_DB_PASSWORD=${db_password} -e gateway_ip=${gateway_ip} -e gateway_port=${gateway_port} -e gateway_user=${gateway_user} --privileged --device=/dev/kfd --device /dev/dri:/dev/dri:rw --volume /dev/dri:/dev/dri:rw --group-add video"
+ docker_args = "--network host -e TUNA_LOGLEVEL=${tuna_loglevel} -e TUNA_CELERY_BROKER_HOST=${PIPELINE_CELERY_BROKER_HOST} -e TUNA_CELERY_BROKER_USER=${TUNA_CELERY_BROKER_USER} -e TUNA_CELERY_BROKER_PWD=${TUNA_CELERY_BROKER_PWD} -e TUNA_CELERY_BROKER_PORT=${TUNA_CELERY_BROKER_PORT} -e TUNA_CELERY_BACKEND_HOST=${PIPELINE_CELERY_BACKEND_HOST} -e TUNA_CELERY_BACKEND_PORT=${TUNA_CELERY_BACKEND_PORT} -e TUNA_DB_HOSTNAME=${db_host} -e TUNA_DB_NAME=${params.db_name} -e TUNA_DB_USER_NAME=${db_user} -e TUNA_DB_PASSWORD=${db_password} -e gateway_ip=${gateway_ip} -e gateway_port=${gateway_port} -e gateway_user=${gateway_user} --privileged --device=/dev/kfd --device /dev/dri:/dev/dri:rw --volume /dev/dri:/dev/dri:rw --group-add video"
db_name = "${params.db_name}"
+ partition = ""
branch_id = "${params.branch_name}_${BUILD_ID}"
+ CREDS = credentials("$DOCKER_CRED")
}
parameters {
- string(name: 'branch_name', defaultValue: 'eval_pipe_gold', description: '')
+ string(name: 'branch_name', defaultValue: 'eval_pipe_celery', description: '')
choice(name: 'use_mlir', choices: ['On', 'Off'], description: 'Build MIOpen with MLIR enabled')
booleanParam(name: 'dynamic_solvers_only', defaultValue: false, description: 'Only use dynamic solvers in tuning')
string(name: 'session_id', defaultValue: '', description: 'session id for evaluation')
choice(name: 'tuna_loglevel', choices: ['WARN', 'ERROR', 'INFO'], description: 'Log level for TUNA')
- string(name: 'env', defaultValue: '', description: 'Additional environment variables for compilation.')
+ string(name: 'env', defaultValue: 'HIP_FORCE_DEV_KERNARG=1', description: 'Additional environment variables for compilation.')
string(name: 'db_host', defaultValue: "${headnode}", description: 'Name of the machine hosting the database instance')
string(name: 'db_name', defaultValue: "${PIPELINE_DB_NAME}", description: 'Name of the database schema')
string(name: 'db_user', defaultValue: "${PIPELINE_USER}", description: 'Username for the databse')
string(name: 'db_password', defaultValue: "${PIPELINE_PWD}", description: 'Password for the user')
- string(name: 'docker_registry', defaultValue: "${headnode}:5000", description: 'Name of the docker registry for pushing images')
- string(name: 'base_image', defaultValue: '', description: 'Put a fully qualified docker name here to use (optional)')
+ string(name: 'docker_registry', defaultValue: "${DOCKER_REGISTRY}", description: 'Name of the docker registry for pushing images')
choice(name: 'stage', choices: ['perf', 'fin_find'], description: 'Evaluate method')
}
stages {
diff --git a/.jenkins/query_config b/.jenkins/query_config
index 47d088bf4..673b4d99b 100644
--- a/.jenkins/query_config
+++ b/.jenkins/query_config
@@ -22,7 +22,7 @@ def QueryConfigs(arch, num_cu, fdb_prefix)
}
if(params.miopen_version != '')
{
- sh "wget https://github.com/ROCmSoftwarePlatform/MIOpen/blob/${params.miopen_version}/src/kernels/${fdb_prefix}.HIP.fdb.txt?raw=true -O ${fdb_prefix}.HIP.fdb.txt"
+ sh "wget https://github.com/ROCm/MIOpen/blob/${params.miopen_version}/src/kernels/${fdb_prefix}.HIP.fdb.txt?raw=true -O ${fdb_prefix}.HIP.fdb.txt"
script_args = script_args + " --fdb_filename ${fdb_prefix}.HIP.fdb.txt"
archiveArtifacts "${fdb_prefix}.HIP.fdb.txt"
}
@@ -45,7 +45,7 @@ def VerifyArgs()
if(params.miopen_version != '')
{
// Checking only one arch is sufficient
- statusCode = sh script:" wget -q --method=HEAD https://github.com/ROCmSoftwarePlatform/MIOpen/blob/${params.miopen_version}/src/kernels/gfx900_56.HIP.fdb.txt?raw=true ", returnStatus:true
+ statusCode = sh script:" wget -q --method=HEAD https://github.com/ROCm/MIOpen/blob/${params.miopen_version}/src/kernels/gfx900_56.HIP.fdb.txt?raw=true ", returnStatus:true
if(statusCode)
{
error "Invalid MIOpen version for find db"
diff --git a/.jenkins/update_solvers b/.jenkins/update_solvers
index c37b9c92a..982e2f5a6 100644
--- a/.jenkins/update_solvers
+++ b/.jenkins/update_solvers
@@ -23,8 +23,10 @@ pipeline {
db_user = "${params.db_user}"
db_password = "${params.db_password}"
branch_name = "${params.branch_name}"
- backend = 'HIPNOGPU'
+ backend = 'HIP'
+ docker_args = "--network host --device=/dev/kfd --device /dev/dri:/dev/dri:rw --volume /dev/dri:/dev/dri:rw --group-add video -e TUNA_LOGLEVEL=${tuna_loglevel} -e TUNA_CELERY_BROKER_HOST=${PIPELINE_CELERY_BROKER_HOST} -e TUNA_CELERY_BROKER_USER=${TUNA_CELERY_BROKER_USER} -e TUNA_CELERY_BROKER_PWD=${TUNA_CELERY_BROKER_PWD} -e TUNA_CELERY_BROKER_PORT=${TUNA_CELERY_BROKER_PORT} -e TUNA_CELERY_BACKEND_HOST=${PIPELINE_CELERY_BACKEND_HOST} -e TUNA_CELERY_BACKEND_PORT=${TUNA_CELERY_BACKEND_PORT} -e TUNA_DB_HOSTNAME=${db_host} -e TUNA_DB_NAME=${params.db_name} -e TUNA_DB_USER_NAME=${db_user} -e TUNA_DB_PASSWORD=${db_password} -e gateway_ip=${gateway_ip} -e gateway_port=${gateway_port} -e gateway_user=${gateway_user}"
branch_id = "${params.branch_name}_${BUILD_ID}"
+ CREDS = credentials("$DOCKER_CRED")
}
parameters {
string(name: 'branch_name', defaultValue: 'applic_pipe_gold', description: '')
@@ -38,8 +40,7 @@ pipeline {
string(name: 'db_name', defaultValue: "${PIPELINE_DB_NAME}", description: 'Name of the database schema')
string(name: 'db_user', defaultValue: "${PIPELINE_USER}", description: 'Username for the databse')
string(name: 'db_password', defaultValue: "${PIPELINE_PWD}", description: 'Password for the user')
- string(name: 'docker_registry', defaultValue: "${headnode}:5000", description: 'Name of the docker registry for pushing images')
- string(name: 'base_image', defaultValue: '', description: 'Put a fully qualified docker name here to use (optional)')
+ string(name: 'docker_registry', defaultValue: "${DOCKER_REGISTRY}", description: 'Name of the docker registry for pushing images')
}
stages {
stage("Check params")
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
new file mode 100644
index 000000000..e185540ca
--- /dev/null
+++ b/.readthedocs.yaml
@@ -0,0 +1,32 @@
+# .readthedocs.yaml
+# Read the Docs configuration file
+# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
+
+# Required
+version: 2
+
+# Set the OS, Python version and other tools you might need
+build:
+ os: ubuntu-22.04
+ tools:
+ python: "3.9"
+ # You can also specify other tool versions:
+ # nodejs: "19"
+ # rust: "1.64"
+ # golang: "1.19"
+
+# Build documentation in the "docs/" directory with Sphinx
+sphinx:
+ configuration: docs/conf.py
+
+# Optionally build your docs in additional formats such as PDF and ePub
+# formats:
+# - pdf
+# - epub
+
+# Optional but recommended, declare the Python requirements required
+# to build your documentation
+# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
+python:
+ install:
+ - requirements: docs/sphinx/requirements.txt
diff --git a/Dockerfile b/Dockerfile
index 879be9dc4..8b0328dd4 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,25 +1,52 @@
-#default image to ubuntu + install rocm
-ARG BASEIMAGE=rocm/miopen:ci_5450cc
-
-#FROM ubuntu:20.04 as dtuna-ver-0
-FROM $BASEIMAGE as dtuna-ver-0
#install rocm
ARG ROCMVERSION=
ARG OSDB_BKC_VERSION=
-#'12969'
-ENV NO_ROCM_INST=
+
+#test if rocm version was set
+ARG HASVER=${ROCMVERSION:+$ROCMVERSION}
+ARG HASVER=${HASVER:-$OSDB_BKC_VERSION}
+
+ARG BASEIMAGE=rocm/miopen:ci_d50fb6
+ARG UBUNTU=ubuntu:22.04
+
+#use UBUNTU with rocm version set
+ARG USEIMAGE=${HASVER:+${UBUNTU}}
+ARG USEIMAGE=${USEIMAGE:-$BASEIMAGE}
+
+FROM $USEIMAGE as dtuna-ver-0
+
+#args before from are wiped
+ARG ROCMVERSION=
+ARG OSDB_BKC_VERSION=
+
+RUN test -d /opt/rocm*; \
+ if [ $? -eq 0 ] ; then \
+ test -d /opt/rocm; \
+ if [ $? ] ; then \
+ ln -s /opt/rocm* /opt/rocm; \
+ fi \
+ fi
+
# Add rocm repository
RUN apt-get update && apt-get install -y wget gnupg
RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
RUN echo "" > /env; \
if ! [ -z $OSDB_BKC_VERSION ]; then \
echo "Using BKC VERSION: $OSDB_BKC_VERSION";\
- sh -c "echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-osdb-20.04-deb/ compute-rocm-dkms-no-npi-hipclang ${OSDB_BKC_VERSION} > /etc/apt/sources.list.d/rocm.list" ;\
- cat /etc/apt/sources.list.d/rocm.list;\
+ if [$(cat /opt/rocm/.info/version) -ne $OSDB_BKC_VERSION]; then \
+ sh -c "echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-osdb-20.04-deb/ compute-rocm-dkms-no-npi-hipclang ${OSDB_BKC_VERSION} > /etc/apt/sources.list.d/rocm.list" ;\
+ cat /etc/apt/sources.list.d/rocm.list;\
+ else \
+ echo "export NO_ROCM_INST=1" >> /env; \
+ fi \
elif ! [ -z $ROCMVERSION ]; then \
echo "Using Release VERSION: $ROCMVERSION";\
- sh -c "echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-osdb-20.04-deb/ compute-rocm-rel-${ROCMVERSION} > /etc/apt/sources.list.d/rocm.list" ;\
- cat /etc/apt/sources.list.d/rocm.list;\
+ if [$(cat /opt/rocm/.info/version) -ne $ROCMVERSION]; then \
+ sh -c "echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-osdb-20.04-deb/ compute-rocm-rel-${ROCMVERSION} > /etc/apt/sources.list.d/rocm.list" ;\
+ cat /etc/apt/sources.list.d/rocm.list;\
+ else \
+ echo "export NO_ROCM_INST=1" >> /env; \
+ fi \
else \
echo "export NO_ROCM_INST=1" >> /env; \
fi
@@ -90,25 +117,28 @@ ENV UBSAN_OPTIONS=print_stacktrace=1
RUN wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb
RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb
-# Install cget
-#RUN pip install cget
-
-# Install rclone
-#RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz
ARG MIOPEN_DIR=/root/dMIOpen
#Clone MIOpen
-RUN git clone https://github.com/ROCmSoftwarePlatform/MIOpen.git $MIOPEN_DIR
+RUN git clone https://github.com/ROCm/MIOpen.git $MIOPEN_DIR
WORKDIR $MIOPEN_DIR
-ARG MIOPEN_BRANCH=b5c9cd5b0fa65bc77004dd59adcbb336ead031af
+ARG MIOPEN_BRANCH=ca2eb7538
RUN git pull && git checkout $MIOPEN_BRANCH
ARG PREFIX=/opt/rocm
-ARG MIOPEN_DEPS=/opt/rocm
+ARG MIOPEN_DEPS=$MIOPEN_DIR/deps
# Install dependencies # included in rocm/miopen:ci_xxxxxx
-#RUN cmake -P install_deps.cmake --prefix $MIOPEN_DEPS
-#RUN CXXFLAGS='-isystem $PREFIX/include' cget install -f ./mlir-requirements.txt
+ARG BUILD_MIOPEN_DEPS=
+ARG ARCH_TARGET=
+RUN . /env; if [ -z $NO_ROCM_INST ] || ! [ -z $BUILD_MIOPEN_DEPS ]; then\
+ pip install cget; \
+ if ! [ -z $ARCH_TARGET ]; then \
+ sed -i "s#\(composable_kernel.*\)#\1 -DGPU_TARGETS=\"$ARCH_TARGET\"#" requirements.txt; \
+ fi; \
+ CXX=/opt/rocm/llvm/bin/clang++ cget install -f ./dev-requirements.txt --prefix $MIOPEN_DEPS; \
+ git checkout requirements.txt; \
+ fi
ARG TUNA_USER=miopenpdb
ARG BACKEND=HIP
@@ -117,7 +147,8 @@ WORKDIR $MIOPEN_DIR/build
ARG MIOPEN_CACHE_DIR=/tmp/${TUNA_USER}/cache
ARG MIOPEN_USER_DB_PATH=/tmp/$TUNA_USER/config/miopen
ARG MIOPEN_USE_MLIR=ON
-ARG MIOPEN_CMAKE_ARGS="-DMIOPEN_USE_MLIR=${MIOPEN_USE_MLIR} -DMIOPEN_INSTALL_CXX_HEADERS=On -DMIOPEN_CACHE_DIR=${MIOPEN_CACHE_DIR} -DMIOPEN_USER_DB_PATH=${MIOPEN_USER_DB_PATH} -DMIOPEN_BACKEND=${BACKEND} -DCMAKE_PREFIX_PATH=${MIOPEN_DEPS}"
+# build kdb objects with offline clang compiler, disable comgr + hiprtc (which would make target id specific code objects)
+ARG MIOPEN_CMAKE_ARGS="-DMIOPEN_USE_COMGR=Off -DMIOPEN_USE_HIPRTC=Off -DMIOPEN_USE_MLIR=${MIOPEN_USE_MLIR} -DMIOPEN_INSTALL_CXX_HEADERS=On -DMIOPEN_CACHE_DIR=${MIOPEN_CACHE_DIR} -DMIOPEN_USER_DB_PATH=${MIOPEN_USER_DB_PATH} -DMIOPEN_BACKEND=${BACKEND} -DCMAKE_PREFIX_PATH=${MIOPEN_DEPS}"
RUN echo "MIOPEN: Selected $BACKEND backend."
RUN if [ $BACKEND = "OpenCL" ]; then \
@@ -135,12 +166,12 @@ RUN git submodule update --init --recursive
ARG FIN_DIR=$MIOPEN_DIR/fin
WORKDIR $FIN_DIR
# Can be a branch or a SHA
-ARG FIN_BRANCH=
+ARG FIN_BRANCH=develop
RUN if ! [ -z $FIN_BRANCH ]; then \
git fetch && git checkout $FIN_BRANCH; \
fi
# Install dependencies
-RUN cmake -P install_deps.cmake
+#RUN cmake -P install_deps.cmake
WORKDIR $FIN_DIR/_hip
RUN CXX=/opt/rocm/llvm/bin/clang++ cmake -DCMAKE_BUILD_TYPE=Debug -DCMAKE_PREFIX_PATH=$MIOPEN_DEPS $FIN_DIR
@@ -150,7 +181,7 @@ RUN make install
#SET MIOPEN ENVIRONMENT VARIABLES
ENV MIOPEN_LOG_LEVEL=6
-ENV PATH=$PREFIX/miopen/bin:$PREFIX/bin:$PATH
+ENV PATH=$PREFIX/miopen/bin:$PREFIX/bin:$MIOPEN_DEPS/bin:$PATH
ENV LD_LIBRARY_PATH=/opt/rocm/lib:$LD_LIRBARY_PATH
RUN ulimit -c unlimited
# Should be over-ridden by the CI/launcher to point to new db
diff --git a/Dockerfile.flask b/Dockerfile.flask
index 3a0fff088..7bf124864 100644
--- a/Dockerfile.flask
+++ b/Dockerfile.flask
@@ -20,7 +20,7 @@ RUN apt-get update -y && apt install -y vim python3-pip git
WORKDIR /etc/flask
-RUN git clone https://rocm-mici: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@github.com/ROCmSoftwarePlatform/Tuna.git
+RUN git clone https://rocm-mici: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@github.com/ROCm/Tuna.git
RUN pip3 install -r /etc/flask/Tuna/requirements.txt
diff --git a/Jenkinsfile b/Jenkinsfile
index 6471681ef..524bb369a 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -9,22 +9,22 @@ pipeline {
branch_master = "develop"
db_name = "${TUNA_DB_NAME}_${branch}_${BUILD_ID}"
docker_args = '--privileged --device=/dev/kfd --device /dev/dri:/dev/dri:rw --volume /dev/dri:/dev/dri:rw -v /var/lib/docker/:/var/lib/docker --group-add video'
- db_host = 'localhost'
+ db_host = "${CI_DB_HOSTNAME}"
db_user = "${DB_USER_NAME}"
db_password = "${DB_USER_PASSWORD}"
+ broker_user = "${TUNA_CELERY_BROKER_USER}"
+ broker_pwd = "${TUNA_CELERY_BROKER_PWD}"
pipeline_user = "${PIPELINE_USER}"
pipeline_pwd = "${PIPELINE_PWD}"
- arch = 'gfx908'
- num_cu = '120'
- arch_908 = 'gfx908'
- num_cu_120 = '120'
+ arch = 'gfx90a'
+ num_cu = '104'
machine_ip = "${machine_ip}"
machine_local_ip = "${machine_local_ip}"
username = "${username}"
pwd = "${pwd}"
port = "${port}"
TUNA_ROCM_VERSION = '4.5'
-
+ docker_registry = "${DOCKER_REGISTRY}"
}
stages {
stage("docker build") {
@@ -85,24 +85,32 @@ pipeline {
}
}
}
- stage("pytest3 and Tests Coverage"){
- agent { label utils.rocmnode("tunatest") }
+ stage("pytest3"){
+ agent{ label "gfx90a" }
steps {
- script {
- utils.pytestSuite3AndCoverage(branch, branch_master)
- }
+ script {
+ utils.pytestSuite3()
+ }
}
}
- stage("fin find compile"){
- agent{ label utils.rocmnode("tunatest") }
- steps{
+ stage("Coverage"){
+ agent { label utils.rocmnode("tunatest") }
+ steps {
script {
- utils.finFindCompile()
+ utils.Coverage(branch, branch_master)
}
}
}
+ stage("fin find compile"){
+ agent{ label utils.rocmnode("tunatest") }
+ steps{
+ script {
+ utils.finFindCompileEnqueue()
+ }
+ }
+ }
stage("fin find eval"){
- agent{ label "gfx908" }
+ agent{ label "gfx90a" }
steps {
script {
utils.finFindEval()
@@ -125,11 +133,11 @@ pipeline {
}
}
}
- stage("perf eval gfx908"){
- agent{ label "gfx908" }
+ stage("perf eval gfx90a"){
+ agent{ label "gfx90a" }
steps{
script {
- utils.perfEval_gfx908()
+ utils.perfEval()
}
}
}
diff --git a/README.md b/README.md
index 92634d3ba..c3d02a01f 100644
--- a/README.md
+++ b/README.md
@@ -1,132 +1,250 @@
-TUNA
-====
+MITuna
+======
-Tuna is a distributed tuning infrastructure that provides pre-compiled kernels
-for MIOpen customers through automated Jenkins pipelines and SLURM scalable
-architecture. MITuna also provides a scalable task management infrastructure
-ready to integrate with external libaries.
+[MITuna](https://mituna.readthedocs.io/en/latest/index.html) is a distributed tuning infrastructure
+that provides pre-compiled kernels for MIOpen customers through automated Jenkins pipelines and
+SLURM scalable architecture. MITuna also provides a scalable task management infrastructure
+ready to integrate with external libaries. The `Example` library provides a sample on
+how to achieve this.
Prerequisites
-------------
-Install python3.9
-```
-apt-get update && apt-get install software-properties-common
-add-apt-repository ppa:deadsnakes/ppa
-apt install python3.9 python3.9-dev python3.9-venv
-```
-
-Install pip for python3.9
-```
-wget https://bootstrap.pypa.io/get-pip.py -o get-pip.py
-python3.9 get-pip.py
-rm get-pip.py
-```
-
-Install MySQL server
-```
-apt-get install mysql-server
-```
-
-```
-mysqld --initialize
-grep 'temporary password' /var/log/mysql/error.log
-```
-
-Enable the service
-```
-systemctl start mysql
-```
-
-```
-mysql -u root -p
-
--t - tag
--f - filepath
---model - model name
---md_version - model version
---framework - framework name
---fw_version - framework version
--
-```
-
-**Add Solvers (2)**
-
-The solver table contains MIOpen solvers and solver characteristics.
-This should be updated when an MIOpen version modifies solvers.
-
-```
-./go_fish.py miopen --update_solvers
-```
-
-**Add Tuning Session (3)**
-
-Session will track the architecture and skew, as well as the miopen version and
-rocm version for the tuning session.
-
-This command will need to be run from inside the tuning environment eg MITuna docker
-and will populate the table with the version and architecture information.
-
-[Use backend=HIPNOGPU docker]
-```
-./go_fish.py miopen --init_session -l reason
---init_session - create a session entry
--l - reference text description
-```
-
-**Add Applicability (4)**
-Each network configuration has a set of applicable solvers. This step will update the
-solver_applicability table with applicable solvers for each configuration for the session.
-
-[Use backend=HIPNOGPU docker]
-```
-./go_fish.py miopen --update_applicability --session_id 1
---session_id - tuning session id
-```
-
-**Load Jobs (5)**
-
-Time to create the jobs for the tuning session. Specify the session id, the configs that
-should be tuned, and the fin_step to be executed. Configs can be added by using the tag from
-the config_tags table. Jobs should have a compile and an eval fin step pair.
-
-Fin steps include: miopen_perf_compile, miopen_perf_eval, miopen_find_compile, and miopen_find_eval.
-
-```
-./load_job.py --session_id 1 -t resnet50 --fin_steps miopen_perf_compile,miopen_perf_eval -o -l reason
---session_id - tuning session id
--t - config tag
---fin_steps - operations to be performed by fin (tuning handle into miopen)
--o - only_applicable, will create a job for each applicable solver
--l - reference text description
-```
-
-**Compile Step (6)**
-
-Once prerequisites are set, tuning can begin. To compile the jobs,
-supply the session id along with the compile fin_step matching the one in the job table.
-
-[Use backend=HIPNOGPU docker]
-```
-./go_fish.py miopen --session_id 1 --fin_steps miopen_perf_compile
---session_id - tuning session id
---fin_steps - execute this operation
-```
-
-**Evaluation Step (7)**
-
-Once compilation has been started, evaluation can also be launched.
-This command is similar to the previous.
-
-[Use backend=HIP docker]
-```
-./go_fish.py miopen --session_id 1 --fin_steps miopen_perf_eval
---session_id - tuning session id
---fin_steps - execute this operation
-```
-
-**Database Export (8)**
-
-To export the results the export_db.py script can be run with options
-for selecting session as well as database type.
-
-The outputs of this function are database files in the format that MIOpen keeps and manages.
-eg for MI100, -p will produce a gfx90878.db file, -f will produce gfx90878.HIP.fdb.txt, and -k will produce gfx90878.kdb.
-
-```
-./export_db.py --session_id 1 -p
---session_id - tuning session id
--p - export performance db
--f - export find db
--k - export kernel db
-```
diff --git a/doc/tuningcycle.rst b/doc/tuningcycle.rst
deleted file mode 100644
index 5daddca51..000000000
--- a/doc/tuningcycle.rst
+++ /dev/null
@@ -1,4 +0,0 @@
-TuningCycle
-************
-
-.. include:: src/TuningCycle.md
diff --git a/docker-compose-flower.yaml b/docker-compose-flower.yaml
new file mode 100644
index 000000000..eac73c7ca
--- /dev/null
+++ b/docker-compose-flower.yaml
@@ -0,0 +1,33 @@
+version: "3.8"
+
+services:
+ celery:
+ build:
+ context: .
+ dockerfile: Dockerfile
+ container_name: mituna_celery_flower
+ privileged: true
+ group_add:
+ - video
+ devices:
+ - /dev/kfd:/dev/kfd
+ - /dev/dri:/dev:dri:rw
+ stdin_open: true
+ tty: true
+ volumes:
+ - /dev/dri:/dev/dri/:rw
+ - /var/lib/docker/:/var/lib/docker
+ env_file:
+ - .env
+ environment:
+ - TUNA_DB_HOSTNAME=${db_host}
+ - TUNA_DB_NAME=${db_name}
+ - TUNA_DB_USER_NAME=${db_user}
+ - TUNA_DB_USER_PASSWORD=${db_password}
+ - BROKER_URL=redis://localhost:6378/14
+ - TUNA_REDIS_PORT=6378
+ #purge offline workers older than 1 day in seconds
+ #- FLOWER_PURGE_OFFLINE_WORKERS=86400
+ command: "celery -A tuna.celery_app.celery_app flower --debug --persistent=True --purge_offline_workers=10800"
+ network_mode: "host"
+
diff --git a/docker-compose-flower_rabbitmq.yaml b/docker-compose-flower_rabbitmq.yaml
new file mode 100644
index 000000000..46dd9670a
--- /dev/null
+++ b/docker-compose-flower_rabbitmq.yaml
@@ -0,0 +1,34 @@
+version: "3.8"
+
+services:
+ celery:
+ build:
+ context: .
+ dockerfile: Dockerfile
+ container_name: mituna_celery_flower_rabbitmq
+ privileged: true
+ group_add:
+ - video
+ devices:
+ - /dev/kfd:/dev/kfd
+ - /dev/dri:/dev:dri:rw
+ stdin_open: true
+ tty: true
+ volumes:
+ - /dev/dri:/dev/dri/:rw
+ - /var/lib/docker/:/var/lib/docker
+ env_file:
+ - .env
+ environment:
+ - TUNA_DB_HOSTNAME=${db_host}
+ - TUNA_DB_NAME=${db_name}
+ - TUNA_DB_USER_NAME=${db_user}
+ - TUNA_DB_USER_PASSWORD=${db_password}
+ - TUNA_CELERY_BROKER_USER=${TUNA_CELERY_BROKER_USER}
+ - TUNA_CELERY_BROKER_PWD=${TUNA_CELERY_BROKER_PWD}
+ - TUNA_CELERY_BROKER_HOST=${TUNA_CELERY_BROKER_HOST}
+ - TUNA_CELERY_BROKER_PORT=${TUNA_CELERY_BROKER_PORT}
+ #purge offline workers older than 1 day in seconds
+ command: "celery -A tuna.celery_app.celery_app flower --debug --persistent=True --purge_offline_workers=10800 --broker_api=http://${TUNA_CELERY_BROKER_USER}:${TUNA_CELERY_BROKER_PWD}@${TUNA_CELERY_BROKER_HOST}:${TUNA_CELERY_BROKER_PORT}/api/vhost --port=5555"
+ network_mode: "host"
+
diff --git a/docker-compose.yaml b/docker-compose.yaml
new file mode 100644
index 000000000..12447fa21
--- /dev/null
+++ b/docker-compose.yaml
@@ -0,0 +1,43 @@
+version: "3.8"
+
+services:
+ redis:
+ image: redis:latest
+ ports:
+ - "6380:6379"
+ expose:
+ - "6380"
+ restart: always
+ container_name: mituna_redis
+
+ celery:
+ build:
+ context: .
+ dockerfile: Dockerfile
+ #perf_compile step requires HIPNOGPU
+ #args:
+ # - "BACKEND=HIPNOGPU"
+ container_name: mituna_celery
+ privileged: true
+ group_add:
+ - video
+ devices:
+ - /dev/kfd:/dev/kfd
+ - /dev/dri:/dev:dri:rw
+ stdin_open: true
+ tty: true
+ volumes:
+ - /dev/dri:/dev/dri/:rw
+ - /var/lib/docker/:/var/lib/docker
+ env_file:
+ - .env
+ environment:
+ - TUNA_DB_HOSTNAME=${db_host}
+ - TUNA_DB_NAME=${db_name}
+ - TUNA_DB_USER_NAME=${db_user}
+ - TUNA_DB_USER_PASSWORD=${db_password}
+ - HIP_VISIBLE_DEVICES=1
+ command: "celery -A tuna.celery_app.celery worker -l info -E"
+ depends_on:
+ - redis
+
diff --git a/doc/Doxyfile b/docs/Doxyfile
similarity index 99%
rename from doc/Doxyfile
rename to docs/Doxyfile
index 1d629c160..06ac290a8 100644
--- a/doc/Doxyfile
+++ b/docs/Doxyfile
@@ -873,7 +873,7 @@ RECURSIVE = NO
# Note that relative paths are relative to the directory from which doxygen is
# run.
-EXCLUDE =
+EXCLUDE = README.md
# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or
# directories that are symbolic links (a Unix file system feature) are excluded
diff --git a/doc/Makefile b/docs/Makefile
similarity index 100%
rename from doc/Makefile
rename to docs/Makefile
diff --git a/doc/conf.py b/docs/conf.py
similarity index 94%
rename from doc/conf.py
rename to docs/conf.py
index f04cebc3a..9dbb2e63d 100644
--- a/doc/conf.py
+++ b/docs/conf.py
@@ -25,14 +25,14 @@
# You can specify multiple suffix as a list of string:
#
source_suffix = ['.rst', '.md']
-breathe_projects = { "MITuna": "xml/" }
+breathe_projects = {"MITuna": "xml/"}
#breathe_projects_source = { "tuna": "xml/" }
# General information about the project.
project = u'MITuna'
-copyright = u'2022, Advanced Micro Devices, Inc. All rights reserved'
+copyright = u'2024, Advanced Micro Devices, Inc. All rights reserved'
author = u'Advanced Micro Devices, Inc'
-version = '1.0'
+version = '2.0'
release = version
templates_path = ['_templates']
@@ -48,8 +48,6 @@
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
-
-
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
@@ -62,4 +60,3 @@
'collapse_navigation': False,
'display_version': True,
}
-
diff --git a/docs/findocs.rst b/docs/findocs.rst
new file mode 100644
index 000000000..d30c2bca2
--- /dev/null
+++ b/docs/findocs.rst
@@ -0,0 +1,2 @@
+
+.. include:: ../tuna/miopen/doc/FinDocs.md
diff --git a/doc/index.rst b/docs/index.rst
similarity index 87%
rename from doc/index.rst
rename to docs/index.rst
index 4266b2898..077f0657c 100644
--- a/doc/index.rst
+++ b/docs/index.rst
@@ -7,16 +7,16 @@ Welcome to MITuna
==================
**Advanced Micro Devices, Inc's tuning library.**
-Sources and binaries can be found at `MITuna's GitHub site `_.
+Sources and binaries can be found at `MITuna's GitHub site `_.
.. toctree::
:maxdepth: 8
:caption: Contents:
readme
+ tuning
findocs
- flaskapp
- tuningcycle
+ dbversioning
Indices and tables
==================
diff --git a/doc/readme.rst b/docs/readme.rst
similarity index 61%
rename from doc/readme.rst
rename to docs/readme.rst
index cbfbfbb4a..bdff72a8e 100644
--- a/doc/readme.rst
+++ b/docs/readme.rst
@@ -1,4 +1 @@
-README
-=======
-
.. include:: ../README.md
diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt
new file mode 100644
index 000000000..b8c28cdf8
--- /dev/null
+++ b/docs/sphinx/requirements.txt
@@ -0,0 +1,72 @@
+aioredis==2.0.1
+alembic==1.8.1
+asn1crypto==0.24.0
+astroid==2.15.4
+asyncio==3.4.3
+attrs==19.3.0
+backcall==0.1.0
+bcrypt==3.2
+breathe==4.35.0
+celery==5.3.4
+cryptography==43.0.1
+decorator==4.3.0
+docutils==0.20
+flask==2.2.5
+flower==2.0.1
+idna==3.7
+importlib-metadata>=6.6.0
+jsonargparse==4.19.0
+isort==5.13.2
+jedi==0.13.1
+lazy-object-proxy==1.7.1
+markdown-it-py==3.0.0
+mccabe==0.6.1
+myst-parser==3.0.1
+more-itertools==8.3.0
+numpy==1.24.2
+opentelemetry-api==1.12.0rc2
+opentelemetry-distro==0.32b0
+opentelemetry-exporter-otlp-proto-http==1.11.1
+packaging==24.1
+pandas==1.5.3
+paramiko==3.4.0
+parso==0.3.1
+pathlib2==2.3.5
+pexpect==4.6.0
+pickleshare==0.7.5
+pluggy==0.13.1
+prompt-toolkit==3.0.36
+protobuf<5.0.0dev,>=3.19.5
+ptyprocess==0.6.0
+py==1.10.0
+pyasn1==0.4.4
+pycparser==2.19
+Pygments==2.18.0
+pylint<=2.17.0-dev0,>=2.15.4
+pymysql==1.1.1
+PyNaCl==1.5
+pyparsing==2.4.7
+pytest==7.4.4
+pytest-asyncio==0.21
+pyyaml==6.0
+redis==5.0.1
+six==1.12.0
+sqlalchemy==1.3.23
+sphinx==7.4.7
+sphinx_rtd_theme==2.0.0
+traitlets==4.3.2
+twine==5.1.1
+typed-ast==1.5.4
+types-PyYAML==6.0.12.6
+types-paramiko==3.0.0.4
+types-PyMySQL==1.0.19.5
+wcwidth==0.1.7
+wrapt==1.14.1
+yamllint==1.29.0
+yapf==0.40.2
+zipp==3.19.1
+coverage==7.0.5
+python-logstash-async==3.0.0
+mysql-connector-python
+prometheus_flask_exporter
+tenacity
diff --git a/doc/DBVersioning.md b/docs/src/DBVersioning.md
similarity index 63%
rename from doc/DBVersioning.md
rename to docs/src/DBVersioning.md
index e42fac018..248a005b7 100644
--- a/doc/DBVersioning.md
+++ b/docs/src/DBVersioning.md
@@ -7,19 +7,23 @@ Follow the instructions bellow to start versioning your DB and for sample on how
MITuna's **requirements.txt** file contains the required library to install: alembic.
-You will need the following environment variables:
- *TUNA_DB_NAME=
- *TUNA_DB_USER_NAME=
- *TUNA_DB_USER_PASSWORD=
- *TUNA_DB_HOSTNAME=
-These are used in alembic/env.py to set up the DB connection.
+.. code-block::
+ :caption: You will need the following environment variables:
+
+ export TUNA_DB_NAME=
+ export TUNA_DB_USER_NAME=
+ export TUNA_DB_USER_PASSWORD=
+ export TUNA_DB_HOSTNAME=
+
+These are used in `alembic/env.py` to set up the DB connection.
DB version upgrades/downgrades are located in alembic/versions. The alembic tool and it's '.ini' file are set up to work with the MITuna DB.
To start a migration file follow the steps bellow:
-1. $ alembic revision -m "create account table"
-2. Modify the new versioning file in MITuna/alembic/versions/ with the desired DB changes.
-3. Run the new migraton file: $ alembic upgrade head
+
+ 1. $ alembic revision -m "create account table"
+ 2. Modify the new versioning file in MITuna/alembic/versions/ with the desired DB changes.
+ 3. Run the new migraton file: $ alembic upgrade head
For more details on how to modify the new versining file, or how to execute more complex migrations, follow along the [alembic tutorial](https://alembic.sqlalchemy.org/en/latest/tutorial.html#creating-an-environment)
diff --git a/docs/tuning.rst b/docs/tuning.rst
new file mode 100644
index 000000000..b86b9637d
--- /dev/null
+++ b/docs/tuning.rst
@@ -0,0 +1,10 @@
+Tuning
+************
+
+MITuna is a distributed tuning infrastructure that provides pre-compiled kernels for MIOpen
+customers through automated Jenkins pipelines and SLURM scalable architecture. MITuna also
+provides a scalable task management infrastructure ready to integrate with external libaries.
+The Example library provides a sample on how to achieve this.
+
+.. include:: ../tuna/miopen/doc/Tuning.md
+.. include:: ../tuna/example/doc/Tuning.md
diff --git a/flaskapp/templates/display_keys.html b/flaskapp/templates/display_keys.html
deleted file mode 100644
index 9d3088ca1..000000000
--- a/flaskapp/templates/display_keys.html
+++ /dev/null
@@ -1,27 +0,0 @@
-
-
diff --git a/flaskapp/templates/input-form.html b/flaskapp/templates/input-form.html
deleted file mode 100644
index 12fe603a8..000000000
--- a/flaskapp/templates/input-form.html
+++ /dev/null
@@ -1,5 +0,0 @@
-
diff --git a/flaskapp/views/example_grafana.py b/flaskapp/views/example_grafana.py
deleted file mode 100644
index 92ffe9bf1..000000000
--- a/flaskapp/views/example_grafana.py
+++ /dev/null
@@ -1,81 +0,0 @@
-#!/usr/bin/env python3
-###############################################################################
-#
-# MIT License
-#
-# Copyright (c) 2022 Advanced Micro Devices, Inc.
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-#
-###############################################################################
-"""Module for Grafana plugin"""
-import json
-from flask import Blueprint, jsonify, request
-from tuna.flask_example import get_table_example
-
-grafana = Blueprint('example_grafana', __name__, template_folder='templates')
-
-
-@grafana.route('/search', methods=['POST', 'GET'])
-def search():
- """Entrypoint needed for Grafana plugin"""
- req = request.get_json()
- return jsonify([], [])
-
-
-@grafana.route('/query', methods=['POST', 'GET'])
-def query():
- """Entrypoint needed for Grafana plugin"""
- req = request.get_json()
- grafana_req = req['targets'][0]['target'].split(',')
- data = []
- if grafana_req[0] == 'table-example':
- get_table_example(grafana_req, data)
- else:
- raise ValueError('Unsupported Grafana request: {}'.format(grafana_req[1]))
-
- return jsonify(data)
-
-
-@grafana.route('/annotations', methods=['POST', 'GET'])
-def annotations():
- """Entrypoint needed for Grafana plugin"""
- req = request.get_json()
- data = []
- return jsonify(data)
-
-
-@grafana.route('/tag-keys', methods=['POST', 'GET'])
-def tag_keys():
- """Entrypoint needed for Grafana plugin"""
- req = request.get_json()
- data = []
- return jsonify(data)
-
-
-@grafana.route('/tag-values', methods=['POST', 'GET'])
-def tag_values():
- """Entrypoint needed for Grafana plugin"""
- req = request.get_json()
- data = []
- return jsonify(data)
-
-
-class Object():
- pass
diff --git a/flaskapp/views/fdb_key.py b/flaskapp/views/fdb_key.py
deleted file mode 100644
index a9537a514..000000000
--- a/flaskapp/views/fdb_key.py
+++ /dev/null
@@ -1,72 +0,0 @@
-#!/usr/bin/env python3
-###############################################################################
-#
-# MIT License
-#
-# Copyright (c) 2022 Advanced Micro Devices, Inc.
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-#
-###############################################################################
-"""Module to take in MIOpenDriver cmd and return fdb keys in json format"""
-from flask import request, render_template
-from flask import Blueprint, jsonify
-from tuna.query_db import main_impl
-from tuna.parsing import parse_driver_line
-from tuna.import_configs import config_set_defaults, tag_config_v1
-
-fdb_key = Blueprint('fdb_key', __name__, template_folder='templates')
-
-
-@fdb_key.route('/fdb_key')
-def get_fdb_keys():
- """Takes MIOpenDriver cmd"""
- return render_template('input-form.html')
-
-
-#parse JSON object from Grafana
-@fdb_key.route('/fdb_key', methods=['POST'])
-def post_fdb_keys():
- """POST that takes MIOpenDriver cmd and returns json"""
- cmd = None
- if 'cmd' in request.form:
- cmd = request.form['cmd']
- elif request.is_json:
- json_dict = request.get_json(force=True)
- cmd = json_dict['cmd']
- else:
- raise ValueError('Unsuported operation.')
-
- return_dict = {}
- return_dict['cmd'] = cmd
- fds, precision = get_fds_from_cmd(cmd)
- config_set_defaults(fds)
- fdb_keys = {}
- fdb_keys['F'] = get_pdb_key(fds, precision, 'F')
- fdb_keys['B'] = get_pdb_key(fds, precision, 'B')
- fdb_keys['W'] = get_pdb_key(fds, precision, 'W')
- return_dict['fdb_keys'] = fdb_keys
-
- if request.is_json:
- return jsonify(return_dict['fdb_keys'])
-
- return render_template('display_keys.html',
- result=return_dict,
- driver=cmd,
- config_id=res[0][0])
diff --git a/flaskapp/views/grafana.py b/flaskapp/views/grafana.py
deleted file mode 100644
index 037eb9ec7..000000000
--- a/flaskapp/views/grafana.py
+++ /dev/null
@@ -1,87 +0,0 @@
-#!/usr/bin/env python3
-###############################################################################
-#
-# MIT License
-#
-# Copyright (c) 2022 Advanced Micro Devices, Inc.
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-#
-###############################################################################
-"""Module for Grafana plugin"""
-from flask import Blueprint, jsonify, request
-from tuna.parsing import parse_driver_line
-from tuna.import_configs import config_set_defaults, tag_config_v1
-from tuna.flask import get_timeseries_data, get_tag_data
-from tuna.flask import get_performance_comparison
-
-grafana = Blueprint('grafana', __name__, template_folder='templates')
-
-
-@grafana.route('/search', methods=['POST', 'GET'])
-def search():
- """Entrypoint needed for Grafana plugin"""
- req = request.get_json()
- return jsonify([], [])
-
-
-@grafana.route('/query', methods=['POST', 'GET'])
-def query():
- """Entrypoint needed for Grafana plugin"""
- req = request.get_json()
- grafana_req = req['targets'][0]['target'].split(',')
- data = []
- if grafana_req[0] == 'solver-timeseries':
- get_timeseries_data(grafana_req, data)
- elif grafana_req[0] == 'tag-table':
- get_tag_data(grafana_req, data)
- elif grafana_req[0] == 'performance-comparison':
- get_performance_comparison(grafana_req, data)
- else:
- raise ValueError('Unsupported Grafana request: {}'.format(grafana_req[1]))
-
- return jsonify(data)
-
-
-@grafana.route('/annotations', methods=['POST', 'GET'])
-def annotations():
- """Entrypoint needed for Grafana plugin"""
- req = request.get_json()
- data = []
- return jsonify(data)
-
-
-@grafana.route('/tag-keys', methods=['POST', 'GET'])
-def tag_keys():
- """Entrypoint needed for Grafana plugin"""
- req = request.get_json()
- data = []
- return jsonify(data)
-
-
-@grafana.route('/tag-values', methods=['POST', 'GET'])
-def tag_values():
- """Entrypoint needed for Grafana plugin"""
- req = request.get_json()
- data = []
- return jsonify(data)
-
-
-class Object():
- pass
diff --git a/requirements.txt b/requirements.txt
index 86fcda74c..1b29fc500 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,55 +1,61 @@
+aioredis==2.0.1
alembic==1.8.1
asn1crypto==0.24.0
astroid==2.15.4
+asyncio==3.4.3
attrs==19.3.0
backcall==0.1.0
-bcrypt==3.1.4
-breathe==4.30.0
-cryptography==41.0.2
+bcrypt==3.2
+breathe==4.35.0
+celery==5.3.4
+cryptography==43.0.1
decorator==4.3.0
-docutils<0.17 # sphinx-rtd-theme 0.5.2 requires docutils<0.17
+docutils==0.20
flask==2.2.5
-idna==2.8
-importlib-metadata==6.6.0
+flower==2.0.1
+idna==3.7
+importlib-metadata>=6.6.0
jsonargparse==4.19.0
-isort==4.3.4
+isort==5.13.2
jedi==0.13.1
lazy-object-proxy==1.7.1
-markdown-it-py==2.2.0
+markdown-it-py==3.0.0
mccabe==0.6.1
-myst-parser==0.18.1
+myst-parser==4.0.0
more-itertools==8.3.0
numpy==1.24.2
opentelemetry-api==1.12.0rc2
opentelemetry-distro==0.32b0
opentelemetry-exporter-otlp-proto-http==1.11.1
-packaging==20.4
+packaging==24.1
pandas==1.5.3
-paramiko==2.10.1
+paramiko==3.5.0
parso==0.3.1
pathlib2==2.3.5
pexpect==4.6.0
pickleshare==0.7.5
pluggy==0.13.1
-prompt-toolkit==2.0.7
+prompt-toolkit==3.0.36
protobuf<5.0.0dev,>=3.19.5
ptyprocess==0.6.0
py==1.10.0
pyasn1==0.4.4
pycparser==2.19
-Pygments==2.15.0
+Pygments==2.18.0
pylint<=2.17.0-dev0,>=2.15.4
-pymysql==0.10.0
-PyNaCl==1.3.0
+pymysql==1.1.1
+PyNaCl==1.5
pyparsing==2.4.7
-pytest==6.1.0
+pytest==7.4.4
+pytest-asyncio==0.21
pyyaml==6.0
+redis==5.0.1
six==1.12.0
sqlalchemy==1.3.23
-sphinx==4.1.2
-sphinx_rtd_theme==0.5.2
+sphinx==7.4.7
+sphinx_rtd_theme==2.0.0
traitlets==4.3.2
-twine==4.0.2
+twine==5.1.1
typed-ast==1.5.4
types-PyYAML==6.0.12.6
types-paramiko==3.0.0.4
@@ -58,8 +64,9 @@ wcwidth==0.1.7
wrapt==1.14.1
yamllint==1.29.0
yapf==0.40.2
-zipp==3.8.1
+zipp==3.19.1
coverage==7.0.5
+python-logstash-async==3.0.0
mysql-connector-python
prometheus_flask_exporter
tenacity
diff --git a/setup.py b/setup.py
index fc89a2dcf..1a3e010f1 100644
--- a/setup.py
+++ b/setup.py
@@ -27,6 +27,7 @@
from setuptools import setup, find_packages
import os
+
thelibFolder = os.path.dirname(os.path.realpath(__file__))
requirementPath = thelibFolder + '/requirements.txt'
readmePath = thelibFolder + '/README.md'
@@ -45,14 +46,14 @@
name='MITuna',
python_requires='>=3.9',
#some version number you may wish to add - increment this after every update
- version='1.0',
+ version='2.0',
description="Tuna is a distributed tuning infrastructure that provides pre-compiled kernels "\
"for MIOpen customers through automated Jenkins pipelines and SLURM scalable "\
"architecture. MITuna also provides a scalable task management infrastructure "\
"ready to integrate with external libaries.",
long_description=readme,
license='MIT',
- url='https://github.com/ROCmSoftwarePlatform/MITuna.git',
+ url='https://github.com/ROCm/MITuna.git',
install_requires=install_requires,
# Use one of the below approach to define package and/or module names:
diff --git a/tests/dummy_machine.py b/tests/dummy_machine.py
index f76c1ea98..fc7ea621f 100644
--- a/tests/dummy_machine.py
+++ b/tests/dummy_machine.py
@@ -61,3 +61,6 @@ def exec_command(self, command, docker_name=None, timeout=LOG_TIMEOUT):
ret_code, out, err = self.machine.exec_command(command, docker_name,
timeout)
return ret_code, out, err
+
+ def get_num_cpus(self):
+ return 5
diff --git a/tests/test_abort_file.py b/tests/test_abort_file.py
index abc3ba850..a81960b65 100644
--- a/tests/test_abort_file.py
+++ b/tests/test_abort_file.py
@@ -43,7 +43,7 @@
from tuna.miopen.db.tables import MIOpenDBTables
from tuna.miopen.subcmd.import_configs import import_cfgs
from tuna.utils.db_utility import connect_db
-from tuna.miopen.db.miopen_tables import ConvolutionJob
+from tuna.miopen.db.convolutionjob_tables import ConvolutionJob
from tuna.miopen.subcmd.load_job import test_tag_name as tag_name_test
from utils import add_test_session
from tuna.dbBase.sql_alchemy import DbSession
@@ -101,7 +101,6 @@ def test_abort():
'machine': m,
'gpu_id': gpu_idx,
'num_procs': num_gpus,
- 'barred': v,
'bar_lock': Lock(),
'envmt': ["MIOPEN_LOG_LEVEL=7"],
'reset_interval': False,
diff --git a/tests/test_add_session.py b/tests/test_add_session.py
index 19572c1e3..583631ce2 100644
--- a/tests/test_add_session.py
+++ b/tests/test_add_session.py
@@ -36,7 +36,8 @@
from tuna.machine import Machine
from tuna.dbBase.sql_alchemy import DbSession
-from tuna.worker_interface import WorkerInterface
+#from tuna.worker_interface import WorkerInterface
+from tuna.miopen.worker.fin_class import FinClass
from tuna.miopen.db.session import Session
from utils import DummyArgs
@@ -55,7 +56,6 @@ def test_add_session():
'machine': machine,
'gpu_id': 0,
'num_procs': num_gpus,
- 'barred': v,
'bar_lock': Lock(),
'envmt': ["MIOPEN_LOG_LEVEL=7"],
'reset_interval': False,
@@ -79,7 +79,7 @@ def test_add_session():
args.docker_name = docker_name
args.solver_id = 1
- worker = WorkerInterface(**kwargs)
+ worker = FinClass(**kwargs)
sess_id = Session().add_new_session(args, worker)
print(f"session id: {sess_id}")
assert (sess_id)
diff --git a/tests/test_add_session_rocmlir.py b/tests/test_add_session_rocmlir.py
index 55656960b..d327c71db 100644
--- a/tests/test_add_session_rocmlir.py
+++ b/tests/test_add_session_rocmlir.py
@@ -59,7 +59,6 @@ def test_add_session_rocmlir():
'gpu_id': 0,
'num_procs': num_gpus,
'config_type': ConfigType.convolution,
- 'barred': v,
'bar_lock': Lock(),
'reset_interval': False,
'app_test': False,
diff --git a/tests/test_celery.py b/tests/test_celery.py
new file mode 100644
index 000000000..1e4aa01dd
--- /dev/null
+++ b/tests/test_celery.py
@@ -0,0 +1,314 @@
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2024 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+import os
+import copy
+import json
+import pytest
+from time import sleep
+from multiprocessing import Value
+import aioredis
+import pytest_asyncio
+from sqlalchemy.inspection import inspect
+
+from utils import GoFishArgs, add_test_jobs, add_test_session
+from tuna.dbBase.sql_alchemy import DbSession
+from tuna.utils.machine_utility import load_machines
+from tuna.miopen.db.tables import MIOpenDBTables
+from tuna.miopen.miopen_lib import MIOpen
+from tuna.miopen.utils.config_type import ConfigType
+from tuna.utils.utility import serialize_job_config_row, arch2targetid
+from tuna.miopen.celery_tuning.celery_tasks import prep_kwargs
+from tuna.machine import Machine
+from tuna.libraries import Operation
+from tuna.celery_app.celery_workers import launch_worker_per_node
+from tuna.celery_app.utility import get_q_name
+from tuna.celery_app.celery_app import get_broker_env, get_backend_env
+from tuna.parse_args import TunaArgs, setup_arg_parser
+from tuna.miopen.celery_tuning.celery_tasks import prep_worker
+from tuna.miopen.worker.fin_utils import compose_config_obj, fin_job
+from tuna.miopen.utils.lib_helper import get_worker
+
+@pytest.mark.asyncio
+async def test_celery_workers():
+ miopen = MIOpen()
+ miopen.args = GoFishArgs()
+ miopen.args.label = 'tuna_pytest_celery'
+ miopen.args.session_id = add_test_session(label=miopen.args.label)
+
+ #load jobs
+ dbt = MIOpenDBTables(config_type=ConfigType.convolution)
+ num_jobs = add_test_jobs(miopen, miopen.args.session_id, dbt,
+ miopen.args.label, miopen.args.label,
+ ['miopen_perf_compile'],
+ 'test_add_celery_compile_job',
+ 'miopenConvolutionAlgoGEMM')
+ #assert num_jobs
+ num_jobs = 4
+
+ machine_lst = load_machines(miopen.args)
+ machine = machine_lst[0]
+ miopen.operation = Operation.COMPILE
+ miopen.dbt = MIOpenDBTables(session_id=miopen.args.session_id,
+ config_type=ConfigType.convolution)
+ miopen.args.enqueue_only = False
+ db_name = os.environ['TUNA_DB_NAME']
+
+ #testing get_q_name
+ q_name = get_q_name(miopen, op_compile=True)
+ assert q_name == f"compile_q_{db_name}_sess_{miopen.args.session_id}"
+ q_name = get_q_name(miopen, op_eval=True)
+ assert q_name == f"eval_q_{db_name}_sess_{miopen.args.session_id}"
+
+ #testing prep_tuning
+ _, subp_list = miopen.prep_tuning()
+ assert subp_list
+ for subp in subp_list:
+ subp.kill()
+
+ miopen.args.enqueue_only = True
+ _, subp_list = miopen.prep_tuning()
+ assert subp_list == []
+
+
+ cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{miopen.args.session_id} -Q {q_name}" #pylint: disable=line-too-long
+ #testing launch_worker_per_node
+ subp_list = launch_worker_per_node([machine], cmd, True)
+ #wait for workers to finish launch
+ sleep(5)
+ assert subp_list
+ assert miopen.cancel_consumer(q_name)
+ #wait for celery worker shutdown
+ sleep(5)
+
+ for subp in subp_list:
+ print(subp.pid)
+ assert subp.poll()
+ subp.kill()
+
+ miopen.args.fin_steps = "miopen_perf_compile"
+ miopen.db_name = "test_db"
+ parser = setup_arg_parser(
+ 'Run Performance Tuning on a certain architecture', [
+ TunaArgs.ARCH, TunaArgs.NUM_CU, TunaArgs.VERSION,
+ TunaArgs.CONFIG_TYPE, TunaArgs.SESSION_ID, TunaArgs.MACHINES,
+ TunaArgs.REMOTE_MACHINE, TunaArgs.LABEL, TunaArgs.RESTART_MACHINE,
+ TunaArgs.DOCKER_NAME, TunaArgs.SHUTDOWN_WORKERS
+ ])
+
+ #testing check_fin_args
+ miopen.check_fin_args(parser)
+ #testing set_prefix
+ miopen.set_prefix()
+ assert (miopen.prefix ==
+ f"d_test_db_sess_{miopen.args.session_id}_miopen_perf_compile")
+
+ #testing update_operation
+ miopen.update_operation()
+ assert 'new' in miopen.fetch_state
+ assert miopen.set_state == 'compile_start'
+ assert miopen.operation == Operation.COMPILE
+
+ #testing has_tunable operation
+ assert miopen.has_tunable_operation()
+
+ with DbSession() as session:
+ job_query = session.query(
+ dbt.job_table).filter(dbt.job_table.session == miopen.args.session_id)\
+ .filter(dbt.job_table.reason=='tuna_pytest_celery')
+ job_query.update({dbt.job_table.state: 'compile_start'})
+ session.commit()
+ #testing reset_job_staet_on_ctrl_c
+ miopen.reset_job_state_on_ctrl_c()
+ count = session.query(dbt.job_table).filter(dbt.job_table.session==miopen.args.session_id)\
+ .filter(dbt.job_table.state=='new').count()
+ assert count == num_jobs
+
+ with DbSession() as session:
+ jobs = miopen.get_jobs(session, miopen.fetch_state, miopen.set_state,
+ miopen.args.session_id)
+ assert jobs
+ #testing get_context_list
+ context_list = miopen.get_context_list(session, [job for job in jobs])
+ assert context_list
+ assert len(context_list) == 4
+
+ #testing serialized_job call
+ serialized_jobs = miopen.serialize_jobs(session, [job for job in jobs])
+ assert serialized_jobs
+
+ #testing build_context call
+ context_l = miopen.build_context(serialized_jobs)
+ assert context_l
+
+ #testing get_context_items
+ assert miopen.get_context_items()
+ #testing get_fdb_attr
+ assert miopen.get_fdb_attr()
+
+ entries = [job for job in jobs]
+
+ job_config_rows = miopen.compose_work_objs_fin(session, entries, miopen.dbt)
+ assert job_config_rows
+
+ job_dct, config_dct = serialize_job_config_row(job_config_rows[0])
+ #testing arch2targetid
+ arch = arch2targetid(miopen.dbt.session.arch)
+ assert arch == "gfx90a:sram-ecc+:xnack-"
+ steps = ['alloc_buf', 'fill_buf', miopen.args.fin_steps[0]]
+
+ #testing fin_job
+ fjob = fin_job(steps, True, job_config_rows[0][0], job_config_rows[0][1],
+ miopen.dbt)
+ assert fjob
+ f_vals = miopen.get_f_vals(machine, range(0))
+ kwargs = miopen.get_kwargs(0, f_vals, tuning=True)
+ kwargs['job'] = job_dct
+ kwargs['config'] = config_dct
+ kwargs['avail_gpus'] = 1
+ fdb_attr = [column.name for column in inspect(miopen.dbt.find_db_table).c]
+ fdb_attr.remove("insert_ts")
+ fdb_attr.remove("update_ts")
+ context = {
+ 'job': job_dct,
+ 'config': config_dct,
+ 'operation': Operation.EVAL,
+ 'arch': miopen.dbt.session.arch,
+ 'num_cu': miopen.dbt.session.num_cu,
+ 'kwargs': kwargs,
+ 'fdb_attr': fdb_attr
+ }
+
+ worker = prep_worker(copy.deepcopy(context))
+ worker_kwargs = prep_kwargs(
+ context['kwargs'],
+ [context['job'], context['config'], context['operation']])
+ assert worker_kwargs['config']
+ assert worker_kwargs['job']
+ assert worker_kwargs['fin_steps'] == ['miopen_perf_compile']
+ miopen.operation = Operation.EVAL
+ fin_eval = get_worker(worker_kwargs, miopen.operation)
+
+ #testing fin_job
+ fjob = fin_job(steps, True, job_config_rows[0][0], job_config_rows[0][1],
+ miopen.dbt)
+ #testing fin_pdb_input
+ f_job = fin_eval.fin_pdb_input(fjob)
+ assert f_job[0]['solvers'] == ['GemmBwd1x1_stride2']
+ assert f_job[0]['miopen_perf_compile_result'] == [{
+ 'solver_name': 'GemmBwd1x1_stride2',
+ 'perf_compiled': False,
+ 'kernel_objects': []
+ }]
+
+ #testing fin_fdb_input
+ steps = ['alloc_buf', 'fill_buf', ['miopen_find_compile']]
+ f_job = fin_eval.fin_fdb_input(fjob)
+ assert f_job
+ assert f_job[0]['miopen_find_compile_result'] == [{
+ 'solver_name': 'GemmBwd1x1_stride2',
+ 'find_compiled': False,
+ 'kernel_objects': []
+ }]
+
+ #testing compose_config_obj
+ conf_obj = compose_config_obj(job_config_rows[0][1], ConfigType.convolution)
+ assert conf_obj
+ assert conf_obj[
+ 'driver'] == "./bin/MIOpenDriver conv --batchsize 128 --spatial_dim 2 --pad_h 0 --pad_w 0 --pad_d 0 --conv_stride_h 2 --conv_stride_w 2 --conv_stride_d 0 --dilation_h 1 --dilation_w 1 --dilation_d 0 --group_count 1 --mode conv --pad_mode default --trans_output_pad_h 0 --trans_output_pad_w 0 --trans_output_pad_d 0 --out_layout NCHW --in_layout NCHW --fil_layout NCHW --in_d 1 --in_h 14 --in_w 14 --fil_d 1 --fil_h 1 --fil_w 1 --in_channels 1024 --out_channels 2048 --forw 2"
+
+ miopen.operation = Operation.COMPILE
+ f_vals = miopen.get_f_vals(Machine(local_machine=True), range(0))
+ kwargs = miopen.get_kwargs(0, f_vals, tuning=True)
+ fdb_attr = [column.name for column in inspect(miopen.dbt.find_db_table).c]
+ fdb_attr.remove("insert_ts")
+ fdb_attr.remove("update_ts")
+
+ backend_host = os.environ['TUNA_CELERY_BACKEND_HOST']
+ redis = await aioredis.from_url(f"redis://{backend_host}:6379/15")
+ print('Established redis connection')
+ counter = 1
+
+ res_set = []
+ for elem in job_config_rows:
+ job_dict, config_dict = serialize_job_config_row(elem)
+ context = {
+ 'job': job_dict,
+ 'config': config_dict,
+ 'operation': miopen.operation,
+ 'arch': miopen.dbt.session.arch,
+ 'num_cu': miopen.dbt.session.num_cu,
+ 'kwargs': kwargs,
+ 'fdb_attr': fdb_attr
+ }
+
+ worker = prep_worker(copy.deepcopy(context))
+ worker.dbt = miopen.dbt
+ worker.fin_steps = miopen.args.fin_steps
+ fin_json = worker.run()
+ res_set.append((fin_json, context))
+ await redis.set(f"celery-task-meta-{counter}", json.dumps(fin_json))
+ counter += 1
+
+ print('Consuming from redis')
+ assert miopen.consume(job_counter=counter, prefix=None)
+ await redis.close()
+
+ with DbSession() as session:
+ for fin_json, context in res_set:
+ #testing process_compile_results
+ miopen.process_compile_results(session, fin_json, context)
+ count = session.query(dbt.job_table).filter(
+ dbt.job_table.session == miopen.args.session_id).count()
+ assert count == num_jobs
+
+ with DbSession() as session:
+ job_query = session.query(
+ dbt.job_table).filter(dbt.job_table.session == miopen.args.session_id)\
+ .filter(dbt.job_table.reason=='tuna_pytest_celery')
+ job_query.update({dbt.job_table.state: 'new'})
+ session.commit()
+
+ db_name = os.environ['TUNA_DB_NAME']
+ #testing enqueue_jobs
+ job_counter = Value('i', 4)
+ miopen.enqueue_jobs(job_counter, 1, f"test_{db_name}")
+ print('Done enqueue')
+ with DbSession() as session:
+ count = session.query(dbt.job_table).filter(dbt.job_table.session==miopen.args.session_id)\
+ .filter(dbt.job_table.state=='compile_start').count()
+ assert count == 4
+
+
+TUNA_CELERY_BROKER_HOST, TUNA_CELERY_BROKER_PORT, TUNA_CELERY_BROKER_USER, TUNA_CELERY_BROKER_PWD = get_broker_env(
+)
+assert TUNA_CELERY_BROKER_HOST
+assert TUNA_CELERY_BROKER_PORT
+assert TUNA_CELERY_BROKER_USER
+assert TUNA_CELERY_BROKER_PWD
+
+TUNA_CELERY_BACKEND_PORT, TUNA_CELERY_BACKEND_HOST = get_backend_env()
+assert TUNA_CELERY_BACKEND_PORT
+assert TUNA_CELERY_BACKEND_HOST
diff --git a/tests/test_dbBase.py b/tests/test_dbBase.py
index f26e334fc..c3974f05b 100644
--- a/tests/test_dbBase.py
+++ b/tests/test_dbBase.py
@@ -48,7 +48,6 @@ def connect_db():
db_name = ENV_VARS['db_name']
try:
ENGINE.execute('Use {}'.format(db_name))
- return
except OperationalError: # as err:
LOGGER.warning('Database %s does not exist, attempting to create database',
db_name)
diff --git a/tests/test_driver.py b/tests/test_driver.py
index a23fe7bc7..078037b81 100644
--- a/tests/test_driver.py
+++ b/tests/test_driver.py
@@ -28,11 +28,12 @@
from tuna.miopen.driver.convolution import DriverConvolution
from tuna.miopen.driver.batchnorm import DriverBatchNorm
from tuna.miopen.subcmd.import_configs import insert_config
-from tuna.miopen.db.miopen_tables import ConvolutionConfig, BNConfig
+from tuna.miopen.db.convolutionjob_tables import ConvolutionConfig
+from tuna.miopen.db.batch_norm_tables import BNConfig
from tuna.dbBase.sql_alchemy import DbSession
from tuna.miopen.db.tables import MIOpenDBTables
from tuna.miopen.utils.config_type import ConfigType
-from test_fin_builder import CfgImportArgs
+from utils import CfgImportArgs
def test_driver():
@@ -43,13 +44,13 @@ def test_driver():
counts['cnt_configs'] = 0
counts['cnt_tagged_configs'] = set()
conv_driver(args, logger, dbt, counts)
- bn_driver(args, logger, dbt, counts)
+ bn_driver(args, logger, counts)
def conv_driver(args, logger, dbt, counts):
cmd0 = "./bin/MIOpenDriver conv --pad_h 1 --pad_w 1 --out_channels 128 --fil_w 3 --fil_h 3 --dilation_w 1 --dilation_h 1 --conv_stride_w 1 --conv_stride_h 1 --in_channels 128 --in_w 28 --in_h 28 --in_h 28 --batchsize 256 --group_count 1 --in_d 1 --fil_d 1 --in_layout NHWC --fil_layout NHWC --out_layout NHWC -V 0"
try:
- driver0 = DriverConvolution(cmd0)
+ _ = DriverConvolution(cmd0)
assert False
except ValueError as err:
assert "needs direction" in str(err)
@@ -57,15 +58,15 @@ def conv_driver(args, logger, dbt, counts):
cmd1 = "./bin/MIOpenDriver conv --pad_h 1 --pad_w 1 --out_channels 128 --fil_w 3 --fil_h 3 --dilation_w 1 --dilation_h 1 --conv_stride_w 1 --conv_stride_h 1 --in_channels 128 --in_w 28 --in_h 28 --in_h 28 --batchsize 256 --group_count 1 --in_d 1 --fil_d 1 --forw 1 --out_layout NHWC -V 0"
driver1 = DriverConvolution(cmd1)
d1_str = driver1.to_dict()
- assert (d1_str["fil_h"] == 3)
- assert (d1_str["fil_layout"] == 'NHWC')
- assert (d1_str["in_layout"] == 'NHWC')
- assert (d1_str["out_layout"] == 'NHWC')
- assert (d1_str["in_channels"] == 128)
- assert (d1_str["out_channels"] == 128)
- assert (d1_str["spatial_dim"] == 2)
- assert (d1_str["direction"] == 'F')
- assert (d1_str["cmd"] == 'conv')
+ assert d1_str["fil_h"] == 3
+ assert d1_str["fil_layout"] == 'NHWC'
+ assert d1_str["in_layout"] == 'NHWC'
+ assert d1_str["out_layout"] == 'NHWC'
+ assert d1_str["in_channels"] == 128
+ assert d1_str["out_channels"] == 128
+ assert d1_str["spatial_dim"] == 2
+ assert d1_str["direction"] == 'F'
+ assert d1_str["cmd"] == 'conv'
itensor1 = driver1.get_input_t_id()
assert itensor1
wtensor1 = driver1.get_weight_t_id()
@@ -83,7 +84,7 @@ def conv_driver(args, logger, dbt, counts):
assert driver1 == driver_1_row
c_dict1 = driver1.compose_tensors(keep_id=True)
- assert c_dict1['id'] != None
+ assert c_dict1['id'] is not None
assert c_dict1["input_tensor"]
assert c_dict1["weight_tensor"]
@@ -91,16 +92,16 @@ def conv_driver(args, logger, dbt, counts):
driver2 = DriverConvolution(cmd2)
d2_str = driver2.to_dict()
assert d2_str
- assert (d2_str["cmd"] == 'convfp16')
- assert (d2_str["direction"] == 'B')
- assert (d2_str["in_h"] == 56)
- assert (d2_str["fil_layout"] == 'NCHW')
- assert (d2_str["dilation_d"] == 0)
- assert (d2_str["conv_stride_d"] == 0)
- assert (d2_str["fil_w"] == 1)
- assert (d2_str["out_channels"] == 64)
- itensor2 = driver2.get_input_t_id()
- wtensor2 = driver2.get_weight_t_id()
+ assert d2_str["cmd"] == 'convfp16'
+ assert d2_str["direction"] == 'B'
+ assert d2_str["in_h"] == 56
+ assert d2_str["fil_layout"] == 'NCHW'
+ assert d2_str["dilation_d"] == 0
+ assert d2_str["conv_stride_d"] == 0
+ assert d2_str["fil_w"] == 1
+ assert d2_str["out_channels"] == 64
+ assert driver2.get_input_t_id()
+ assert driver2.get_weight_t_id()
c_dict2 = driver2.compose_tensors(keep_id=True)
assert c_dict2["input_tensor"]
assert c_dict2["weight_tensor"]
@@ -116,36 +117,36 @@ def conv_driver(args, logger, dbt, counts):
fdb1 = "64-75-75-3x3-64-75-75-512-1x1-1x1-1x1-0-NHWC-FP16-W="
driver4 = DriverConvolution(fdb1)
- d4_str = driver4.__str__()
+ d4_str = str(driver4)
d4_dict = driver4.to_dict()
- assert (d4_dict["in_layout"] == "NHWC")
- assert (d4_dict["out_layout"] == "NHWC")
- assert (d4_dict["fil_layout"] == "NHWC")
+ assert d4_dict["in_layout"] == "NHWC"
+ assert d4_dict["out_layout"] == "NHWC"
+ assert d4_dict["fil_layout"] == "NHWC"
driver5 = DriverConvolution(d4_str)
assert driver4 == driver5
d5_dict = driver5.to_dict()
- assert (d5_dict["in_layout"] == "NHWC")
- assert (d5_dict["out_layout"] == "NHWC")
- assert (d5_dict["fil_layout"] == "NHWC")
+ assert d5_dict["in_layout"] == "NHWC"
+ assert d5_dict["out_layout"] == "NHWC"
+ assert d5_dict["fil_layout"] == "NHWC"
-def bn_driver(args, logger, dbt, counts):
+def bn_driver(args, logger, counts):
cmd3 = "./bin/MIOpenDriver bnormfp16 -n 256 -c 64 -H 56 -W 56 -m 1 --forw 1 -b 0 -s 1 -r 1"
args.config_type = ConfigType.batch_norm
dbt2 = MIOpenDBTables(session_id=None, config_type=args.config_type)
driver3 = DriverBatchNorm(cmd3)
d3_str = driver3.to_dict()
assert d3_str
- assert (d3_str["forw"] == 1)
- assert (d3_str["back"] == 0)
- assert (d3_str["cmd"] == 'bnormfp16')
- assert (d3_str["mode"] == 1)
- assert (d3_str["run"] == 1)
- assert (d3_str["in_channels"] == 64)
- assert (d3_str["alpha"] == 1)
- assert (d3_str["beta"] == 0)
- assert (d3_str["direction"] == 'F')
- itensor3 = driver3.get_input_t_id()
+ assert d3_str["forw"] == 1
+ assert d3_str["back"] == 0
+ assert d3_str["cmd"] == 'bnormfp16'
+ assert d3_str["mode"] == 1
+ assert d3_str["run"] == 1
+ assert d3_str["in_channels"] == 64
+ assert d3_str["alpha"] == 1
+ assert d3_str["beta"] == 0
+ assert d3_str["direction"] == 'F'
+ assert driver3.get_input_t_id()
c_dict3 = driver3.compose_tensors(keep_id=True)
assert c_dict3["input_tensor"]
assert c_dict3
diff --git a/tests/test_example.py b/tests/test_example.py
index 08e2b888b..2b3690083 100644
--- a/tests/test_example.py
+++ b/tests/test_example.py
@@ -26,6 +26,7 @@
import os
import sys
+from time import sleep
sys.path.append("../tuna")
sys.path.append("tuna")
@@ -33,44 +34,65 @@
this_path = os.path.dirname(__file__)
from tuna.example.example_lib import Example
+from utils import GoFishArgs, add_test_session
from utils import ExampleArgs
-from tuna.utils.miopen_utility import load_machines
+from tuna.utils.machine_utility import load_machines
from tuna.dbBase.sql_alchemy import DbSession
from tuna.example.session import SessionExample
from tuna.example.example_tables import Job
from tuna.example.load_job import add_jobs
+from tuna.libraries import Operation
from tuna.example.tables import ExampleDBTables
+from tuna.example.session import SessionExample
+from tuna.celery_app.celery_workers import launch_worker_per_node
+from tuna.celery_app.utility import get_q_name
def test_example():
example = Example()
example.args = ExampleArgs()
- assert (example.add_tables())
+ example.args = GoFishArgs()
+ example.args.label = 'tuna_pytest_example'
+ example.args.session_id = add_test_session(label=example.args.label,
+ session_table=SessionExample)
+ example.operation = Operation.COMPILE
+ example.args.arch = "gfx90a"
+ example.args.num_cu = 104
+ example.add_tables()
+ example.dbt = ExampleDBTables(session_id=example.args.session_id)
- res = load_machines(example.args)
- res = example.compose_worker_list(res)
+ machines = load_machines(example.args)
+ res = example.compose_worker_list(machines)
with DbSession() as session:
query = session.query(SessionExample)
res = query.all()
assert len(res) is not None
#test load_job
- dbt = ExampleDBTables(session_id=None)
example.args.init_session = False
example.args.session_id = 1
example.args.execute = True
example.args.label = 'test_example'
- example.args.config = 1
- num_jobs = add_jobs(example.args, dbt)
- assert num_jobs
+ example.args.execute = True
+ #assert num_jobs
- #testing execute rocminfo
- res = load_machines(example.args)
- res = example.compose_worker_list(res)
- with DbSession() as session:
- query = session.query(Job).filter(Job.session==1)\
- .filter(Job.state=='completed')
- res = query.all()
- assert res
+ example.args.execute = None
+ example.args.enqueue_only = True
+ db_name = os.environ['TUNA_DB_NAME']
+ _, subp_list = example.prep_tuning()
+ assert subp_list == []
+
+
+ cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{example.args.session_id} -Q test_{db_name}" #pylint: disable=line-too-long
+ #testing launch_worker_per_node
+ machine = machines[0]
+ subp_list = launch_worker_per_node([machine], cmd, True)
+ #wait for workers to finish launch
+ sleep(5)
+ assert subp_list
+
+ for subp in subp_list:
+ print(subp.pid)
+ subp.kill()
- return True
+ assert example.has_tunable_operation()
diff --git a/tests/test_export_db.py b/tests/test_export_db.py
index f6a499c1a..445c4a846 100644
--- a/tests/test_export_db.py
+++ b/tests/test_export_db.py
@@ -48,7 +48,7 @@
from tuna.utils.db_utility import DB_Type
from tuna.miopen.db.tables import MIOpenDBTables, ConfigType
from tuna.dbBase.sql_alchemy import DbSession
-from tuna.utils.db_utility import get_id_solvers
+from tuna.miopen.db.solver import get_id_solvers
from utils import add_test_session, DummyArgs, CfgEntry, TensorEntry, build_fdb_entry
session_id = add_test_session(arch='gfx90a',
diff --git a/tests/test_fin_builder.py b/tests/test_fin_builder.py
index e09927f79..948c4a0a5 100644
--- a/tests/test_fin_builder.py
+++ b/tests/test_fin_builder.py
@@ -24,97 +24,80 @@
#
###############################################################################
-import os
-import sys
-
-sys.path.append("../tuna")
-sys.path.append("tuna")
-
-this_path = os.path.dirname(__file__)
+import copy
+from sqlalchemy.inspection import inspect
from tuna.dbBase.sql_alchemy import DbSession
-from tuna.utils.miopen_utility import load_machines
from tuna.miopen.db.tables import MIOpenDBTables
-from tuna.miopen.worker.fin_class import FinClass
-from tuna.utils.db_utility import connect_db
-from tuna.miopen.subcmd.import_configs import import_cfgs
-from tuna.miopen.subcmd.load_job import test_tag_name as tag_name_test, add_jobs
-from utils import CfgImportArgs, LdJobArgs, GoFishArgs
-from utils import get_worker_args, add_test_session
from tuna.miopen.miopen_lib import MIOpen
-from tuna.miopen.utils.metadata import ALG_SLV_MAP
-from tuna.utils.db_utility import get_solver_ids
-from tuna.utils.logger import setup_logger
-
-
-def add_cfgs():
- #import configs
- args = CfgImportArgs()
- args.tag = 'tuna_pytest_fin_builder'
- args.mark_recurrent = True
- args.file_name = f"{this_path}/../utils/configs/conv_configs_NCHW.txt"
-
- dbt = MIOpenDBTables(config_type=args.config_type)
- counts = import_cfgs(args, dbt, setup_logger('test_fin_builder'))
- return dbt
-
-
-def add_fin_find_compile_job(session_id, dbt):
- #load jobs
- args = LdJobArgs
- args.label = 'tuna_pytest_fin_builder'
- args.tag = 'tuna_pytest_fin_builder'
- args.fin_steps = ['miopen_find_compile', 'miopen_find_eval']
- args.session_id = session_id
- logger = setup_logger('test_add_fin_find_compile_job')
-
- #limit job scope
- args.algo = "miopenConvolutionAlgoGEMM"
- solver_arr = ALG_SLV_MAP[args.algo]
- solver_id_map = get_solver_ids()
- if solver_arr:
- solver_ids = []
- for solver in solver_arr:
- sid = solver_id_map.get(solver, None)
- solver_ids.append((solver, sid))
- args.solvers = solver_ids
- args.only_applicable = True
-
- connect_db()
- return add_jobs(args, dbt, logger)
+from tuna.miopen.utils.config_type import ConfigType
+from tuna.utils.utility import serialize_job_config_row
+from tuna.libraries import Operation
+from tuna.miopen.celery_tuning.celery_tasks import prep_worker
+from tuna.machine import Machine
+from utils import GoFishArgs
+from utils import add_test_session, add_test_jobs
def test_fin_builder():
miopen = MIOpen()
miopen.args = GoFishArgs()
- machine_lst = load_machines(miopen.args)
- machine = machine_lst[0]
miopen.args.label = 'tuna_pytest_fin_builder'
- miopen.args.session_id = add_test_session(label='tuna_pytest_fin_builder')
-
- #update solvers
- kwargs = get_worker_args(miopen.args, machine, miopen)
- fin_worker = FinClass(**kwargs)
- assert (fin_worker.get_solvers())
-
- #get applicability
- dbt = add_cfgs()
- miopen.args.update_applicability = True
- worker_lst = miopen.compose_worker_list(machine_lst)
- for worker in worker_lst:
- worker.join()
+ miopen.args.session_id = add_test_session(label=miopen.args.label)
#load jobs
- miopen.args.label = 'tuna_pytest_fin_builder'
- num_jobs = add_fin_find_compile_job(miopen.args.session_id, dbt)
-
- #compile
+ dbt = MIOpenDBTables(config_type=ConfigType.convolution)
+ num_jobs = add_test_jobs(miopen, miopen.args.session_id, dbt,
+ miopen.args.label, miopen.args.label,
+ ['miopen_find_compile', 'miopen_find_eval'],
+ 'test_add_fin_find_compile_job',
+ 'miopenConvolutionAlgoGEMM')
+ assert num_jobs
+
+ #testing process_fdb_compile in process_compile_results
miopen.args.update_applicability = False
miopen.args.fin_steps = ["miopen_find_compile"]
- miopen.args.label = 'tuna_pytest_fin_builder'
- worker_lst = miopen.compose_worker_list(machine_lst)
- for worker in worker_lst:
- worker.join()
+ miopen.fetch_state.add('new')
+ miopen.operation = Operation.COMPILE
+ miopen.set_state = 'compile_start'
+ miopen.dbt = MIOpenDBTables(session_id=miopen.args.session_id,
+ config_type=ConfigType.convolution)
+ jobs = None
+ with DbSession() as session:
+ jobs = miopen.get_jobs(session, miopen.fetch_state, miopen.set_state,
+ miopen.args.session_id)
+ entries = list(jobs)
+ job_config_rows = miopen.compose_work_objs_fin(session, entries, miopen.dbt)
+ assert job_config_rows
+
+ f_vals = miopen.get_f_vals(Machine(local_machine=True), range(0))
+ kwargs = miopen.get_kwargs(0, f_vals, tuning=True)
+ fdb_attr = [column.name for column in inspect(miopen.dbt.find_db_table).c]
+ fdb_attr.remove("insert_ts")
+ fdb_attr.remove("update_ts")
+
+ res_set = []
+ for elem in job_config_rows:
+ job_dict, config_dict = serialize_job_config_row(elem)
+ context = {
+ 'job': job_dict,
+ 'config': config_dict,
+ 'operation': miopen.operation,
+ 'arch': miopen.dbt.session.arch,
+ 'num_cu': miopen.dbt.session.num_cu,
+ 'kwargs': kwargs,
+ 'fdb_attr': fdb_attr
+ }
+
+ worker = prep_worker(copy.deepcopy(context))
+ worker.dbt = miopen.dbt
+ worker.fin_steps = miopen.args.fin_steps
+ fin_json = worker.run()
+ res_set.append((fin_json, context))
+
+ with DbSession() as session:
+ for fin_json, context in res_set:
+ miopen.process_compile_results(session, fin_json, context)
with DbSession() as session:
valid_fin_err = session.query(dbt.job_table).filter(dbt.job_table.session==miopen.args.session_id)\
@@ -125,4 +108,4 @@ def test_fin_builder():
num_jobs = (num_jobs - valid_fin_err)
count = session.query(dbt.job_table).filter(dbt.job_table.session==miopen.args.session_id)\
.filter(dbt.job_table.state=='compiled').count()
- assert (count == num_jobs)
+ assert count == num_jobs
diff --git a/tests/test_fin_class.py b/tests/test_fin_class.py
index b8bc81ef5..75d18ab26 100644
--- a/tests/test_fin_class.py
+++ b/tests/test_fin_class.py
@@ -49,7 +49,6 @@ def test_set_all_configs():
'machine': DummyMachine(False),
'gpu_id': 0,
'num_procs': num_gpus,
- 'barred': v,
'bar_lock': Lock(),
'envmt': ["MIOPEN_LOG_LEVEL=7"],
'reset_interval': False,
diff --git a/tests/test_fin_evaluator.py b/tests/test_fin_evaluator.py
index bf38fead9..a8848a5ec 100644
--- a/tests/test_fin_evaluator.py
+++ b/tests/test_fin_evaluator.py
@@ -27,16 +27,12 @@
import json
import os
import sys
-from multiprocessing import Value, Lock, Queue
+import copy
+from sqlalchemy.inspection import inspect
-sys.path.append("../tuna")
-sys.path.append("tuna")
-
-this_path = os.path.dirname(__file__)
-
-from dummy_machine import DummyMachine
+from utils import CfgImportArgs, LdJobArgs, GoFishArgs
+from utils import get_worker_args, add_test_session
from tuna.dbBase.sql_alchemy import DbSession
-from tuna.miopen.worker.fin_eval import FinEvaluator
from tuna.miopen.db.tables import MIOpenDBTables
from tuna.miopen.miopen_lib import MIOpen
from tuna.miopen.subcmd.import_configs import import_cfgs
@@ -44,40 +40,22 @@
from tuna.miopen.utils.config_type import ConfigType
from tuna.miopen.utils.metadata import ALG_SLV_MAP
from tuna.miopen.worker.fin_class import FinClass
-from tuna.utils.db_utility import get_solver_ids
+from tuna.miopen.db.solver import get_solver_ids
from tuna.utils.db_utility import connect_db
from tuna.utils.logger import setup_logger
-from tuna.utils.miopen_utility import load_machines
-from utils import CfgImportArgs, LdJobArgs, GoFishArgs
-from utils import get_worker_args, add_test_session
+from tuna.utils.machine_utility import load_machines
+from tuna.miopen.celery_tuning.celery_tasks import prep_kwargs
+from tuna.miopen.utils.lib_helper import get_worker
+from tuna.utils.utility import serialize_job_config_row
+from tuna.libraries import Operation
+from tuna.miopen.celery_tuning.celery_tasks import prep_worker
-solver_id_map = get_solver_ids()
+sys.path.append("../tuna")
+sys.path.append("tuna")
+this_path = os.path.dirname(__file__)
-def get_kwargs(dbt):
- num_gpus = Value('i', 1)
- v = Value('i', 0)
- e = Value('i', 0)
-
- kwargs = {
- 'machine': DummyMachine(False),
- 'gpu_id': 0,
- 'num_procs': num_gpus,
- 'barred': v,
- 'bar_lock': Lock(),
- 'envmt': ["MIOPEN_LOG_LEVEL=7"],
- 'reset_interval': False,
- 'app_test': False,
- 'label': 'tuna_pytest_fin_eval',
- 'fin_steps': ['miopen_find_eval'],
- 'use_tuner': False,
- 'job_queue': Queue(),
- 'queue_lock': Lock(),
- 'fetch_state': ['compiled'],
- 'end_jobs': e,
- 'session_id': dbt.session_id
- }
- return kwargs
+solver_id_map = get_solver_ids()
def add_cfgs():
@@ -158,119 +136,108 @@ def test_fin_evaluator():
#update solvers
kwargs = get_worker_args(miopen.args, machine, miopen)
fin_worker = FinClass(**kwargs)
- assert (fin_worker.get_solvers())
+ assert fin_worker.get_solvers()
add_cfgs()
dbt = MIOpenDBTables(config_type=ConfigType.convolution,
session_id=miopen.args.session_id)
- #set all applicable
- with DbSession() as session:
- configs = session.query(dbt.config_tags_table.config).filter(
- dbt.config_tags_table.tag == 'tuna_pytest_fin_eval').all()
- configs = [x[0] for x in configs]
- print(configs)
- for solver in solver_id_map.values():
- for config in configs:
- slv_app_entry = dbt.solver_app()
- slv_app_entry.config = config
- slv_app_entry.solver = solver
- slv_app_entry.session = dbt.session_id
- slv_app_entry.applicable = True
- session.add(slv_app_entry)
- session.commit()
+ args = GoFishArgs()
+ machine_lst = load_machines(args)
+ miopen.args.update_applicability = True
- #load jobs
- miopen.args.label = 'tuna_pytest_fin_eval'
- add_fin_find_eval_job(miopen.args.session_id, dbt)
+ worker_lst = miopen.compose_worker_list(machine_lst)
+ for worker in worker_lst:
+ worker.join()
- with DbSession() as session:
- job_query = session.query(
- dbt.job_table).filter(dbt.job_table.session == miopen.args.session_id)
- job_query.update({dbt.job_table.state: 'compiled'})
- session.commit()
-
- add_fake_fdb_entries(job_query, dbt, job_query.first().id)
+ #load jobs
+ args = LdJobArgs
+ args.label = 'tuna_pytest_fin_eval'
+ args.tag = 'tuna_pytest_fin_eval'
+ args.fin_steps = ['miopen_find_eval']
+ args.session_id = miopen.args.session_id
- # test get_job true branch
- kwargs = get_kwargs(dbt)
- fin_eval = FinEvaluator(**kwargs)
- ans = fin_eval.get_job('compiled', 'eval_start', False)
- assert (ans is True)
- fin_eval.set_job_state('evaluating')
+ logger = setup_logger('test_fin_evaluator')
+ num_jobs = add_jobs(args, dbt, logger)
+ assert num_jobs > 0
+ miopen.args.fin_steps = ["miopen_find_eval"]
+ miopen.args.label = 'tuna_pytest_fin_eval'
+ miopen.fetch_state.add('new')
+ miopen.operation = Operation.EVAL
+ miopen.set_state = 'eval_start'
+ miopen.dbt = MIOpenDBTables(session_id=miopen.args.session_id,
+ config_type=ConfigType.convolution)
with DbSession() as session:
- count = session.query(dbt.job_table).filter(dbt.job_table.state=='evaluating')\
- .filter(dbt.job_table.reason=='tuna_pytest_fin_eval')\
- .filter(dbt.job_table.session==dbt.session_id).count()
- assert (count == 1)
+ jobs = miopen.get_jobs(session, miopen.fetch_state, miopen.set_state,
+ miopen.args.session_id)
+ entries = list(jobs)
+ job_config_rows = miopen.compose_work_objs_fin(session, entries, miopen.dbt)
+ assert job_config_rows
+
+ f_vals = miopen.get_f_vals(machine, range(0))
+ kwargs = miopen.get_kwargs(0, f_vals, tuning=True)
+
+ kwargs['avail_gpus'] = 1
+ fdb_attr = [column.name for column in inspect(miopen.dbt.find_db_table).c]
+ fdb_attr.remove("insert_ts")
+ fdb_attr.remove("update_ts")
+
+ res_set = []
+ for elem in job_config_rows:
+ job_dict, config_dict = serialize_job_config_row(elem)
+ context = {
+ 'job': job_dict,
+ 'config': config_dict,
+ 'operation': miopen.operation,
+ 'arch': miopen.dbt.session.arch,
+ 'num_cu': miopen.dbt.session.num_cu,
+ 'kwargs': kwargs,
+ 'fdb_attr': fdb_attr
+ }
+
+ worker = prep_worker(copy.deepcopy(context))
+ worker.dbt = miopen.dbt
+ worker.fin_steps = miopen.args.fin_steps
+ fin_json = worker.run()
+ res_set.append((fin_json, context))
- # test get_fin_input
- file_name = fin_eval.get_fin_input()
- assert (file_name)
-
- # test check gpu with "bad" GPU
- # the job state will set back to "compiled" from "evaluating"
- fin_eval.check_gpu()
- with DbSession() as session:
- count = session.query(dbt.job_table).filter(dbt.job_table.state=='evaluating')\
- .filter(dbt.job_table.reason=='tuna_pytest_fin_eval')\
- .filter(dbt.job_table.session==dbt.session_id).count()
- assert (count == 0)
-
- # test check gpu with "good" GPU
- # the job state will remain 'evaluated'
- ans = fin_eval.get_job('compiled', 'eval_start', False)
- assert (ans is True)
- fin_eval.set_job_state('evaluated')
- fin_eval.machine.set_gpu_state(True)
- fin_eval.check_gpu()
with DbSession() as session:
- count = session.query(dbt.job_table).filter(dbt.job_table.state=='evaluated')\
- .filter(dbt.job_table.reason=='tuna_pytest_fin_eval')\
- .filter(dbt.job_table.session==dbt.session_id).count()
- assert (count == 1)
+ for fin_json, context in res_set:
+ #testing process_fin_evaluator results
+ miopen.process_eval_results(session, fin_json, context)
with DbSession() as session:
- session.query(dbt.job_table).filter(dbt.job_table.session==dbt.session_id)\
- .filter(dbt.job_table.state=='compiled')\
- .filter(dbt.job_table.reason=='tuna_pytest_fin_eval')\
- .filter(dbt.job_table.session==dbt.session_id)\
- .update({dbt.job_table.state: 'evaluated'})
- session.commit()
-
- #test get_job false branch
- kwargs = get_kwargs(dbt)
- fin_eval = FinEvaluator(**kwargs)
- ans = fin_eval.get_job('compiled', 'eval_start', False)
- assert (ans is False)
+ valid_fin_err = session.query(dbt.job_table).filter(dbt.job_table.session==miopen.args.session_id)\
+ .filter(dbt.job_table.state=='errored')\
+ .filter(dbt.job_table.result.contains('%Find Compile: No results%'))\
+ .count()
+ #ommiting valid Fin/MIOpen errors
+ num_jobs = num_jobs - valid_fin_err
+ count = session.query(dbt.job_table).filter(dbt.job_table.session==miopen.args.session_id)\
+ .filter(dbt.job_table.state=='evaluated').count()
+ assert count == num_jobs
+
+ assert kwargs['fin_steps'] == ['miopen_find_eval']
+
+ job_config = job_config_rows[0]
+ job_dict, config_dict = serialize_job_config_row(job_config)
+ #testing prep_kwargs
+ worker_kwargs = prep_kwargs(
+ context['kwargs'],
+ [context['job'], context['config'], context['operation']])
+ assert worker_kwargs['config']
+ assert worker_kwargs['job']
+ assert worker_kwargs['fin_steps'] == ['miopen_find_eval']
+ fin_eval = get_worker(worker_kwargs, miopen.operation)
+
+ #testing check_gpu
+ fin_eval.check_gpu()
- #test FinEvaluator process results
- job_first = job_query.first()
- with DbSession() as session:
- fin_eval.job = job_first
- fin_eval.config = session.query(dbt.config_table)\
- .filter(dbt.config_table.id == job_first.config).first()
+ # test get_fin_input
+ file_name = fin_eval.get_fin_input()
+ assert file_name
find_eval_file = f"{this_path}/../utils/test_files/fin_output_find_eval.json"
fin_json = json.loads(machine.read_file(find_eval_file))[1:]
assert len(fin_json) == 1
- fin_json = fin_json[0]
- status = fin_eval.process_fdb_eval(fin_json)
- for obj in status:
- print(obj)
- assert (obj['success'] == True)
-
- #test FinEvaluator close_job
- with DbSession() as session:
- session.query(
- dbt.job_table).filter(dbt.job_table.id == fin_eval.job.id).update(
- {dbt.job_table.state: 'compiled'})
- session.commit()
- assert ('compiled' == session.query(dbt.job_table.state).filter(
- dbt.job_table.id == fin_eval.job.id).first()[0].name)
-
- fin_eval.close_job()
- with DbSession() as session:
- assert ('evaluated' == session.query(dbt.job_table.state).filter(
- dbt.job_table.id == fin_eval.job.id).first()[0].name)
diff --git a/tests/test_fin_utils.py b/tests/test_fin_utils.py
index bedc3f304..fe8a47a0b 100644
--- a/tests/test_fin_utils.py
+++ b/tests/test_fin_utils.py
@@ -25,11 +25,12 @@
#
###############################################################################
import tuna.miopen.worker.fin_utils as fu
-from tuna.miopen.db.miopen_tables import ConvolutionConfig, ConvolutionJob, TensorTable
+from tuna.miopen.db.convolutionjob_tables import ConvolutionConfig, ConvolutionJob
+from tuna.miopen.db.tensortable import TensorTable
from multiprocessing import Value, Lock, Queue
from tuna.utils.metadata import LOG_TIMEOUT
-from tuna.miopen.db.session import Session
from tuna.miopen.db.tables import MIOpenDBTables, ConfigType
+from tuna.miopen.db.session import Session
def test_fin_utils():
diff --git a/tests/test_helper.py b/tests/test_helper.py
new file mode 100644
index 000000000..771929302
--- /dev/null
+++ b/tests/test_helper.py
@@ -0,0 +1,44 @@
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2024 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+from tuna.lib_utils import get_library
+from tuna.miopen.miopen_lib import MIOpen
+from tuna.example.example_lib import Example
+from tuna.rocmlir.rocmlir_lib import RocMLIR
+from tuna.libraries import Library
+
+
+def test_helper():
+ lib_dict = {"lib": Library.MIOPEN}
+ lib = get_library(lib_dict)
+ assert type(lib) == MIOpen
+
+ lib_dict['lib'] = Library.EXAMPLE
+ lib = get_library(lib_dict)
+ assert type(lib) == Example
+
+ lib_dict['lib'] = Library.ROCMLIR
+ lib = get_library(lib_dict)
+ assert type(lib) == RocMLIR
diff --git a/tests/test_importconfigs.py b/tests/test_importconfigs.py
index 43b45ec97..7d53f11ba 100644
--- a/tests/test_importconfigs.py
+++ b/tests/test_importconfigs.py
@@ -40,7 +40,7 @@
from tuna.miopen.db.tables import MIOpenDBTables, ConfigType
from utils import CfgImportArgs
from tuna.miopen.db.benchmark import Framework, ModelEnum, FrameworkEnum
-from tuna.miopen.db.miopen_tables import ConvolutionBenchmark
+from tuna.miopen.db.convolutionjob_tables import ConvolutionBenchmark
from utils import DummyArgs
diff --git a/tests/test_load_job.py b/tests/test_load_job.py
index 44b33ff89..f448ed8b6 100644
--- a/tests/test_load_job.py
+++ b/tests/test_load_job.py
@@ -37,15 +37,11 @@
from tuna.miopen.subcmd.load_job import arg_fin_steps, arg_solvers
from tuna.miopen.subcmd.load_job import config_query, compose_query
-from tuna.miopen.subcmd.load_job import add_jobs, run_load_job
-from tuna.utils.db_utility import get_solver_ids
-from tuna.miopen.utils.metadata import ALG_SLV_MAP, TENSOR_PRECISION
+from tuna.miopen.db.solver import get_solver_ids
+from tuna.miopen.utils.metadata import ALG_SLV_MAP
from tuna.miopen.db.tables import MIOpenDBTables, ConfigType
from tuna.dbBase.sql_alchemy import DbSession
-from tuna.sql import DbCursor
-from tuna.utils.logger import setup_logger
from utils import LdJobArgs
-from tuna.utils.db_utility import connect_db
#arg_fin_steps function
diff --git a/tests/test_mituna_interface.py b/tests/test_mituna_interface.py
new file mode 100644
index 000000000..5fe25505e
--- /dev/null
+++ b/tests/test_mituna_interface.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2024 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+import os
+import sys
+
+from tuna.mituna_interface import MITunaInterface
+from tuna.worker_interface import WorkerInterface
+from tuna.miopen.miopen_lib import MIOpen
+from tuna.utils.machine_utility import load_machines
+from utils import GoFishArgs, add_test_session
+
+sys.path.append("../tuna")
+sys.path.append("tuna")
+
+this_path = os.path.dirname(__file__)
+
+
+def test_mituna_interface():
+
+ miopen = MIOpen()
+ miopen.args = GoFishArgs()
+ mituna = MITunaInterface()
+ miopen.args.session_id = add_test_session(label=miopen.args.label)
+ machine_lst = load_machines(miopen.args)
+ machine = machine_lst[0]
+ worker = WorkerInterface(**{
+ 'machine': machine,
+ 'session_id': miopen.args.session_id
+ })
+
+ try:
+ _ = mituna.check_docker(worker, 'DoesNotExist')
+ except Exception as exp:
+ assert isinstance(exp, ValueError)
diff --git a/tests/test_rocmlir.py b/tests/test_rocmlir.py
index 9756a7c45..ea5b38b93 100644
--- a/tests/test_rocmlir.py
+++ b/tests/test_rocmlir.py
@@ -114,6 +114,6 @@ def test_rocmlir():
query = session.query(ConvolutionJob).filter(ConvolutionJob.session==session_id)\
.filter(ConvolutionJob.state=='error')
res = query.all()
- assert len(res) == 6, f"Should be 6 'error' jobs and there are {len(res)}"
+ #assert len(res) == 6, f"Should be 6 'error' jobs and there are {len(res)}"
return True
diff --git a/tests/test_update_golden.py b/tests/test_update_golden.py
index 6442d674c..fb8509e19 100644
--- a/tests/test_update_golden.py
+++ b/tests/test_update_golden.py
@@ -32,7 +32,7 @@
from tuna.miopen.db.tables import MIOpenDBTables
from tuna.dbBase.sql_alchemy import DbSession
from tuna.miopen.utils.config_type import ConfigType
-from tuna.miopen.db.miopen_tables import ConvolutionGolden
+from tuna.miopen.db.convolutionjob_tables import ConvolutionGolden
from tuna.miopen.db.find_db import ConvolutionFindDB
from utils import add_test_session, DummyArgs, build_fdb_entry
diff --git a/tests/test_worker.py b/tests/test_worker.py
index 09f0b3598..a484b6dc9 100644
--- a/tests/test_worker.py
+++ b/tests/test_worker.py
@@ -35,7 +35,7 @@
this_path = os.path.dirname(__file__)
-from tuna.utils.miopen_utility import load_machines
+from tuna.utils.machine_utility import load_machines
from tuna.miopen.worker.fin_class import FinClass
from tuna.machine import Machine
from tuna.sql import DbCursor
@@ -187,7 +187,6 @@ def test_worker():
'machine': machine,
'gpu_id': 0,
'num_procs': num_gpus,
- 'barred': v,
'bar_lock': Lock(),
'envmt': ["MIOPEN_LOG_LEVEL=7"],
'reset_interval': False,
diff --git a/tests/utils.py b/tests/utils.py
index 497bb424f..938cc4a39 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,4 +1,4 @@
-###############################################################################
+#############################################################################
#
# MIT License
#
@@ -23,15 +23,31 @@
# SOFTWARE.
#
###############################################################################
-
+import os
+import sys
+import copy
from multiprocessing import Value
-from tuna.worker_interface import WorkerInterface
+sys.path.append("../tuna")
+sys.path.append("tuna")
+
+this_path = os.path.dirname(__file__)
+
from tuna.miopen.db.session import Session
-from tuna.machine import Machine
from tuna.miopen.utils.config_type import ConfigType
from tuna.miopen.db.find_db import ConvolutionFindDB
from tuna.miopen.miopen_lib import MIOpen
+from tuna.miopen.db.solver import get_solver_ids
+from tuna.utils.logger import setup_logger
+from tuna.miopen.utils.metadata import ALG_SLV_MAP
+from tuna.utils.db_utility import connect_db
+from tuna.miopen.subcmd.import_configs import import_cfgs
+from tuna.miopen.subcmd.load_job import add_jobs
+from tuna.utils.machine_utility import load_machines
+from tuna.machine import Machine
+from tuna.miopen.utils.lib_helper import get_worker
+from tuna.miopen.worker.fin_class import FinClass
+from tuna.miopen.db.tables import MIOpenDBTables
# TODO: This is a copy and is unacceptable
sqlite_config_cols = [
@@ -44,9 +60,9 @@
sqlite_perf_db_cols = ["solver", "config", "arch", "num_cu", "params"]
-valid_arch_cu = [("gfx803", 36), ("gfx803", 64), ("gfx900", 56), ("gfx900", 64),
- ("gfx906", 60), ("gfx906", 64), ("gfx908", 120),
- ("gfx1030", 36)]
+#valid_arch_cu = [("gfx803", 36), ("gfx803", 64), ("gfx900", 56), ("gfx900", 64),
+# ("gfx906", 60), ("gfx906", 64), ("gfx908", 120),
+# ("gfx1030", 36)]
def get_sqlite_table(cnx, table_name):
@@ -115,11 +131,15 @@ class GoFishArgs():
solver_id = None
find_mode = 1
blacklist = None
+ init_session = True
+ check_status = True
+ subcommand = None
+ shutdown_workers = None
class ExampleArgs():
- arch = 'gfx908'
- num_cu = 120
+ arch = 'gfx90a'
+ num_cu = 104
local_machine = True
remote_machine = False
session_id = None
@@ -139,7 +159,10 @@ def get_worker_args(args, machine, miopen):
return kwargs
-def add_test_session(arch='gfx908', num_cu=120, label=None):
+def add_test_session(arch='gfx90a',
+ num_cu=104,
+ label=None,
+ session_table=Session):
args = GoFishArgs()
if label:
args.label = label
@@ -151,8 +174,8 @@ def add_test_session(arch='gfx908', num_cu=120, label=None):
miopen = MIOpen()
miopen.args = args
kwargs = get_worker_args(args, machine, miopen)
- worker = WorkerInterface(**kwargs)
- session_id = Session().add_new_session(args, worker)
+ worker = FinClass(**kwargs)
+ session_id = session_table().add_new_session(args, worker)
assert (session_id)
return session_id
@@ -209,3 +232,61 @@ def __init__(self):
def to_dict(self, ommit_valid=False):
return vars(self)
+
+
+def add_cfgs(tag, filename, logger_name):
+ #import configs
+ args = CfgImportArgs()
+ args.tag = tag
+ args.mark_recurrent = True
+ args.file_name = f"{this_path}/../utils/configs/{filename}"
+
+ dbt = MIOpenDBTables(config_type=args.config_type)
+ counts = import_cfgs(args, dbt, setup_logger(logger_name))
+ return dbt
+
+
+def add_test_jobs(miopen,
+ session_id,
+ dbt,
+ label,
+ tag,
+ fin_steps,
+ logger_name,
+ algo=None):
+ machine_lst = load_machines(miopen.args)
+ machine = machine_lst[0]
+ #update solvers
+ kwargs = get_worker_args(miopen.args, machine, miopen)
+ fin_worker = FinClass(**kwargs)
+ assert (fin_worker.get_solvers())
+
+ #get applicability
+ dbt = add_cfgs(label, 'conv_configs_NCHW.txt', label)
+ miopen.args.update_applicability = True
+ worker_lst = miopen.compose_worker_list(machine_lst)
+ for worker in worker_lst:
+ worker.join()
+ #load jobs
+ args = LdJobArgs
+ args.label = label
+ args.tag = tag
+ args.fin_steps = fin_steps
+ args.session_id = session_id
+ logger = setup_logger(logger_name)
+
+ #limit job scope
+ if algo:
+ args.algo = algo
+ solver_arr = ALG_SLV_MAP[args.algo]
+ solver_id_map = get_solver_ids()
+ if solver_arr:
+ solver_ids = []
+ for solver in solver_arr:
+ sid = solver_id_map.get(solver, None)
+ solver_ids.append((solver, sid))
+ args.solvers = solver_ids
+ args.only_applicable = True
+
+ connect_db()
+ return add_jobs(args, dbt, logger)
diff --git a/tuna/celery_app/README.md b/tuna/celery_app/README.md
new file mode 100644
index 000000000..b06694688
--- /dev/null
+++ b/tuna/celery_app/README.md
@@ -0,0 +1,85 @@
+#Tuning with celery
+
+MITuna launches celery workers and uses redis as a broker and backend result. Celery is a custom
+scheduler which abstracts job scheduling for the purpose of tuning. Tuna launches 1 celery worker
+per node for the compile step and 1 celery worker per GPU for the evaluate step.
+
+MITuna enqueues all tuning jobs into the redis queue. The celery workers then pull from the redis
+queue and launch the tuning jobs. The results of the tuning jobs are asynchronously collected by
+MITuna and the mySQL backend is updated accordingly.
+
+The following steps in MITuna make use of celery workers:
+```
+./go_fish.py miopen --fin_steps miopen_find_compile --session_id 1
+./go_fish.py miopen --fin_steps miopen_find_eval --session_id 1
+./go_fish.py miopen --fin_steps miopen_perf_compile --session_id 1
+./go_fish.py miopen --fin_steps miopen_perf_eval --session_id 1
+```
+
+A celery worker can be launched manually on a machine like this:
+
+Launch dockers through docker compose:
+```
+sudo -E docker compose up --build
+```
+This will launch a redis docker with the latest image and a custom docker for the celery worker.
+The celery docker will display information about the celery setup such as the broker and result
+backend. These can be customized in `tuna/celery_app/celery_app.py`
+
+Launch the celery docker container:
+```
+sudo docker exec -it mituna_celery_1 bash
+
+To test celery on a local machine:
+Install redis and start the redis server:
+```
+redis-server --daemonize yes
+```
+Intall rabbitMQ that is used as a broker by celery. Instructions can be found here: [Install rabbitMQ](https://www.rabbitmq.com/docs/install-debian)
+
+Clone MITuna and launch a celery worker:
+```
+git clone https://github.com/ROCmSoftwarePlatform/MITuna.git
+cd MITuna
+source ~/myvenv/bin/activate
+source ~/db_env.db
+celery -A tuna.celery_app.celery worker -l info -E -n worker_name -Q custom_q_name
+
+```
+
+##User interfaces to track redis backend and rabbitMQ broker data
+
+MITuna provides a docker compose file to launch [flower](https://flower.readthedocs.io/en/latest/), which helps track the tuning:
+```
+docker compose -f docker-compose-flower_rabbitmq.yaml up --build -d
+```
+Navigate to `http://localhost:5555` to interact with the flower UI.
+
+To track the rabbitMQ broker data,install the following:
+```
+rabbitmq-plugins enable rabbitmq_management
+```
+Navigate to `http://localhost:15672` to interact with the rabbitMQ UI. The username and password required
+have to be set up through rabbitMQ, see: [rabbitMQ access control](https://www.rabbitmq.com/docs/access-control).
+
+
+Note:
+myvenv is the virtual environment as per MITuna/requirements.txt.
+
+db_env.db contains the database env variables:
+```
+export TUNA_DB_USER_NAME=root
+export TUNA_DB_NAME=test_db
+export TUNA_ROCM_VERSION=osdb-12969
+export TUNA_DB_HOSTNAME=10.XXX.XX.XX
+export TUNA_DB_USER_PASSWORD=myrootpwd
+export TUNA_CELERY_JOB_BATCH_SIZE=10 (optional)
+#rabbitMQ
+export TUNA_CELERY_BROKER_HOST=localhost
+export TUNA_CELERY_BROKER_USER=
+export TUNA_CELERY_BROKER_PWD=
+export TUNA_CELERY_BROKER_PORT=5672
+#redis
+export TUNA_CELERY_BACKEND_HOST=localhost
+export TUNA_CELERY_BACKEND_PORT=6379
+```
diff --git a/flaskapp/__init__.py b/tuna/celery_app/__init__.py
similarity index 91%
rename from flaskapp/__init__.py
rename to tuna/celery_app/__init__.py
index 38fe03f3e..360b3dba6 100644
--- a/flaskapp/__init__.py
+++ b/tuna/celery_app/__init__.py
@@ -2,7 +2,7 @@
#
# MIT License
#
-# Copyright (c) 2022 Advanced Micro Devices, Inc.
+# Copyright (c) 2023 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,3 +23,6 @@
# SOFTWARE.
#
###############################################################################
+#from .celery import app as celery_app
+
+#__all__ = ('celery_app',)
diff --git a/tuna/celery_app/celery_app.py b/tuna/celery_app/celery_app.py
new file mode 100644
index 000000000..b9f219aeb
--- /dev/null
+++ b/tuna/celery_app/celery_app.py
@@ -0,0 +1,150 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2023 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+"""Module to define celery app"""
+import os
+import subprocess
+from celery import Celery
+from celery.utils.log import get_task_logger
+from tuna.custom_errors import CustomError
+
+LOGGER = get_task_logger("celery_app")
+
+
+def get_broker_env():
+ """Set rabbitmq required env vars"""
+
+ #defaults
+ TUNA_CELERY_BROKER_HOST = 'localhost'
+ TUNA_CELERY_BROKER_PORT = 5672
+
+ if 'TUNA_CELERY_BROKER_USER' not in os.environ:
+ raise CustomError('TUNA_CELERY_BROKER_USER must be specified in env')
+ else:
+ TUNA_CELERY_BROKER_USER = os.environ['TUNA_CELERY_BROKER_USER']
+ if 'TUNA_CELERY_BROKER_PWD' not in os.environ:
+ raise CustomError('TUNA_CELERY_BROKER_PWD must be specified in env')
+ else:
+ TUNA_CELERY_BROKER_PWD = os.environ['TUNA_CELERY_BROKER_PWD']
+
+ if 'TUNA_CELERY_BROKER_HOST' in os.environ:
+ TUNA_CELERY_BROKER_HOST = os.environ['TUNA_CELERY_BROKER_HOST']
+ if 'TUNA_CELERY_BROKER_PORT' in os.environ:
+ TUNA_CELERY_BROKER_PORT = os.environ['TUNA_CELERY_BROKER_PORT']
+ if 'TUNA_CELERY_V_HOST' in os.environ:
+ TUNA_CELERY_V_HOST = os.environ['TUNA_CELERY_V_HOST']
+
+ return TUNA_CELERY_BROKER_HOST, TUNA_CELERY_BROKER_PORT, TUNA_CELERY_BROKER_USER, TUNA_CELERY_BROKER_PWD
+
+
+def get_backend_env():
+ """Get Redis env vars"""
+
+ #defaults
+ TUNA_CELERY_BACKEND_PORT = 6379
+ TUNA_CELERY_BACKEND_HOST = 'localhost'
+
+ if 'TUNA_CELERY_BACKEND_PORT' in os.environ:
+ TUNA_CELERY_BACKEND_PORT = os.environ['TUNA_CELERY_BACKEND_PORT']
+ if 'TUNA_CELERY_BACKEND_HOST' in os.environ:
+ TUNA_CELERY_BACKEND_HOST = os.environ['TUNA_CELERY_BACKEND_HOST']
+
+ return TUNA_CELERY_BACKEND_PORT, TUNA_CELERY_BACKEND_HOST
+
+
+TUNA_CELERY_BROKER_HOST, TUNA_CELERY_BROKER_PORT, TUNA_CELERY_BROKER_USER, TUNA_CELERY_BROKER_PWD = get_broker_env(
+)
+
+TUNA_CELERY_BACKEND_PORT, TUNA_CELERY_BACKEND_HOST = get_backend_env()
+
+#ampq borker & redis backend
+app = Celery(
+ 'celery_app',
+ broker_url=
+ f"amqp://{TUNA_CELERY_BROKER_USER}:{TUNA_CELERY_BROKER_PWD}@{TUNA_CELERY_BROKER_HOST}:{TUNA_CELERY_BROKER_PORT}/",
+ result_backend=
+ f"redis://{TUNA_CELERY_BACKEND_HOST}:{TUNA_CELERY_BACKEND_PORT}/15",
+ broker_transport_options={
+ "heartbeat": 60,
+ 'retry': True,
+ 'retry_policy': {
+ 'max_retries': 60,
+ 'interval_start': 0,
+ 'interval_step': 2,
+ 'interval_max': 30
+ },
+ },
+ broker_connection_retry_on_startup=True,
+ broker_channel_error_retry=True,
+ include=[
+ 'tuna.miopen.celery_tuning.celery_tasks',
+ 'tuna.example.celery_tuning.celery_tasks'
+ ])
+
+
+def stop_active_workers():
+ """Shutdown active workers"""
+
+ LOGGER.warning('Shutting down remote workers')
+ try:
+ if app.control.inspect().active() is not None:
+ app.control.shutdown()
+ except Exception as err: #pylint: disable=broad-exception-caught
+ LOGGER.warning('Exception occured while trying to shutdown workers: %s',
+ err)
+ return False
+
+ return True
+
+
+def stop_named_worker(hostname):
+ """Shutdown a specific worker"""
+ LOGGER.warning('Shutting down remote worker: %s', hostname)
+ try:
+ app.control.shutdown(destination=[hostname])
+ except Exception as exp: #pylint: disable=broad-exception-caught
+ LOGGER.warning('Exception occured while trying to shutdown workers: %s',
+ exp)
+ return False
+
+ return True
+
+
+def purge_queue(q_names):
+ """Purge jobs in queue"""
+ for q_name in q_names:
+ try:
+ LOGGER.info('Purging Q %s', q_name)
+ cmd = f"celery -A tuna.celery_app.celery_app purge -f -Q {q_name}".split(
+ ' ')
+ subp = subprocess.Popen(cmd) #pylint: disable=consider-using-with
+ #waiting for purge queue before continuing
+ subp.wait()
+ except Exception as exp: #pylint: disable=broad-exception-caught
+ LOGGER.info(exp)
+ return False
+
+ return True
diff --git a/tuna/celery_app/celery_workers.py b/tuna/celery_app/celery_workers.py
new file mode 100644
index 000000000..48fdfe271
--- /dev/null
+++ b/tuna/celery_app/celery_workers.py
@@ -0,0 +1,103 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2024 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+"""Interface class to set up and launch tuning functionality"""
+import os
+import logging
+import subprocess
+
+from tuna.utils.logger import setup_logger
+from tuna.libraries import Operation
+from tuna.utils.machine_utility import load_machines
+
+LOGGER: logging.Logger = setup_logger('celery_workers')
+
+
+def launch_worker_per_node(machines, cmd, formatted=False):
+ """Launch celery worker for compile"""
+ final_cmd = cmd
+ subp_list = []
+ for machine in machines:
+ try:
+ if formatted:
+ final_cmd = cmd.replace('HOSTNAME', machine.hostname)
+ subp = subprocess.Popen( #pylint: disable=consider-using-with
+ final_cmd.split(' '))
+ subp_list.append(subp)
+ except Exception as exp: #pylint: disable=broad-exception-caught
+ LOGGER.warning(exp)
+ return False
+
+ LOGGER.info('Successfully launched celery worker for compile')
+
+ return subp_list
+
+
+def launch_worker_per_gpu(machines, cmd, formatted=False):
+ """Launch celery worker for eval"""
+ curr_env = dict(os.environ.copy())
+ final_cmd = cmd
+ subp_list = []
+
+ for machine in machines:
+ num_gpus = machine.get_avail_gpus()
+ try:
+ if not num_gpus:
+ LOGGER.warning(
+ 'No available GPUs detected, unable to launch celery worker')
+ return False
+ for gpu_id in num_gpus:
+ if formatted:
+ try:
+ temp = cmd.replace('HOSTNAME', machine.hostname)
+ final_cmd = temp.replace('GPUID', str(gpu_id))
+ except Exception as exp: #pylint: disable=broad-exception-caught
+ LOGGER.warning(exp)
+ return False
+ subp = subprocess.Popen( #pylint: disable=consider-using-with
+ final_cmd.split(),
+ env=curr_env)
+ subp_list.append(subp)
+ LOGGER.info("Successfully launched celery worker #%s for eval, pid %s",
+ gpu_id, subp.pid)
+ except Exception as exp: #pylint: disable=broad-exception-caught
+ LOGGER.info('Error ocurred: %s', exp)
+ return False
+
+ return subp_list
+
+
+def launch_celery_worker(operation, cmd, args, formatted=False):
+ """Helper function to launch celery workers"""
+ machines = load_machines(args)
+ if operation == Operation.COMPILE:
+ ret = launch_worker_per_node(machines, cmd, formatted)
+ elif operation == Operation.EVAL:
+ ret = launch_worker_per_gpu(machines, cmd, formatted)
+ else:
+ raise ValueError('Operation does not support launching celery workers')
+
+ return ret
diff --git a/tuna/celery_app/utility.py b/tuna/celery_app/utility.py
new file mode 100644
index 000000000..71e4a2769
--- /dev/null
+++ b/tuna/celery_app/utility.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2024 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+"""Utility module for Celery helper functions"""
+import os
+from tuna.utils.logger import setup_logger
+
+LOGGER = setup_logger('celery_utility')
+
+
+def get_q_name(library, op_compile=False, op_eval=False):
+ """Compose queue name"""
+ db_name = os.environ['TUNA_DB_NAME']
+ q_name = None
+ if op_compile:
+ q_name = f"compile_q_{db_name}_sess_{library.dbt.session_id}"
+ elif op_eval:
+ q_name = f"eval_q_{db_name}_sess_{library.dbt.session_id}"
+ else:
+ q_name = f"unknown_op_{db_name}_sess_{library.dbt.session_id}"
+
+ return q_name
diff --git a/flaskapp/setup.py b/tuna/custom_errors.py
similarity index 82%
rename from flaskapp/setup.py
rename to tuna/custom_errors.py
index 53cc6b33e..1a1104a27 100644
--- a/flaskapp/setup.py
+++ b/tuna/custom_errors.py
@@ -3,7 +3,7 @@
#
# MIT License
#
-# Copyright (c) 2022 Advanced Micro Devices, Inc.
+# Copyright (c) 2024 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -24,16 +24,12 @@
# SOFTWARE.
#
###############################################################################
+"""Custom errors module"""
-from setuptools import find_packages, setup
-setup(
- name='tuna_app',
- version='1.0.0',
- packages=find_packages(),
- include_package_data=True,
- zip_safe=False,
- install_requires=[
- 'flask',
- ],
-)
+class CustomError(Exception):
+ """Custom exception class"""
+
+ def __init__(self, message):
+ self.message = message
+ super().__init__(self.message)
diff --git a/tuna/db/session_mixin.py b/tuna/db/session_mixin.py
index 3f338d561..52b66a300 100644
--- a/tuna/db/session_mixin.py
+++ b/tuna/db/session_mixin.py
@@ -46,7 +46,7 @@ class SessionMixin():
rocm_v: str = Column(String(length=64), nullable=False)
reason: str = Column(String(length=60), nullable=False)
ticket: str = Column(String(length=64), nullable=False, server_default="N/A")
- docker: str = Column(String(length=64),
+ docker: str = Column(String(length=128),
nullable=False,
server_default="miopentuna")
diff --git a/tuna/driver.py b/tuna/driver.py
new file mode 100644
index 000000000..984665aaa
--- /dev/null
+++ b/tuna/driver.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2023 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+"""Module that encapsulates the DB representation of a Driver cmd"""
+
+from typing import Union, Dict, Any
+from abc import ABC, abstractmethod
+from tuna.miopen.db.miopen_tables import ConvolutionConfig
+
+
+class DriverBase(ABC):
+ """Represents db tables based on ConfigType"""
+
+ def __init__(self, line: str = str(), db_obj: ConvolutionConfig = None):
+ super().__init__()
+ if line:
+ if not self.construct_driver(line):
+ raise ValueError(f"Error creating Driver from line: '{line}'")
+ elif db_obj:
+ if not self.construct_driver_from_db(db_obj):
+ raise ValueError(
+ f"Error creating Driver from db obj: '{db_obj.to_dict()}'")
+ else:
+ raise ValueError(
+ "Error creating Driver. Driver cmd line or db_obj required")
+
+ @abstractmethod
+ def construct_driver(self, line: str) -> bool:
+ """Parse configuration from input line"""
+ raise NotImplementedError("Not implemented")
+
+ @abstractmethod
+ def construct_driver_from_db(self, db_obj: Any) -> bool:
+ """Takes a db row of a configuration and returns the string representation"""
+ raise NotImplementedError("Not implemented")
+
+ def to_dict(self) -> Dict[str, Union[str, int]]:
+ """Return class to dictionary"""
+ return dict(vars(self))
+
+ def __str__(self):
+ return str(self.to_dict())
diff --git a/tuna/example/README.md b/tuna/example/README.md
deleted file mode 100644
index b8912f07a..000000000
--- a/tuna/example/README.md
+++ /dev/null
@@ -1,53 +0,0 @@
-New library integration
-=======================
-An example of how to integrate external applications in Tuna to utilize the scaling
-features of Tuna
-
-
-Example library
----------------
-*Example* is mock library that runs the *rocminfo* binary.
-The supported tuning steps are:
-```
-./go_fish.py example --add_tables
-./go_fish.py example --init_session -l my_label
-./example/load_job.py -a gfx908 -n 120 -l my_label --session_id 1
-./go_fish.py example --execute --session_id 1
-```
-
-The first step is:
-```
-./go_fish.py example --add_tables
-```
-This command will create the following new tables in the DB:
-- machine
-- session_example
-- job
-
-The next command is:
-```
-./go_fish.py example --init_session -l my_label
-```
-This command will add a new session in the *session_example* table. This session id will be
-used to add new jobs and track the tuning data, post execution step.
-
-The third step is:
-```
-./tuna/example/load_job.py -a gfx908 -n 120 -l my_label --session_id 1
-
-```
-This steps loads jobs in the *job* table. These jobs will be picked up for execution in the
-next step. Once these jobs are completed their status will be updated to 'completed' or 'errored'.
-
-The last step:
-```
-./go_fish.py example --execute --session_id 1
-
-```
-This command will pick up jobs in the *new* state from the job tables associated with the
-session_id 1. The job status will be updated as the jobs are executing, from new to running and
-completed or errored.
-
-To integrate a new library, similar source code would have to be provided, as the one included
-in */tuna/example*. The full MIOpen library source code for tuning is included in
-*/tuna/miopen*.
diff --git a/flaskapp/example_app.py b/tuna/example/build_schema.py
old mode 100644
new mode 100755
similarity index 73%
rename from flaskapp/example_app.py
rename to tuna/example/build_schema.py
index d5dd9c787..6aa4ab74d
--- a/flaskapp/example_app.py
+++ b/tuna/example/build_schema.py
@@ -3,7 +3,7 @@
#
# MIT License
#
-# Copyright (c) 2022 Advanced Micro Devices, Inc.
+# Copyright (c) 2024 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -24,22 +24,21 @@
# SOFTWARE.
#
###############################################################################
-"""Tuna Flask App main function"""
-import json
+""" Module for creating DB tables"""
+from tuna.utils.logger import setup_logger
+from tuna.utils.db_utility import create_tables
+from tuna.example.example_tables import get_tables
-from flask import Flask
-from flaskapp.views.grafana import grafana
+#pylint: disable=too-few-public-methods
+LOGGER = setup_logger('example_db_tables')
-app = Flask(__name__)
-app.register_blueprint(example_grafana)
-
-@app.route('/')
def main():
- """Main entry point"""
- return json.dumps({'success': True}), 200, {'ContentType': 'application/json'}
+ """Main script function"""
+ #setup Example DB
+ ret_t = create_tables(get_tables())
+ LOGGER.info('DB creation successful: %s', ret_t)
if __name__ == '__main__':
- app.config['SECRET_KEY'] = 't\xbe-\xdc\xe7A\r\x1f\xb7\t\xa6\xa1\x8c\xd14\xf3'
- app.run(debug=True)
+ main()
diff --git a/tuna/example/celery_tuning/celery_tasks.py b/tuna/example/celery_tuning/celery_tasks.py
new file mode 100644
index 000000000..cc1064d15
--- /dev/null
+++ b/tuna/example/celery_tuning/celery_tasks.py
@@ -0,0 +1,76 @@
+#!/usr/bin/env python3
+
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2024 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+"""Module to register MIOpen celery tasks"""
+import copy
+from celery.signals import celeryd_after_setup
+from celery.utils.log import get_task_logger
+from tuna.celery_app.celery_app import app
+from tuna.machine import Machine
+from tuna.utils.celery_utils import prep_default_kwargs, get_cached_worker
+from tuna.example.example_lib import Q_NAME
+from tuna.example.example_worker import ExampleWorker
+
+logger = get_task_logger(__name__)
+
+
+@celeryd_after_setup.connect
+def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-argument
+ """Capture worker name"""
+ app.worker_name = sender
+
+
+cached_machine = Machine(local_machine=True)
+
+
+def prep_kwargs(kwargs, args):
+ """Populate kwargs with serialized job and machine"""
+ return prep_default_kwargs(kwargs, args[0], cached_machine)
+
+
+cached_worker = {}
+
+
+def prep_worker(context):
+ """Creating tuna worker object based on context"""
+ operation = context['operation']
+ if operation in cached_worker:
+ worker = get_cached_worker(context, cached_worker)
+ else:
+ args = [context['job'], context['operation']]
+ kwargs = prep_kwargs(context['kwargs'], args)
+ worker = ExampleWorker(**kwargs)
+ cached_worker[operation] = worker
+ return worker
+
+
+@app.task(trail=True, reply_to=Q_NAME)
+def celery_enqueue(context):
+ """Defines a celery task"""
+ worker = prep_worker(copy.deepcopy(context), ExampleWorker)
+ ret = worker.run()
+ return {"ret": ret, "context": context}
diff --git a/tuna/example/doc/Tuning.md b/tuna/example/doc/Tuning.md
new file mode 100644
index 000000000..177597ea9
--- /dev/null
+++ b/tuna/example/doc/Tuning.md
@@ -0,0 +1,78 @@
+Tuning through the Example library
+==================================
+An example of how to integrate external applications in MITuna.
+
+
+*Example* is mock library that runs the *rocminfo* binary.
+The supported tuning steps are:
+
+.. code-block::
+
+ ./go_fish.py example --add_tables
+ ./go_fish.py example --init_session -l my_label
+ ./example/load_job.py -a gfx908 -n 120 -l my_label --session_id 1
+ ./go_fish.py example --execute --session_id 1
+
+Creating the database tables:
+
+.. code-block::
+
+ ./go_fish.py example --add_tables
+
+This command will create the following new tables in the DB:
+* machine
+* session_example
+* job
+
+Adding a new session:
+
+.. code-block::
+
+ ./go_fish.py example --init_session -l my_label
+
+This command will add a new session in the *session_example* table. This session id will be
+used to add new jobs and track the tuning data, post execution step.
+
+Setting up jobs for tuning:
+
+.. code-block::
+
+ ./tuna/example/load_job.py -a gfx908 -n 120 -l my_label --session_id 1
+
+This steps loads jobs in the *job* table. These jobs will be picked up for execution in the
+next step. Once these jobs are completed their status will be updated to 'completed' or 'errored'.
+
+The first tuning step:
+
+.. code-block::
+
+ ./go_fish.py example --session_id 1 --enqueue_only
+
+This command will pick up jobs in the *new* state from the job tables associated with the
+session_id 1. The job status will be updated to running and jobs will be placed in a celery
+queue. This step needs to be executed on the headnode (or any node with access to the rabbitMQ
+broker, celery backend result and mySQL DB.) This main process will write results from the redis DB
+into the final mySQL DB.
+
+
+The last tuning step:
+
+.. code-block::
+
+ ./go_fish.py example --session_id 1
+
+This last step launches a celery worker. The worker will pick up jobs from the queue and execute.
+This step needs to be launched on the machine where the job is to be executed.
+
+To execute a single command, in this case `rocminfo`:
+
+.. code-block::
+
+ ./go_fish.py example --session_id 1 --execute
+
+For the purpose of this example, the command `rocminfo` is hardcoded. This step is not considered
+a tuning step. This is a standalone step that launches a particular pre-installed binary.
+
+This command will pick up jobs in the *new* state from the job tables associated with the
+session_id 1. The job status will be updated as the jobs are executing, from new to running and
+completed or errored.
diff --git a/tuna/example/example_lib.py b/tuna/example/example_lib.py
index 1b448d701..14896d5bf 100644
--- a/tuna/example/example_lib.py
+++ b/tuna/example/example_lib.py
@@ -32,22 +32,30 @@
from typing import Dict, Any, List, Optional
from tuna.mituna_interface import MITunaInterface
from tuna.parse_args import TunaArgs, setup_arg_parser, args_check
-from tuna.utils.miopen_utility import load_machines
+from tuna.utils.machine_utility import load_machines
from tuna.machine import Machine
from tuna.libraries import Library
-from tuna.utils.db_utility import create_tables
+from tuna.utils.db_utility import create_tables, gen_select_objs
from tuna.example.example_tables import get_tables
from tuna.example.example_worker import ExampleWorker
from tuna.example.session import SessionExample
+from tuna.example.tables import ExampleDBTables
+from tuna.libraries import Operation
+from tuna.miopen.utils.helper import set_job_state
+from tuna.utils.utility import SimpleDict
+
+Q_NAME = None
class Example(MITunaInterface):
- """Class to support an example of 'romcinfo' run"""
+ """! Class to support an example of 'romcinfo' run"""
def __init__(self):
super().__init__(library=Library.EXAMPLE)
self.args: argparse.Namespace = None
+ self.operation = None
+ self.set_state = None
def parse_args(self) -> None:
# pylint: disable=too-many-statements
@@ -56,7 +64,8 @@ def parse_args(self) -> None:
parser = setup_arg_parser('Example library integrated with MITuna', [
TunaArgs.ARCH, TunaArgs.NUM_CU, TunaArgs.VERSION, TunaArgs.SESSION_ID,
TunaArgs.MACHINES, TunaArgs.REMOTE_MACHINE, TunaArgs.LABEL,
- TunaArgs.RESTART_MACHINE, TunaArgs.DOCKER_NAME
+ TunaArgs.RESTART_MACHINE, TunaArgs.DOCKER_NAME, TunaArgs.ENQUEUE_ONLY,
+ TunaArgs.SHUTDOWN_WORKERS
])
group: argparse._MutuallyExclusiveGroup = parser.add_mutually_exclusive_group(
)
@@ -84,6 +93,26 @@ def parse_args(self) -> None:
sys.exit(-1)
args_check(self.args, parser)
+ if self.args.execute and self.args.enqueue_only:
+ parser.error('--operation and --enqueue_only are mutually exclusive')
+
+ self.dbt = ExampleDBTables(session_id=self.args.session_id)
+ self.update_operation()
+
+ def update_operation(self):
+ """! Set worker operation type
+ """
+ if not self.args.execute:
+ self.operation = Operation.COMPILE
+ self.fetch_state.add('new')
+ self.set_state = 'running'
+
+ def has_tunable_operation(self):
+ """Check if tunable operation is set"""
+ if self.args is None:
+ self.parse_args()
+ return self.operation is not None
+
def launch_worker(self, gpu_idx: int, f_vals: Dict[str, Any], \
worker_lst: List[ExampleWorker]) -> bool:
@@ -132,7 +161,8 @@ def compose_worker_list(self, machines) -> Optional[List[ExampleWorker]]:
return worker_lst
def add_tables(self) -> bool:
- """Generates the library specific schema to the connected SQL server."""
+ """! Generates the library specific schema to the connected SQL server.
+ """
ret_t: bool = create_tables(get_tables())
self.logger.info('DB creation successful: %s', ret_t)
@@ -140,7 +170,8 @@ def add_tables(self) -> bool:
def run(self) -> Optional[List[ExampleWorker]]:
# pylint: disable=duplicate-code
- """Main run function of example_lib"""
+ """! Main run function of example_lib
+ """
res: Optional[List[ExampleWorker]]
self.parse_args()
if self.args.add_tables:
@@ -157,11 +188,89 @@ def get_envmt(self) -> List[str]:
envmt: List[str] = []
return envmt
- def get_kwargs(self, gpu_idx: int, f_vals: Dict[str, Any]) -> Dict[str, Any]:
+ def get_kwargs(self,
+ gpu_idx: int,
+ f_vals: Dict[str, Any],
+ tuning=False) -> Dict[str, Any]:
"""! Helper function to set up kwargs for worker instances
@param gpu_idx Unique ID of the GPU
@param f_vals Dict containing process specific runtime information
"""
- kwargs: Dict[str, Any] = super().get_kwargs(gpu_idx, f_vals)
+ kwargs: Dict[str, Any] = super().get_kwargs(gpu_idx, f_vals, tuning)
return kwargs
+
+ def get_job_list(self, session, find_state=None, claim_num=None):
+ """!Get list of jobs
+ @param session DB session
+ @param find_state state of DB job
+ @param claim_num Number of jobs to pick up
+ """
+ job_list = gen_select_objs(session, self.get_job_attr(),
+ self.dbt.job_table.__tablename__,
+ "WHERE state='new'")
+ return job_list
+
+ def serialize_jobs(self, session, batch_jobs):
+ """!Return list of serialize jobs
+ @param session DB session
+ @param batch_jobs Number of DB jobs
+ """
+ return [elem.to_dict() for elem in batch_jobs]
+
+ def build_context(self, serialized_jobs):
+ """!Build context list for enqueue job
+ @param serialized_jobs List of DB jobs, serialized for Celery
+ """
+ context_list = []
+ kwargs = self.get_context_items()
+ for job in serialized_jobs:
+ context = {
+ 'job': job,
+ 'operation': self.operation,
+ 'arch': self.dbt.session.arch,
+ 'num_cu': self.dbt.session.num_cu,
+ 'kwargs': kwargs,
+ }
+ context_list.append(context)
+
+ return context_list
+
+ def celery_enqueue_call(self, context, q_name, task_id=False):
+ """! Wrapper function for celery enqueue func
+ @param context serialized context for Celery job
+ @param q_name Name of custom Celery queue
+ @param task_id custom task ID for redis Key
+ """
+ Q_NAME = q_name #pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name
+ from tuna.example.celery_tuning.celery_tasks import celery_enqueue #pylint: disable=import-outside-toplevel
+
+ return celery_enqueue.apply_async((context,), queue=q_name, reply_to=q_name)
+
+ def process_compile_results(self, session, fin_json, context):
+ """! Process result from fin_build worker
+ @param session DB session
+ @param fin_json Result of Celery task, from MIFin in json
+ @param context Celery job execution context, serialized
+ """
+ self.logger.info('Pocessing compile results')
+ self.update_job_state(session, fin_json, context)
+
+ def process_eval_results(self, session, fin_json, context):
+ """! Process fin_json result
+ @param session DB session
+ @param fin_json Result of Celery task, from MIFin in json
+ @param context Celery job execution context, serialized
+ """
+ self.logger.info('Pocessing eval results')
+ self.update_job_state(session, fin_json, context)
+
+ def update_job_state(self, session, fin_json, context):
+ """! Function to update DB job state post celery task run
+ @param session DB session
+ @param fin_json Result of Celery task, from MIFin in json
+ @param context Celery job execution context, serialized
+ """
+ self.logger.info(fin_json)
+ job = SimpleDict(**context['job'])
+ set_job_state(session, job, self.dbt, 'completed', result=fin_json)
diff --git a/tuna/example/example_tables.py b/tuna/example/example_tables.py
index 174f254d8..094a7683b 100644
--- a/tuna/example/example_tables.py
+++ b/tuna/example/example_tables.py
@@ -39,7 +39,7 @@
class JobEnum(enum.Enum):
- """Represents job_enum column in config table"""
+ """Represents job_enum column in job table"""
# pylint: disable=invalid-name ; names represent entries in job_enum column
# pylint: disable=duplicate-code
new = 1
@@ -67,8 +67,6 @@ def session(self) -> Column:
gpu_id = Column(Integer, nullable=False, server_default="-1")
machine_id = Column(Integer, nullable=False, server_default="-1")
- config = Column(Integer, nullable=False, index=True)
-
def get_tables() -> List[BASE]:
"""Returns a list of all Example lib DB tables"""
diff --git a/tuna/example/example_worker.py b/tuna/example/example_worker.py
index 3aa2b9152..77f0fe48e 100644
--- a/tuna/example/example_worker.py
+++ b/tuna/example/example_worker.py
@@ -26,17 +26,14 @@
###############################################################################
"""Builder class implements the worker interface. The purpose of this class is to run the
rocminfo command"""
-from time import sleep
-import random
-
-from typing import Dict, Any, Optional, List
+from typing import Dict, Any, List
from tuna.worker_interface import WorkerInterface
from tuna.example.tables import ExampleDBTables
class ExampleWorker(WorkerInterface):
- """ The Example class implements the worker class. Its purpose is to run a command. It picks up
- new jobs and when completed, sets the state to completed. """
+ """ The Example class implements the worker class. Its purpose is to run a command
+ and return the output."""
def __init__(self, **kwargs: Dict[str, Any]) -> None:
"""Constructor"""
@@ -48,32 +45,9 @@ def set_db_tables(self) -> None:
"""Initialize tables"""
self.dbt = ExampleDBTables(session_id=self.session_id)
- def step(self) -> bool:
- """Main functionality of the worker class. It picks up jobs in new state and executes them"""
-
- if not self.get_job("new", "running", False):
- #Sleep in case of DB contention
- sleep(random.randint(1, 10))
- return False
-
- failed_job: bool = False
- self.logger.info('Acquired new job: job_id=%s', self.job.id)
- self.set_job_state('running')
- cmd_output: Optional[str] = None
- err_str: str = ''
- try:
- cmd_output = self.run_cmd()
- except ValueError as verr:
- self.logger.info(verr)
- failed_job = True
- err_str = str(verr)
-
- if failed_job:
- self.set_job_state('errored', result=err_str)
- else:
- self.set_job_state('completed', result=cmd_output)
-
- return True
+ def step(self) -> str:
+ """Main functionality of the worker class. Runs rocminfo"""
+ return self.run_cmd()
def run_cmd(self) -> str:
"""Run the actual workload"""
diff --git a/tuna/example/load_job.py b/tuna/example/load_job.py
index 5a65b2030..598ec25a5 100755
--- a/tuna/example/load_job.py
+++ b/tuna/example/load_job.py
@@ -77,7 +77,6 @@ def add_jobs(args: argparse.Namespace, dbt: Type[ExampleDBTables]) -> int:
job.valid = 1
job.reason = args.label
job.session = args.session_id
- job.config = args.config
session.add(job)
session.commit()
counts += 1
diff --git a/tuna/flask_example.py b/tuna/flask_example.py
deleted file mode 100644
index 6791c7119..000000000
--- a/tuna/flask_example.py
+++ /dev/null
@@ -1,64 +0,0 @@
-#!/usr/bin/env python3
-###############################################################################
-#
-# MIT License
-#
-# Copyright (c) 2022 Advanced Micro Devices, Inc.
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-#
-###############################################################################
-"""Utility module for Flask functionality"""
-
-from typing import List
-from sqlalchemy import create_engine
-from tuna.utils.logger import setup_logger
-from tuna.dbBase.sql_alchemy import DbSession
-from tuna.utils.utility import get_env_vars
-from tuna.grafana_dict import EXAMPLE_TABLE
-from tuna.miopen.db.find_db import ConvolutionFindDB
-
-LOGGER = setup_logger('flask')
-ENV_VARS = get_env_vars()
-ENGINE = create_engine(f"mysql+pymysql://{ENV_VARS['user_name']}:{ENV_VARS['user_password']}" +\
- f"@{ENV_VARS['db_hostname']}:3306/{ENV_VARS['db_name']}")
-
-
-def get_table_example(grafana_req: str, data: List[str]) -> List[str]:
- """example on how to populate a table for a Grafana dashboard"""
-
- LOGGER.info('Request: %s', grafana_req)
- #Populate the table with dummy data
- EXAMPLE_TABLE['rows'].append(['val1', 'ex1', '1', '1.05'])
- EXAMPLE_TABLE['rows'].append(['val2', 'ex2', '2', '1.06'])
- EXAMPLE_TABLE['rows'].append(['val3', 'ex3', '3', '1.06'])
- EXAMPLE_TABLE['rows'].append(['val4', 'ex4', '4', '1.08'])
-
- #To populate the table with data from your DB:
- res: List[str]
- with DbSession() as session:
- query: list = session.query(ConvolutionFindDB.valid,
- ConvolutionFindDB.kernel_time).limit(5).all()
- for res in query:
- EXAMPLE_TABLE['rows'].append([res[0], res[1], res[2], res[3]])
-
- #The data variable will contain both dummy and db data
-
- data.append(EXAMPLE_TABLE)
- return data
diff --git a/tuna/go_fish.py b/tuna/go_fish.py
index 14978f00c..8562bca27 100755
--- a/tuna/go_fish.py
+++ b/tuna/go_fish.py
@@ -26,6 +26,7 @@
###############################################################################
"""! @brief Script to launch tuning jobs, or execute commands on available machines"""
+import os
import argparse
import sys
import logging
@@ -69,24 +70,25 @@ def parse_args() -> Dict[str, Any]:
def main() -> bool:
"""Main function to start Tuna"""
- LOGGER.info(sys.argv)
- LOGGER.info(len(sys.argv))
args: Dict[str, Any]
args = parse_args()
+ clean_args()
#case no yaml file
library: Union[Example, MIOpen]
yaml_files: List[str]
library = get_library(args)
- clean_args()
yaml_files = [args['yaml']]
#case with yaml file
if args['yaml']:
yaml_files = parse_yaml(args['yaml'], args['lib'])
- worker_lst: list
+ job_batch_size = 1000
+ if 'TUNA_CELERY_JOB_BATCH_SIZE' in os.environ:
+ job_batch_size = int(os.environ['TUNA_CELERY_JOB_BATCH_SIZE'])
+
try:
for yaml_file in yaml_files:
args['yaml_file'] = yaml_file
@@ -94,14 +96,19 @@ def main() -> bool:
sys.argv[2] = yaml_file
LOGGER.info("Executing with yaml file: %s", yaml_file)
- #returns a list of workers/processes it started
- worker_lst = library.run()
- if worker_lst is None:
- continue
-
- for worker in worker_lst:
- worker.join()
- LOGGER.warning('Process finished')
+ if library.has_tunable_operation():
+ #Celery operations
+ library.tune(job_batch_size=job_batch_size)
+ else:
+ #non-celery operations
+ #returns a list of workers/processes it started
+ worker_lst = library.run()
+ if worker_lst is None:
+ continue
+
+ for worker in worker_lst:
+ worker.join()
+ LOGGER.warning('Process finished')
except KeyboardInterrupt:
LOGGER.warning('Interrupt signal caught')
diff --git a/tuna/lib_utils.py b/tuna/lib_utils.py
index 66d8e21e0..954ffe28c 100644
--- a/tuna/lib_utils.py
+++ b/tuna/lib_utils.py
@@ -36,7 +36,7 @@
def get_library(args: Dict[str, Any]) -> Union[Example, MIOpen, RocMLIR]:
"""Factory method to get lib based on args"""
library: Union[Example, MIOpen, RocMLIR]
- if 'lib' not in args.keys() or args['lib'].value == Library.MIOPEN.value:
+ if args['lib'].value == Library.MIOPEN.value:
library = MIOpen()
elif args['lib'].value == Library.EXAMPLE.value:
library = Example()
diff --git a/tuna/libraries.py b/tuna/libraries.py
index 0ee269773..4d6a809c3 100644
--- a/tuna/libraries.py
+++ b/tuna/libraries.py
@@ -38,3 +38,12 @@ class Library(Enum):
def __str__(self) -> str:
return self.value
+
+
+class Operation(str, Enum):
+ """Enumerate supported tuning operations"""
+ COMPILE: str = "compile"
+ EVAL: str = "eval"
+
+ def __str__(self) -> str:
+ return self.value
diff --git a/tuna/machine.py b/tuna/machine.py
index 79c77eca1..8ebef0ce3 100644
--- a/tuna/machine.py
+++ b/tuna/machine.py
@@ -53,7 +53,7 @@
ROCMINFO: str = '/opt/rocm/bin/rocminfo'
ROCMSMI: str = '/opt/rocm/bin/rocm-smi'
-CLINFO: str = '/opt/rocm/opencl/bin/clinfo'
+CLINFO: str = '/opt/rocm/bin/clinfo'
class Machine(BASE): #pylint: disable=too-many-instance-attributes
@@ -466,12 +466,16 @@ def chk_gpu_status(self, gpu_id: int) -> bool:
logger: logging.Logger = self.get_logger()
cnx = self.connect()
- if gpu_id not in self.avail_gpus:
- logger.info('GPU index %u out of bounds', gpu_id)
- return False
- logger.info('Checking GPU %u status', gpu_id)
- _, stdout, _ = cnx.exec_command(
- f'GPU_DEVICE_ORDINAL={gpu_id} {CLINFO} | grep gfx', timeout=30)
+ if self.avail_gpus is not None:
+ if gpu_id not in self.avail_gpus:
+ logger.info('GPU index %u out of bounds', gpu_id)
+ return False
+ logger.info('Checking GPU %u status', gpu_id)
+ _, stdout, _ = cnx.exec_command(
+ f'GPU_DEVICE_ORDINAL={gpu_id} {CLINFO} | grep gfx', timeout=30)
+ else:
+ logger.warning('No available gpus to check')
+
if stdout is None:
return False
for line in stdout:
diff --git a/tuna/miopen/celery_tuning/celery_tasks.py b/tuna/miopen/celery_tuning/celery_tasks.py
new file mode 100644
index 000000000..3dfe464a2
--- /dev/null
+++ b/tuna/miopen/celery_tuning/celery_tasks.py
@@ -0,0 +1,94 @@
+#!/usr/bin/env python3
+
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2024 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+"""Module to register MIOpen celery tasks"""
+import copy
+from celery.signals import celeryd_after_setup
+from celery.utils.log import get_task_logger
+from tuna.celery_app.celery_app import app
+from tuna.libraries import Operation
+from tuna.machine import Machine
+from tuna.miopen.utils.lib_helper import get_worker
+from tuna.utils.utility import SimpleDict
+from tuna.utils.celery_utils import prep_default_kwargs, get_cached_worker
+from tuna.miopen.miopen_lib import Q_NAME
+
+logger = get_task_logger(__name__)
+
+
+@celeryd_after_setup.connect
+def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-argument
+ """Capture worker name"""
+ app.worker_name = sender
+
+
+cached_machine = Machine(local_machine=True)
+
+
+def prep_kwargs(kwargs, args):
+ """Populate kwargs with serialized job, config and machine"""
+ kwargs = prep_default_kwargs(kwargs, args[0], cached_machine)
+ kwargs["config"] = SimpleDict(**args[1])
+
+ return kwargs
+
+
+cached_worker = {}
+
+
+def prep_worker(context):
+ """Creating tuna worker object based on context"""
+ operation = context['operation']
+ if operation in cached_worker:
+ worker = get_cached_worker(context, cached_worker)
+ worker.config = SimpleDict(**context['config'])
+ else:
+ args = [context['job'], context['config'], context['operation']]
+ kwargs = prep_kwargs(context['kwargs'], args)
+ worker = get_worker(kwargs, args[2])
+ cached_worker[operation] = worker
+ return worker
+
+
+@app.task(trail=True, reply_to=Q_NAME)
+def celery_enqueue(context):
+ """Defines a celery task"""
+ kwargs = context['kwargs']
+ operation = context['operation']
+
+ if operation == Operation.EVAL:
+ gpu_id = int((app.worker_name).split('gpu_id_')[1])
+ kwargs['gpu_id'] = gpu_id
+ context['job']['gpu_id'] = gpu_id
+ logger.info("Enqueueing worker %s: gpu(%s), job %s", app.worker_name,
+ gpu_id, context['job'])
+ else:
+ logger.info("Enqueueing worker %s: job %s", app.worker_name, context['job'])
+
+ worker = prep_worker(copy.deepcopy(context))
+ ret = worker.run()
+ return {"ret": ret, "context": context}
diff --git a/tuna/miopen/db/__init__.py b/tuna/miopen/db/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tuna/miopen/db/batch_norm_tables.py b/tuna/miopen/db/batch_norm_tables.py
new file mode 100644
index 000000000..a15d84592
--- /dev/null
+++ b/tuna/miopen/db/batch_norm_tables.py
@@ -0,0 +1,151 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2022 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+"""Represents Batch normalization table definitions """
+
+from sqlalchemy import Column, Integer, String, UniqueConstraint, ForeignKey
+from sqlalchemy.orm import relationship
+from tuna.dbBase.base_class import BASE
+from tuna.miopen.db.mixin_tables import BenchmarkMixin, CacheMixin
+from tuna.miopen.db.mixin_tables import ConfigTagMixin, KernelCacheMixin
+from tuna.miopen.db.mixin_tables import MIOpenJobMixin, SolverApplicabilityMixin
+from tuna.miopen.utils.metadata import DIR_MAP
+
+COMMON_UNIQ_FDS = ["config", "solver", "session"]
+
+#pylint: disable=too-few-public-methods
+#pylint: disable=duplicate-code
+
+
+class BNJob(BASE, MIOpenJobMixin):
+ """Represents batch norm job table"""
+ __tablename__ = "bn_job"
+ __table_args__ = (UniqueConstraint(*COMMON_UNIQ_FDS, name="uq_idx"),)
+
+ config = Column(Integer,
+ ForeignKey("bn_config.id"),
+ nullable=False,
+ index=True)
+
+
+class BNConfig(BASE):
+ """Represents batch normalization table"""
+ __tablename__ = "bn_config"
+ __table_args__ = (UniqueConstraint("alpha",
+ "beta",
+ "forw",
+ "verify",
+ "back",
+ "mode",
+ "batchsize",
+ "run",
+ "input_tensor",
+ name="uq_idx"),)
+
+ alpha = Column(Integer, nullable=False, server_default="1.0")
+ beta = Column(Integer, nullable=False, server_default="0.0")
+ forw = Column(Integer, nullable=False, server_default="1")
+ verify = Column(Integer, nullable=False, server_default="1")
+ back = Column(Integer, nullable=False, server_default="0")
+ mode = Column(Integer, nullable=False, server_default="0")
+ batchsize = Column(Integer, nullable=False, server_default="32")
+ run = Column(Integer, nullable=False, server_default="0")
+ save = Column(Integer, nullable=False, server_default="0")
+ input_tensor = Column(Integer, ForeignKey("tensor.id"), nullable=False)
+ input_t = relationship("TensorTable",
+ backref="bn_input_tensor",
+ foreign_keys=[input_tensor],
+ lazy="joined")
+ in_layout = Column(String(60), nullable=False, server_default="NCHW")
+ driver = Column(String(length=512), nullable=False, server_default="")
+
+ #pylint: disable=too-few-public-methods
+ def get_direction(self):
+ """synthesize direction"""
+ return DIR_MAP[(self.forw + 4 * self.back)]
+
+
+class BNConfigTags(BASE, ConfigTagMixin):
+ """Represents config_tags tables"""
+ __tablename__ = "bn_config_tags"
+ __table_args__ = (UniqueConstraint("config", "tag", name="uq_idx"),)
+
+ config = Column(Integer, ForeignKey("bn_config.id"), nullable=False)
+
+
+class BNSolverApplicability(BASE, SolverApplicabilityMixin):
+ """Represents bn_solver_applicability table"""
+ __tablename__ = "bn_solver_applicability"
+ __table_args__ = (UniqueConstraint(*COMMON_UNIQ_FDS, name="uq_idx"),)
+
+ config = Column(Integer,
+ ForeignKey("bn_config.id"),
+ nullable=False,
+ index=True)
+
+
+class BNJobCache(BASE, CacheMixin):
+ """Represents job_cache table for batch_norm"""
+ __tablename__ = "bn_job_cache"
+ __table_args__ = (UniqueConstraint("job_id", name="uq_cache_idx"),)
+
+ job_id = Column(Integer, ForeignKey("bn_job.id"), nullable=False)
+
+
+class BNFinJobCache(BASE, KernelCacheMixin):
+ """Represents job_cache table for batch_norm"""
+ __tablename__ = "bn_job_cache_fin"
+
+ job_id = Column(Integer,
+ ForeignKey("bn_job.id",
+ onupdate="CASCADE",
+ ondelete="CASCADE"),
+ nullable=False)
+ solver_id = Column(Integer,
+ ForeignKey("solver.id",
+ onupdate="CASCADE",
+ ondelete="CASCADE"),
+ nullable=False)
+
+
+class BNKernelCache(BASE, KernelCacheMixin):
+ """Represents kernel_cache table for batch_norm"""
+ __tablename__ = "bn_kernel_cache"
+
+ kernel_group = Column(Integer, nullable=True)
+
+
+class BNBenchmark(BASE, BenchmarkMixin):
+ """benchmark table for framework and model parameters"""
+ __tablename__ = "bn_benchmark"
+ __table_args__ = (UniqueConstraint("framework",
+ "model",
+ "batchsize",
+ "gpu_number",
+ "config",
+ name="uq_idx"),)
+
+ config = Column(Integer, ForeignKey("bn_config.id"), nullable=False)
diff --git a/tuna/miopen/db/benchmark.py b/tuna/miopen/db/benchmark.py
index c13074ecc..0f1bc456e 100644
--- a/tuna/miopen/db/benchmark.py
+++ b/tuna/miopen/db/benchmark.py
@@ -71,6 +71,33 @@ class ModelEnum(pyenum):
VGG11 = 'Vgg11'
DENSENET = 'Densenet'
DENSENET201 = 'Densenet201'
+ ATOA_SMALL = 'atoa_small'
+ ATOA_MEDIUM = 'atoa_medium'
+ PEAK = 'peak'
+ DENSENET121 = 'densenet121'
+ DENSENET161 = 'densenet161'
+ DENSENET169 = 'densenet169'
+ MNASNET0_5 = 'mnasnet0_5'
+ MNASNET0_75 = 'mnasnet0_75'
+ MNASNET1_5 = 'mnasnet1_0'
+ MNASNET1_3 = 'mnasnet1_3'
+ RESNET18 = 'Resnet18'
+ RESNET34 = 'Resnet34'
+ VGG13 = 'vgg13'
+ RESNEXT101_32x8d = 'Resnext101_32x8d'
+ RESNET50_32X4D = 'Resnext50_32x4d'
+ SHUFFLENET_V2_X0_5 = 'Shufflenet_v2_x0_5'
+ SHUFFLENET_V2_X1_0 = 'Shufflenet_v2_x1_0'
+ SHUFFLENET_V2_X1_5 = 'Shufflenet_v2_x1_5'
+ SHUFFLENET_V2_X2_0 = 'Shufflenet_v2_x2_0'
+ SQUEEZENET1_0 = 'Squeezenet1_0'
+ SQUEEZENET1_1 = 'Squeezenet1_1'
+ WIDE_RESNET101_2 = 'Wide_resnet101_2'
+ WIDE_RESNET50_2 = 'Wide_resnet50_2'
+ VGG11_BN = 'Vgg11_bn'
+ VGG13_BN = 'Vgg13_bn'
+ VGG16_BN = 'Vgg16_bn'
+ VGG19_BN = 'Vgg19_bn'
def __str__(self):
return self.value
diff --git a/tuna/miopen/db/bn_golden_tables.py b/tuna/miopen/db/bn_golden_tables.py
new file mode 100644
index 000000000..82c4c971e
--- /dev/null
+++ b/tuna/miopen/db/bn_golden_tables.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2022 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+"""Represents Batchnorm Golden table definitions """
+
+from sqlalchemy import Column, Integer, UniqueConstraint, ForeignKey
+from tuna.dbBase.base_class import BASE
+from tuna.miopen.db.mixin_tables import GoldenMixin
+
+
+class BNGolden(BASE, GoldenMixin):
+ """Golden table for batch norm"""
+ __tablename__ = "bn_golden"
+ __table_args__ = (UniqueConstraint("golden_miopen_v",
+ "config",
+ "solver",
+ "arch",
+ "num_cu",
+ name="uq_idx"),)
+
+ config = Column(Integer, ForeignKey("bn_config.id"), nullable=False)
+
+ kernel_group = Column(Integer, nullable=True)
diff --git a/tuna/miopen/db/build_schema.py b/tuna/miopen/db/build_schema.py
index 889b62ba7..583b84430 100755
--- a/tuna/miopen/db/build_schema.py
+++ b/tuna/miopen/db/build_schema.py
@@ -26,14 +26,14 @@
###############################################################################
""" Module for creating DB tables"""
from sqlalchemy.exc import OperationalError
-from tuna.miopen.db.miopen_tables import get_miopen_tables
+from tuna.miopen.db.get_db_tables import get_miopen_tables
from tuna.miopen.db.triggers import get_miopen_triggers, drop_miopen_triggers
from tuna.db_engine import ENGINE
from tuna.utils.logger import setup_logger
from tuna.utils.db_utility import create_tables
#pylint: disable=too-few-public-methods
-LOGGER = setup_logger('db_tables')
+LOGGER = setup_logger('miopen_db_tables')
def recreate_triggers(drop_triggers, create_triggers):
diff --git a/tuna/miopen/db/convolutionjob_tables.py b/tuna/miopen/db/convolutionjob_tables.py
new file mode 100644
index 000000000..0c7133b53
--- /dev/null
+++ b/tuna/miopen/db/convolutionjob_tables.py
@@ -0,0 +1,235 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2022 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+"""Represents ConvolutionJob table definitions """
+from sqlalchemy import Column, Integer, String, UniqueConstraint, ForeignKey
+from sqlalchemy.orm import relationship
+from sqlalchemy import Index
+from sqlalchemy import Float, BigInteger, Boolean
+from tuna.dbBase.base_class import BASE
+from tuna.miopen.db.mixin_tables import BenchmarkMixin, CacheMixin
+from tuna.miopen.db.mixin_tables import ConfigTagMixin, GoldenMixin
+from tuna.miopen.db.mixin_tables import KernelCacheMixin, MIOpenJobMixin
+from tuna.miopen.db.mixin_tables import SolverAnalyticsMixin, SolverApplicabilityMixin
+
+COMMON_UNIQ_FDS = ["config", "solver", "session"]
+
+
+#pylint: disable=too-few-public-methods
+#pylint: disable=duplicate-code
+class ConvolutionJob(BASE, MIOpenJobMixin):
+ """Represents convolutions job table"""
+ __tablename__ = "conv_job"
+ __table_args__ = (UniqueConstraint(*COMMON_UNIQ_FDS, name="uq_idx"),)
+
+ config = Column(Integer,
+ ForeignKey("conv_config.id"),
+ nullable=False,
+ index=True)
+ get_job_ids1 = Index('get_job_idx1', 'session', 'valid', 'reason', 'fin_step',
+ 'retries')
+ get_job_ids2 = Index('get_job_idx2', 'session', 'valid')
+ get_job_ids3 = Index('get_job_idx3', 'session', 'valid', 'retries')
+ get_job_compile = Index('get_job_compile', 'valid', 'state', 'reason',
+ 'session')
+
+
+class ConvolutionConfig(BASE):
+ """Represents convolution config table"""
+ __tablename__ = "conv_config"
+
+ batchsize = Column(Integer, nullable=False, server_default="0")
+ spatial_dim = Column(Integer, nullable=False, server_default="2")
+ pad_h = Column(Integer, nullable=False, server_default="0")
+ pad_w = Column(Integer, nullable=False, server_default="0")
+ pad_d = Column(Integer, nullable=False, server_default="0")
+ conv_stride_h = Column(Integer, nullable=False, server_default="1")
+ conv_stride_w = Column(Integer, nullable=False, server_default="1")
+ conv_stride_d = Column(Integer, nullable=False, server_default="1")
+ dilation_h = Column(Integer, nullable=False, server_default="1")
+ dilation_w = Column(Integer, nullable=False, server_default="1")
+ dilation_d = Column(Integer, nullable=False, server_default="1")
+ group_count = Column(Integer, nullable=False, server_default="1")
+ mode = Column(String(length=40), nullable=False, server_default="conv")
+ pad_mode = Column(String(length=40), nullable=False, server_default="default")
+ trans_output_pad_h = Column(Integer, nullable=False, server_default="0")
+ trans_output_pad_w = Column(Integer, nullable=False, server_default="0")
+ trans_output_pad_d = Column(Integer, nullable=False, server_default="0")
+ direction = Column(String(length=8), nullable=False)
+ input_tensor = Column(Integer, ForeignKey("tensor.id"), nullable=False)
+ weight_tensor = Column(Integer, ForeignKey("tensor.id"), nullable=False)
+ input_t = relationship("TensorTable",
+ backref="conv_input_tensor",
+ foreign_keys=[input_tensor],
+ lazy="joined")
+ weight_t = relationship("TensorTable",
+ backref="weight_tensor",
+ foreign_keys=[weight_tensor],
+ lazy="joined")
+ out_layout = Column(String(60), nullable=False, server_default="NCHW")
+ md5 = Column(String(length=40), nullable=False, unique=True)
+ driver = Column(String(length=512), nullable=False, server_default="")
+
+
+class ConvolutionConfigTags(BASE, ConfigTagMixin):
+ """Represents config_tags tables"""
+ __tablename__ = "conv_config_tags"
+ __table_args__ = (UniqueConstraint("config", "tag", name="uq_idx"),)
+
+ config = Column(Integer, ForeignKey("conv_config.id"), nullable=False)
+
+
+class ConvSolverApplicability(BASE, SolverApplicabilityMixin):
+ """Represents conv_solver_applicability table"""
+ __tablename__ = "conv_solver_applicability"
+ __table_args__ = (UniqueConstraint(*COMMON_UNIQ_FDS, name="uq_idx"),)
+
+ config = Column(Integer,
+ ForeignKey("conv_config.id"),
+ nullable=False,
+ index=True)
+ app_idx = Index('app_idx', 'config', 'solver', 'session')
+ sess_cfg = Index('sess_cfg', 'session', 'config')
+
+
+class ConvJobCache(BASE, CacheMixin):
+ """Represents job_cache table for convolutions"""
+ __tablename__ = "conv_job_cache"
+ __table_args__ = (UniqueConstraint("job_id", name="uq_cache_idx"),)
+
+ job_id = Column(Integer, ForeignKey("conv_job.id"), nullable=False)
+
+
+class ConvFinJobCache(BASE, KernelCacheMixin):
+ """Represents job_cache table"""
+ __tablename__ = "conv_job_cache_fin"
+
+ job_id = Column(Integer,
+ ForeignKey("conv_job.id",
+ onupdate="CASCADE",
+ ondelete="CASCADE"),
+ nullable=False)
+ solver_id = Column(Integer,
+ ForeignKey("solver.id",
+ onupdate="CASCADE",
+ ondelete="CASCADE"),
+ nullable=False)
+
+ idx_job = Index('job_id')
+
+
+class ConvolutionKernelCache(BASE, KernelCacheMixin):
+ """Represents kernel_cache table for convolutions"""
+ __tablename__ = "conv_kernel_cache"
+
+ kernel_group = Column(Integer, nullable=True)
+ idx_kgroup = Index('kernel_group')
+ idx_valid = Index('valid', 'kernel_group')
+
+
+class ConvolutionGolden(BASE, GoldenMixin):
+ """Golden table for convolution"""
+ __tablename__ = "conv_golden"
+ __table_args__ = (UniqueConstraint("golden_miopen_v",
+ "config",
+ "solver",
+ "arch",
+ "num_cu",
+ name="uq_idx"),)
+
+ config = Column(Integer, ForeignKey("conv_config.id"), nullable=False)
+
+ fdb_key = Column(String(length=128), nullable=True)
+ params = Column(String(length=128), nullable=True)
+ kernel_time = Column(Float, nullable=False)
+ workspace_sz = Column(BigInteger, nullable=False)
+ alg_lib = Column(String(length=64), nullable=True)
+ opencl = Column(Boolean, nullable=False)
+
+ kernel_group = Column(Integer, nullable=True)
+
+
+class ConvolutionBenchmark(BASE, BenchmarkMixin):
+ """benchmark table for framework and model parameters"""
+ __tablename__ = "conv_benchmark"
+ __table_args__ = (UniqueConstraint("framework",
+ "model",
+ "batchsize",
+ "gpu_number",
+ "config",
+ name="uq_idx"),)
+
+ config = Column(Integer, ForeignKey("conv_config.id"), nullable=False)
+
+
+class ConvSolverAnalyticsAggregated(BASE, SolverAnalyticsMixin):
+ """Table to store aggregated results from SolverAnalytics"""
+ __tablename__ = "conv_solver_analytics_aggregated"
+
+ __table_args__ = (UniqueConstraint("golden_miopen_v",
+ "arch",
+ "num_cu",
+ "opencl",
+ "filter",
+ "padding",
+ "stride",
+ "dilation",
+ "layout",
+ "precision",
+ "direction",
+ "sf",
+ name="uq_idx"),)
+
+ # additional stats (null if no alternate solver is available)
+ ta = Column(Float, nullable=True) # alternate solver runtime
+ difference = Column(Float, nullable=True) # runtime difference
+ ratio = Column(Float, nullable=True) # runtime ratio (null if infinity)
+
+
+class ConvSolverAnalyticsDetailed(BASE, SolverAnalyticsMixin):
+ """Table to store detailed results from SolverAnalytics"""
+ __tablename__ = "conv_solver_analytics_detailed"
+
+ __table_args__ = (UniqueConstraint("golden_miopen_v",
+ "arch",
+ "num_cu",
+ "opencl",
+ "filter",
+ "padding",
+ "stride",
+ "dilation",
+ "layout",
+ "precision",
+ "direction",
+ "sf",
+ "sa",
+ name="uq_idx"),)
+
+ # additional stats
+ sa = Column(String(128), nullable=False) # alternate solver
+ ta = Column(Float, nullable=True) # alternate solver runtime
+ difference = Column(Float, nullable=True) # runtime difference
+ ratio = Column(Float, nullable=True) # runtime ratio (null if infinity)
diff --git a/tuna/miopen/db/fusion_config_tables.py b/tuna/miopen/db/fusion_config_tables.py
new file mode 100644
index 000000000..e9242002d
--- /dev/null
+++ b/tuna/miopen/db/fusion_config_tables.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2022 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+"""Represents Fusion config table class definitions """
+
+from sqlalchemy import Column, Integer, UniqueConstraint, ForeignKey
+from sqlalchemy.orm import relationship
+from tuna.dbBase.base_class import BASE
+from tuna.miopen.db.mixin_tables import ConfigTagMixin, MIOpenJobMixin, SolverApplicabilityMixin
+
+COMMON_UNIQ_FDS = ["config", "solver", "session"]
+
+
+#pylint: disable=too-few-public-methods
+class FusionConfig(BASE):
+ """Represents fusion table"""
+ __tablename__ = "fusion_config"
+ __table_args__ = (UniqueConstraint("input_tensor",
+ "weight_tensor",
+ "activ_mode",
+ "fusion_mode",
+ name="uq_idx"),)
+
+ input_tensor = Column(Integer, ForeignKey("tensor.id"), nullable=False)
+ input_t = relationship("TensorTable",
+ backref="input_tensor_fusion",
+ foreign_keys=[input_tensor],
+ lazy="joined")
+ weight_tensor = Column(Integer, ForeignKey("tensor.id"), nullable=False)
+ activ_mode = Column(Integer, nullable=False, server_default="1")
+ fusion_mode = Column(Integer, nullable=False, server_default="1")
+
+
+class FusionConfigTags(BASE, ConfigTagMixin):
+ """Represents config_tags tables"""
+ __tablename__ = "fusion_config_tags"
+ __table_args__ = (UniqueConstraint("config", "tag", name="uq_idx"),)
+
+ config = Column(Integer, ForeignKey("fusion_config.id"), nullable=False)
+
+
+class FusionJob(BASE, MIOpenJobMixin):
+ """Represents fusions job table"""
+ __tablename__ = "fusion_job"
+ __table_args__ = (UniqueConstraint(*COMMON_UNIQ_FDS, name="uq_idx"),)
+
+ config = Column(Integer, ForeignKey("fusion_config.id"), nullable=False)
+
+
+class SolverFusionApplicability(BASE, SolverApplicabilityMixin):
+ """Represents fusion_solver_applicability table"""
+ __tablename__ = "fusion_solver_applicability"
+ __table_args__ = (UniqueConstraint(*COMMON_UNIQ_FDS, name="uq_idx"),)
+
+ config = Column(Integer, ForeignKey("fusion_config.id"), nullable=False)
diff --git a/tuna/miopen/db/get_db_tables.py b/tuna/miopen/db/get_db_tables.py
new file mode 100644
index 000000000..c5a36e9e1
--- /dev/null
+++ b/tuna/miopen/db/get_db_tables.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2022 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+""" Module for get/set/initialize DB - MIOpen/Conv/Fusion/Batchnorm tables"""
+
+from tuna.machine import Machine
+from tuna.miopen.db.benchmark import Framework, Model
+from tuna.miopen.db.solver import Solver
+from tuna.miopen.db.batch_norm_tables import BNBenchmark
+from tuna.miopen.db.convolutionjob_tables import ConvolutionBenchmark
+from tuna.miopen.db.tensortable import TensorTable
+from tuna.miopen.db.miopen_tables import add_bn_tables
+from tuna.miopen.db.miopen_tables import add_conv_tables
+from tuna.miopen.db.miopen_tables import add_fusion_tables
+from tuna.miopen.db.session import Session
+
+
+def get_miopen_tables():
+ """Returns a list of all MIOpen Tuna DB tables"""
+
+ miopen_tables = []
+ miopen_tables.append(Solver())
+ miopen_tables.append(Session())
+ miopen_tables.append(Framework())
+ miopen_tables.append(Model())
+ miopen_tables.append(Machine(local_machine=True))
+ miopen_tables.append(TensorTable())
+
+ miopen_tables = add_conv_tables(miopen_tables)
+ miopen_tables = add_fusion_tables(miopen_tables)
+ miopen_tables = add_bn_tables(miopen_tables)
+
+ return miopen_tables
diff --git a/tuna/miopen/db/miopen_tables.py b/tuna/miopen/db/miopen_tables.py
index d8477cfae..9bf4205b8 100644
--- a/tuna/miopen/db/miopen_tables.py
+++ b/tuna/miopen/db/miopen_tables.py
@@ -25,581 +25,32 @@
#
###############################################################################
""" Module for creating DB tables"""
-import enum
-from sqlalchemy import Column, Integer, String, UniqueConstraint, ForeignKey, DateTime
-from sqlalchemy import Enum, Index
-from sqlalchemy import Float, BigInteger, Boolean
-from sqlalchemy.databases import mysql
-from sqlalchemy.dialects.mysql import TINYINT, MEDIUMBLOB, LONGBLOB
-from sqlalchemy.orm import relationship
-from sqlalchemy.sql import func as sqla_func
-from sqlalchemy.ext.declarative import declared_attr
-from tuna.dbBase.base_class import BASE
-from tuna.machine import Machine
-from tuna.miopen.db.find_db import ConvolutionFindDB, BNFindDB
-from tuna.miopen.utils.config_type import ConfigType
-from tuna.miopen.db.session import Session
-from tuna.miopen.utils.metadata import DIR_MAP
-from tuna.miopen.db.benchmark import Model, Framework
-from tuna.db.tuna_tables import JobMixin
+from tuna.miopen.db.find_db import BNFindDB, ConvolutionFindDB
+from tuna.miopen.db.bn_golden_tables import BNGolden
+from tuna.miopen.db.fusion_config_tables import FusionConfig
+from tuna.miopen.db.fusion_config_tables import FusionConfigTags, FusionJob
+from tuna.miopen.db.fusion_config_tables import SolverFusionApplicability
+from tuna.miopen.db.batch_norm_tables import BNBenchmark, BNConfig
+from tuna.miopen.db.batch_norm_tables import BNConfigTags, BNFinJobCache
+from tuna.miopen.db.batch_norm_tables import BNJob, BNJobCache, BNKernelCache
+from tuna.miopen.db.batch_norm_tables import BNSolverApplicability
+from tuna.miopen.db.convolutionjob_tables import ConvFinJobCache
+from tuna.miopen.db.convolutionjob_tables import ConvJobCache
+from tuna.miopen.db.convolutionjob_tables import ConvSolverAnalyticsAggregated
+from tuna.miopen.db.convolutionjob_tables import ConvSolverAnalyticsDetailed
+from tuna.miopen.db.convolutionjob_tables import ConvSolverApplicability
+from tuna.miopen.db.convolutionjob_tables import ConvolutionBenchmark
+from tuna.miopen.db.convolutionjob_tables import ConvolutionConfig
+from tuna.miopen.db.convolutionjob_tables import ConvolutionConfigTags
+from tuna.miopen.db.convolutionjob_tables import ConvolutionGolden, ConvolutionJob
+from tuna.miopen.db.convolutionjob_tables import ConvolutionKernelCache
COMMON_UNIQ_FDS = ["config", "solver", "session"]
-#pylint: disable=too-few-public-methods
-class CacheMixin():
- """Represents job_cache table"""
-
- kernel_blob = Column(LONGBLOB, nullable=False)
- cache_name = Column(String(length=45), nullable=False)
-
-
-class KernelCacheMixin():
- """Represents Mixin for KernelCache table"""
-
- kernel_name = Column(String(length=4000), nullable=False)
- kernel_args = Column(String(length=9000), nullable=False)
- kernel_blob = Column(MEDIUMBLOB, nullable=False)
- kernel_hash = Column(String(length=128), nullable=False)
- uncompressed_size = Column(Integer, nullable=False)
-
-
-class JobCache(BASE, CacheMixin):
- """Represents job_cache table"""
- __tablename__ = "job_cache"
- __table_args__ = (UniqueConstraint("job_id", name="uq_cache_idx"),)
-
- job_id = Column(Integer, ForeignKey("job.id"), nullable=False)
-
-
-class FinJobCache(BASE, KernelCacheMixin):
- """Represents job_cache table"""
- __tablename__ = "job_cache_fin"
-
- job_id = Column(Integer,
- ForeignKey("job.id", onupdate="CASCADE", ondelete="CASCADE"),
- nullable=False)
- solver_id = Column(Integer,
- ForeignKey("solver.id",
- onupdate="CASCADE",
- ondelete="CASCADE"),
- nullable=False)
-
-
-class BNJobCache(BASE, CacheMixin):
- """Represents job_cache table for batch_norm"""
- __tablename__ = "bn_job_cache"
- __table_args__ = (UniqueConstraint("job_id", name="uq_cache_idx"),)
-
- job_id = Column(Integer, ForeignKey("bn_job.id"), nullable=False)
-
-
-class BNFinJobCache(BASE, KernelCacheMixin):
- """Represents job_cache table for batch_norm"""
- __tablename__ = "bn_job_cache_fin"
-
- job_id = Column(Integer,
- ForeignKey("bn_job.id",
- onupdate="CASCADE",
- ondelete="CASCADE"),
- nullable=False)
- solver_id = Column(Integer,
- ForeignKey("solver.id",
- onupdate="CASCADE",
- ondelete="CASCADE"),
- nullable=False)
-
-
-class ConvJobCache(BASE, CacheMixin):
- """Represents job_cache table for convolutions"""
- __tablename__ = "conv_job_cache"
- __table_args__ = (UniqueConstraint("job_id", name="uq_cache_idx"),)
-
- job_id = Column(Integer, ForeignKey("conv_job.id"), nullable=False)
-
-
-class ConvFinJobCache(BASE, KernelCacheMixin):
- """Represents job_cache table"""
- __tablename__ = "conv_job_cache_fin"
-
- job_id = Column(Integer,
- ForeignKey("conv_job.id",
- onupdate="CASCADE",
- ondelete="CASCADE"),
- nullable=False)
- solver_id = Column(Integer,
- ForeignKey("solver.id",
- onupdate="CASCADE",
- ondelete="CASCADE"),
- nullable=False)
-
-
-class ConvolutionKernelCache(BASE, KernelCacheMixin):
- """Represents kernel_cache table for convolutions"""
- __tablename__ = "conv_kernel_cache"
-
- kernel_group = Column(Integer, nullable=True)
-
-
-class BNKernelCache(BASE, KernelCacheMixin):
- """Represents kernel_cache table for batch_norm"""
- __tablename__ = "bn_kernel_cache"
-
- kernel_group = Column(Integer, nullable=True)
-
-
-class Solver(BASE):
- """Represents solver table"""
- __tablename__ = "solver"
- __table_args__ = (UniqueConstraint("solver", name="uq_idx"),)
-
- solver = Column(String(length=128), unique=True, nullable=False)
- tunable = Column(TINYINT(1), nullable=False, server_default="1")
- config_type = Column(Enum(ConfigType),
- nullable=False,
- server_default="convolution")
- is_dynamic = Column(TINYINT(1), nullable=False, server_default="0")
-
-
-class TensorTable(BASE):
- """Represents tensor table"""
- __tablename__ = "tensor"
- __table_args__ = (UniqueConstraint("dim0",
- "dim1",
- "dim2",
- "dim3",
- "dim4",
- "layout",
- "num_dims",
- "data_type",
- name="uq_idx"),)
-
- dim0 = Column(Integer, nullable=False, server_default="0")
- dim1 = Column(Integer, nullable=False, server_default="0")
- dim2 = Column(Integer, nullable=False, server_default="0")
- dim3 = Column(Integer, nullable=False, server_default="0")
- dim4 = Column(Integer, nullable=False, server_default="0")
- layout = Column(String(60), nullable=False, server_default="NCHW")
- num_dims = Column(Integer, nullable=False, server_default="2")
- data_type = Column(String(60), nullable=False, server_default="FP32")
-
-
-class ConvolutionConfig(BASE):
- """Represents convolution config table"""
- __tablename__ = "conv_config"
-
- batchsize = Column(Integer, nullable=False, server_default="0")
- spatial_dim = Column(Integer, nullable=False, server_default="2")
- pad_h = Column(Integer, nullable=False, server_default="0")
- pad_w = Column(Integer, nullable=False, server_default="0")
- pad_d = Column(Integer, nullable=False, server_default="0")
- conv_stride_h = Column(Integer, nullable=False, server_default="1")
- conv_stride_w = Column(Integer, nullable=False, server_default="1")
- conv_stride_d = Column(Integer, nullable=False, server_default="1")
- dilation_h = Column(Integer, nullable=False, server_default="1")
- dilation_w = Column(Integer, nullable=False, server_default="1")
- dilation_d = Column(Integer, nullable=False, server_default="1")
- group_count = Column(Integer, nullable=False, server_default="1")
- mode = Column(String(length=40), nullable=False, server_default="conv")
- pad_mode = Column(String(length=40), nullable=False, server_default="default")
- trans_output_pad_h = Column(Integer, nullable=False, server_default="0")
- trans_output_pad_w = Column(Integer, nullable=False, server_default="0")
- trans_output_pad_d = Column(Integer, nullable=False, server_default="0")
- direction = Column(String(length=8), nullable=False)
- input_tensor = Column(Integer, ForeignKey("tensor.id"), nullable=False)
- weight_tensor = Column(Integer, ForeignKey("tensor.id"), nullable=False)
- input_t = relationship("TensorTable",
- backref="conv_input_tensor",
- foreign_keys=[input_tensor],
- lazy="joined")
- weight_t = relationship("TensorTable",
- backref="weight_tensor",
- foreign_keys=[weight_tensor],
- lazy="joined")
- out_layout = Column(String(60), nullable=False, server_default="NCHW")
- md5 = Column(String(length=40), nullable=False, unique=True)
- driver = Column(String(length=512), nullable=False, server_default="")
-
-
-class FusionConfig(BASE):
- """Represents fusion table"""
- __tablename__ = "fusion_config"
- __table_args__ = (UniqueConstraint("input_tensor",
- "weight_tensor",
- "activ_mode",
- "fusion_mode",
- name="uq_idx"),)
-
- input_tensor = Column(Integer, ForeignKey("tensor.id"), nullable=False)
- input_t = relationship("TensorTable",
- backref="input_tensor_fusion",
- foreign_keys=[input_tensor],
- lazy="joined")
- weight_tensor = Column(Integer, ForeignKey("tensor.id"), nullable=False)
- activ_mode = Column(Integer, nullable=False, server_default="1")
- fusion_mode = Column(Integer, nullable=False, server_default="1")
-
-
-class BNConfig(BASE):
- """Represents batch normalization table"""
- __tablename__ = "bn_config"
- __table_args__ = (UniqueConstraint("alpha",
- "beta",
- "forw",
- "verify",
- "back",
- "mode",
- "batchsize",
- "run",
- "input_tensor",
- name="uq_idx"),)
-
- alpha = Column(Integer, nullable=False, server_default="1.0")
- beta = Column(Integer, nullable=False, server_default="0.0")
- forw = Column(Integer, nullable=False, server_default="1")
- verify = Column(Integer, nullable=False, server_default="1")
- back = Column(Integer, nullable=False, server_default="0")
- mode = Column(Integer, nullable=False, server_default="0")
- batchsize = Column(Integer, nullable=False, server_default="32")
- run = Column(Integer, nullable=False, server_default="0")
- save = Column(Integer, nullable=False, server_default="0")
- input_tensor = Column(Integer, ForeignKey("tensor.id"), nullable=False)
- input_t = relationship("TensorTable",
- backref="bn_input_tensor",
- foreign_keys=[input_tensor],
- lazy="joined")
- in_layout = Column(String(60), nullable=False, server_default="NCHW")
- driver = Column(String(length=512), nullable=False, server_default="")
-
- def get_direction(self):
- """synthesize direction"""
- return DIR_MAP[(self.forw + 4 * self.back)]
-
-
-class ConfigTagMixin():
- """Mixin class for config tags tables"""
-
- tag = Column(String(length=128), nullable=False, server_default="no_tag")
- recurrent = Column(TINYINT(1), nullable=False, server_default="0")
-
-
-class ConvolutionConfigTags(BASE, ConfigTagMixin):
- """Represents config_tags tables"""
- __tablename__ = "conv_config_tags"
- __table_args__ = (UniqueConstraint("config", "tag", name="uq_idx"),)
-
- config = Column(Integer, ForeignKey("conv_config.id"), nullable=False)
-
-
-class BNConfigTags(BASE, ConfigTagMixin):
- """Represents config_tags tables"""
- __tablename__ = "bn_config_tags"
- __table_args__ = (UniqueConstraint("config", "tag", name="uq_idx"),)
-
- config = Column(Integer, ForeignKey("bn_config.id"), nullable=False)
-
-
-class FusionConfigTags(BASE, ConfigTagMixin):
- """Represents config_tags tables"""
- __tablename__ = "fusion_config_tags"
- __table_args__ = (UniqueConstraint("config", "tag", name="uq_idx"),)
-
- config = Column(Integer, ForeignKey("fusion_config.id"), nullable=False)
-
-
-class FinStep(enum.Enum):
- """ Allowed Fin Steps """
- # pylint: disable=invalid-name ; tuna/go_fish.py names valid fin steps as FinStep.__members__
- find_compile = 1
- find_eval = 2
- get_solvers = 3
- get_applicability = 4
- not_fin = 5
- miopen_find_compile = 6
- miopen_find_eval = 7
- miopen_perf_compile = 8
- miopen_perf_eval = 9
-
-
-class MIOpenJobMixin(JobMixin):
- """Represents MIOpen Mixin class for job tables"""
-
- compile_start = Column(DateTime,
- nullable=False,
- server_default=sqla_func.now())
- compile_end = Column(DateTime, nullable=False, server_default=sqla_func.now())
- eval_start = Column(DateTime, nullable=False, server_default=sqla_func.now())
- eval_end = Column(DateTime, nullable=False, server_default=sqla_func.now())
-
- solver = Column(String(length=128), nullable=True, server_default="")
- eval_mid = Column(Integer, server_default="-1")
- fin_step = Column(mysql.MSSet(*(list(k for k in FinStep.__members__))),
- nullable=False,
- server_default="not_fin")
-
-
-class ConvolutionJob(BASE, MIOpenJobMixin):
- """Represents convolutions job table"""
- __tablename__ = "conv_job"
- __table_args__ = (UniqueConstraint(*COMMON_UNIQ_FDS, name="uq_idx"),)
-
- config = Column(Integer,
- ForeignKey("conv_config.id"),
- nullable=False,
- index=True)
- get_job_ids1 = Index('get_job_idx1', 'session', 'valid', 'reason', 'fin_step',
- 'retries')
- get_job_ids2 = Index('get_job_idx2', 'session', 'valid')
- get_job_ids3 = Index('get_job_idx3', 'session', 'valid', 'retries')
- get_job_compile = Index('get_job_compile', 'valid', 'state', 'reason',
- 'session')
-
-
-class BNJob(BASE, MIOpenJobMixin):
- """Represents batch norm job table"""
- __tablename__ = "bn_job"
- __table_args__ = (UniqueConstraint(*COMMON_UNIQ_FDS, name="uq_idx"),)
-
- config = Column(Integer,
- ForeignKey("bn_config.id"),
- nullable=False,
- index=True)
-
-
-class FusionJob(BASE, MIOpenJobMixin):
- """Represents fusions job table"""
- __tablename__ = "fusion_job"
- __table_args__ = (UniqueConstraint(*COMMON_UNIQ_FDS, name="uq_idx"),)
-
- config = Column(Integer, ForeignKey("fusion_config.id"), nullable=False)
-
-
-class SolverApplicabilityMixin():
- """Represents Mixin class for solver_applicability tables"""
-
- @declared_attr
- def solver(self):
- """solver column"""
- return Column(Integer, ForeignKey("solver.id"), nullable=False, index=True)
-
- @declared_attr
- def session(self):
- """session key"""
- return Column(Integer, ForeignKey("session.id"), nullable=False, index=True)
-
- applicable = Column(TINYINT, nullable=False, server_default="1")
-
-
-class ConvSolverApplicability(BASE, SolverApplicabilityMixin):
- """Represents conv_solver_applicability table"""
- __tablename__ = "conv_solver_applicability"
- __table_args__ = (UniqueConstraint(*COMMON_UNIQ_FDS, name="uq_idx"),)
-
- config = Column(Integer,
- ForeignKey("conv_config.id"),
- nullable=False,
- index=True)
- app_idx = Index('app_idx', 'config', 'solver', 'session')
- sess_cfg = Index('sess_cfg', 'session', 'config')
-
-
-class BNSolverApplicability(BASE, SolverApplicabilityMixin):
- """Represents bn_solver_applicability table"""
- __tablename__ = "bn_solver_applicability"
- __table_args__ = (UniqueConstraint(*COMMON_UNIQ_FDS, name="uq_idx"),)
-
- config = Column(Integer,
- ForeignKey("bn_config.id"),
- nullable=False,
- index=True)
-
-
-class SolverFusionApplicability(BASE, SolverApplicabilityMixin):
- """Represents fusion_solver_applicability table"""
- __tablename__ = "fusion_solver_applicability"
- __table_args__ = (UniqueConstraint(*COMMON_UNIQ_FDS, name="uq_idx"),)
-
- config = Column(Integer, ForeignKey("fusion_config.id"), nullable=False)
-
-
-class GoldenMixin():
- """Mixin for golden table"""
-
- @declared_attr
- def session(self):
- """session foreign key"""
- return Column(Integer, ForeignKey("session.id"), nullable=False)
-
- @declared_attr
- def solver(self):
- """solver foreign key"""
- return Column(Integer,
- ForeignKey("solver.id",
- onupdate="CASCADE",
- ondelete="CASCADE"),
- nullable=False)
-
- golden_miopen_v = Column(Integer, nullable=False)
- arch = Column(String(length=20), nullable=False, server_default="")
- num_cu = Column(Integer, nullable=False, server_default="0")
-
-
-class ConvolutionGolden(BASE, GoldenMixin):
- """Golden table for convolution"""
- __tablename__ = "conv_golden"
- __table_args__ = (UniqueConstraint("golden_miopen_v",
- "config",
- "solver",
- "arch",
- "num_cu",
- name="uq_idx"),)
-
- config = Column(Integer, ForeignKey("conv_config.id"), nullable=False)
-
- fdb_key = Column(String(length=128), nullable=True)
- params = Column(String(length=128), nullable=True)
- kernel_time = Column(Float, nullable=False)
- workspace_sz = Column(BigInteger, nullable=False)
- alg_lib = Column(String(length=64), nullable=True)
- opencl = Column(Boolean, nullable=False)
-
- kernel_group = Column(Integer, nullable=True)
-
-
-class BNGolden(BASE, GoldenMixin):
- """Golden table for batch norm"""
- __tablename__ = "bn_golden"
- __table_args__ = (UniqueConstraint("golden_miopen_v",
- "config",
- "solver",
- "arch",
- "num_cu",
- name="uq_idx"),)
-
- config = Column(Integer, ForeignKey("bn_config.id"), nullable=False)
-
- kernel_group = Column(Integer, nullable=True)
-
-
-class BenchmarkMixin():
- """Mixin class for bechmark tables"""
-
- @declared_attr
- def framework(self):
- """Framework Fkey"""
- return Column(Integer, ForeignKey("framework.id"), nullable=False)
-
- @declared_attr
- def model(self):
- """Model Fkey"""
- return Column(Integer, ForeignKey("model.id"), nullable=False)
-
- batchsize = Column(Integer, nullable=False, server_default="32")
- gpu_number = Column(Integer, nullable=True, server_default="1")
- driver_cmd = Column(String(length=512), nullable=False)
-
-
-class ConvolutionBenchmark(BASE, BenchmarkMixin):
- """benchmark table for framework and model parameters"""
- __tablename__ = "conv_benchmark"
- __table_args__ = (UniqueConstraint("framework",
- "model",
- "batchsize",
- "gpu_number",
- "config",
- name="uq_idx"),)
-
- config = Column(Integer, ForeignKey("conv_config.id"), nullable=False)
-
-
-class BNBenchmark(BASE, BenchmarkMixin):
- """benchmark table for framework and model parameters"""
- __tablename__ = "bn_benchmark"
- __table_args__ = (UniqueConstraint("framework",
- "model",
- "batchsize",
- "gpu_number",
- "config",
- name="uq_idx"),)
-
- config = Column(Integer, ForeignKey("bn_config.id"), nullable=False)
-
-
-class SolverAnalyticsMixin():
- """common columns in aggregated & detailed solver analytics tables"""
- # software state description
- golden_miopen_v = Column(Integer, nullable=False)
- # hardware description
- arch = Column(String(20), nullable=False)
- num_cu = Column(Integer, nullable=False)
- opencl = Column(Boolean, nullable=False)
- # convolution problem description
- filter = Column(String(32), nullable=False)
- padding = Column(String(32), nullable=False)
- stride = Column(String(32), nullable=False)
- dilation = Column(String(32), nullable=False)
- layout = Column(String(8), nullable=False)
- precision = Column(String(8), nullable=False)
- direction = Column(String(1), nullable=False)
- # fastest solvers stats
- sf = Column(String(128), nullable=False) # fastest solver name
- tf = Column(Float, nullable=False) # fastest solver runtime
- count = Column(Integer, nullable=False) # fastest solver count
-
-
-class ConvSolverAnalyticsAggregated(BASE, SolverAnalyticsMixin):
- """Table to store aggregated results from SolverAnalytics"""
- __tablename__ = "conv_solver_analytics_aggregated"
-
- __table_args__ = (UniqueConstraint("golden_miopen_v",
- "arch",
- "num_cu",
- "opencl",
- "filter",
- "padding",
- "stride",
- "dilation",
- "layout",
- "precision",
- "direction",
- "sf",
- name="uq_idx"),)
-
- # additional stats (null if no alternate solver is available)
- ta = Column(Float, nullable=True) # alternate solver runtime
- difference = Column(Float, nullable=True) # runtime difference
- ratio = Column(Float, nullable=True) # runtime ratio (null if infinity)
-
-
-class ConvSolverAnalyticsDetailed(BASE, SolverAnalyticsMixin):
- """Table to store detailed results from SolverAnalytics"""
- __tablename__ = "conv_solver_analytics_detailed"
-
- __table_args__ = (UniqueConstraint("golden_miopen_v",
- "arch",
- "num_cu",
- "opencl",
- "filter",
- "padding",
- "stride",
- "dilation",
- "layout",
- "precision",
- "direction",
- "sf",
- "sa",
- name="uq_idx"),)
-
- # additional stats
- sa = Column(String(128), nullable=False) # alternate solver
- ta = Column(Float, nullable=True) # alternate solver runtime
- difference = Column(Float, nullable=True) # runtime difference
- ratio = Column(Float, nullable=True) # runtime ratio (null if infinity)
-
-
def add_conv_tables(miopen_tables):
- """Append Convolution specific MIOpen DB tables"""
+ """ Append Convolution specific MIOpen DB tables """
miopen_tables.append(ConvolutionConfig())
miopen_tables.append(ConvolutionJob())
miopen_tables.append(ConvolutionConfigTags())
@@ -616,7 +67,7 @@ def add_conv_tables(miopen_tables):
def add_fusion_tables(miopen_tables):
- """Append Fusion specific MIOpen DB tables"""
+ """ Append Fusion specific MIOpen DB tables"""
miopen_tables.append(FusionConfig())
miopen_tables.append(SolverFusionApplicability())
miopen_tables.append(FusionJob())
@@ -625,7 +76,7 @@ def add_fusion_tables(miopen_tables):
def add_bn_tables(miopen_tables):
- """Append BatchNorm specific MIOpen DB tables"""
+ """ Append BatchNorm specific MIOpen DB tables"""
miopen_tables.append(BNConfig())
miopen_tables.append(BNJob())
miopen_tables.append(BNConfigTags())
@@ -637,23 +88,3 @@ def add_bn_tables(miopen_tables):
miopen_tables.append(BNGolden())
miopen_tables.append(BNBenchmark())
return miopen_tables
-
-
-def get_miopen_tables():
- """Returns a list of all MIOpen Tuna DB tables"""
- miopen_tables = []
- miopen_tables.append(Solver())
- miopen_tables.append(Session())
- miopen_tables.append(Framework())
- miopen_tables.append(Model())
- miopen_tables.append(Machine(local_machine=True))
- miopen_tables.append(TensorTable())
-
- miopen_tables = add_conv_tables(miopen_tables)
- miopen_tables = add_fusion_tables(miopen_tables)
- miopen_tables = add_bn_tables(miopen_tables)
-
- miopen_tables.append(ConvolutionBenchmark())
- miopen_tables.append(BNBenchmark())
-
- return miopen_tables
diff --git a/tuna/miopen/db/mixin_tables.py b/tuna/miopen/db/mixin_tables.py
new file mode 100644
index 000000000..b2bfdc1fe
--- /dev/null
+++ b/tuna/miopen/db/mixin_tables.py
@@ -0,0 +1,171 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2022 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+"""Represents Mixin type table class definitions """
+import enum
+from sqlalchemy.sql import func as sqla_func
+from sqlalchemy.databases import mysql
+from sqlalchemy import Float, Boolean
+from sqlalchemy.dialects.mysql import TINYINT, MEDIUMBLOB, LONGBLOB
+from sqlalchemy.ext.declarative import declared_attr
+from sqlalchemy import Column, Integer, String, ForeignKey, DateTime
+
+from tuna.db.tuna_tables import JobMixin
+
+#pylint: disable=too-few-public-methods
+
+
+class FinStep(enum.Enum):
+ """ Allowed Fin Steps """
+ # pylint: disable=invalid-name ; tuna/go_fish.py names valid fin steps as FinStep.__members__
+ find_compile = 1
+ find_eval = 2
+ get_solvers = 3
+ get_applicability = 4
+ not_fin = 5
+ miopen_find_compile = 6
+ miopen_find_eval = 7
+ miopen_perf_compile = 8
+ miopen_perf_eval = 9
+
+
+class MIOpenJobMixin(JobMixin):
+ """Represents MIOpen Mixin class for job tables"""
+
+ compile_start = Column(DateTime,
+ nullable=False,
+ server_default=sqla_func.now())
+ compile_end = Column(DateTime, nullable=False, server_default=sqla_func.now())
+ eval_start = Column(DateTime, nullable=False, server_default=sqla_func.now())
+ eval_end = Column(DateTime, nullable=False, server_default=sqla_func.now())
+
+ solver = Column(String(length=128), nullable=True, server_default="")
+ eval_mid = Column(Integer, server_default="-1")
+ fin_step = Column(mysql.MSSet(*(list(k for k in FinStep.__members__))),
+ nullable=False,
+ server_default="not_fin")
+
+
+class ConfigTagMixin():
+ """Mixin class for config tags tables"""
+
+ tag = Column(String(length=128), nullable=False, server_default="no_tag")
+ recurrent = Column(TINYINT(1), nullable=False, server_default="0")
+
+
+class SolverApplicabilityMixin():
+ """Represents Mixin class for solver_applicability tables"""
+
+ @declared_attr
+ def solver(self):
+ """solver column"""
+ return Column(Integer, ForeignKey("solver.id"), nullable=False, index=True)
+
+ @declared_attr
+ def session(self):
+ """session key"""
+ return Column(Integer, ForeignKey("session.id"), nullable=False, index=True)
+
+ applicable = Column(TINYINT, nullable=False, server_default="1")
+
+
+class CacheMixin():
+ """Represents job_cache table"""
+
+ kernel_blob = Column(LONGBLOB, nullable=False)
+ cache_name = Column(String(length=45), nullable=False)
+
+
+class KernelCacheMixin():
+ """Represents Mixin for KernelCache table"""
+
+ kernel_name = Column(String(length=4000), nullable=False)
+ kernel_args = Column(String(length=9000), nullable=False)
+ kernel_blob = Column(MEDIUMBLOB, nullable=False)
+ kernel_hash = Column(String(length=128), nullable=False)
+ uncompressed_size = Column(Integer, nullable=False)
+
+
+class GoldenMixin():
+ """Mixin for golden table"""
+
+ @declared_attr
+ def session(self):
+ """session foreign key"""
+ return Column(Integer, ForeignKey("session.id"), nullable=False)
+
+ @declared_attr
+ def solver(self):
+ """solver foreign key"""
+ return Column(Integer,
+ ForeignKey("solver.id",
+ onupdate="CASCADE",
+ ondelete="CASCADE"),
+ nullable=False)
+
+ golden_miopen_v = Column(Integer, nullable=False)
+ arch = Column(String(length=20), nullable=False, server_default="")
+ num_cu = Column(Integer, nullable=False, server_default="0")
+
+
+class BenchmarkMixin():
+ """Mixin class for bechmark tables"""
+
+ @declared_attr
+ def framework(self):
+ """Framework Fkey"""
+ return Column(Integer, ForeignKey("framework.id"), nullable=False)
+
+ @declared_attr
+ def model(self):
+ """Model Fkey"""
+ return Column(Integer, ForeignKey("model.id"), nullable=False)
+
+ batchsize = Column(Integer, nullable=False, server_default="32")
+ gpu_number = Column(Integer, nullable=True, server_default="1")
+ driver_cmd = Column(String(length=512), nullable=False)
+
+
+class SolverAnalyticsMixin():
+ """common columns in aggregated & detailed solver analytics tables"""
+ # software state description
+ golden_miopen_v = Column(Integer, nullable=False)
+ # hardware description
+ arch = Column(String(20), nullable=False)
+ num_cu = Column(Integer, nullable=False)
+ opencl = Column(Boolean, nullable=False)
+ # convolution problem description
+ filter = Column(String(32), nullable=False)
+ padding = Column(String(32), nullable=False)
+ stride = Column(String(32), nullable=False)
+ dilation = Column(String(32), nullable=False)
+ layout = Column(String(8), nullable=False)
+ precision = Column(String(8), nullable=False)
+ direction = Column(String(1), nullable=False)
+ # fastest solvers stats
+ sf = Column(String(128), nullable=False) # fastest solver name
+ tf = Column(Float, nullable=False) # fastest solver runtime
+ count = Column(Integer, nullable=False) # fastest solver count
diff --git a/tuna/miopen/db/session.py b/tuna/miopen/db/session.py
index cc4d961c6..4f1f20e8b 100644
--- a/tuna/miopen/db/session.py
+++ b/tuna/miopen/db/session.py
@@ -30,6 +30,7 @@
from tuna.dbBase.base_class import BASE
from tuna.utils.logger import setup_logger
from tuna.db.session_mixin import SessionMixin
+from tuna.miopen.worker.fin_class import FinClass
LOGGER = setup_logger('session_miopen')
@@ -64,7 +65,7 @@ def get_query(self, sess, sess_obj, entry):
return query
- def add_new_session(self, args, worker):
+ def add_new_session(self, args, worker: FinClass):
"""Add new session entry"""
super().add_new_session(args, worker)
diff --git a/tuna/miopen/db/solver.py b/tuna/miopen/db/solver.py
new file mode 100644
index 000000000..0b20b2252
--- /dev/null
+++ b/tuna/miopen/db/solver.py
@@ -0,0 +1,82 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2022 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+""" Module for defining Solver and model enums """
+
+from sqlalchemy import Column, String, UniqueConstraint
+from sqlalchemy import Enum
+from sqlalchemy.dialects.mysql import TINYINT
+from tuna.dbBase.base_class import BASE
+from tuna.dbBase.sql_alchemy import DbSession
+from tuna.miopen.utils.config_type import ConfigType
+from tuna.utils.db_utility import session_retry
+from tuna.utils.logger import setup_logger
+
+LOGGER = setup_logger('miopen_db_utility')
+
+
+#pylint: disable=too-few-public-methods
+class Solver(BASE):
+ """Represents solver table"""
+ __tablename__ = "solver"
+ __table_args__ = (UniqueConstraint("solver", name="uq_idx"),)
+
+ solver = Column(String(length=128), unique=True, nullable=False)
+ tunable = Column(TINYINT(1), nullable=False, server_default="1")
+ config_type = Column(Enum(ConfigType),
+ nullable=False,
+ server_default="convolution")
+ is_dynamic = Column(TINYINT(1), nullable=False, server_default="0")
+
+
+def get_id_solvers():
+ """DB solver id to name map"""
+ solver_id_map_c = {}
+ solver_id_map_h = {}
+ with DbSession() as session:
+ query = session.query(Solver.solver, Solver.id).filter(Solver.valid == 1)
+ res = session_retry(session, query.all, lambda x: x(), LOGGER)
+ for slv, sid in res:
+ solver_id_map_c[slv] = sid
+ solver_id_map_h[slv.replace(', ', '-')] = sid
+ id_solver_map_c = {val: key for key, val in solver_id_map_c.items()}
+ id_solver_map_h = {val: key for key, val in solver_id_map_h.items()}
+
+ return id_solver_map_c, id_solver_map_h
+
+
+def get_solver_ids():
+ """DB solver name to id map"""
+ # TODO: Get this info from the SQLAlchemy class # pylint: disable=fixme
+ solver_id_map = {}
+ with DbSession() as session:
+ query = session.query(Solver.solver, Solver.id).filter(Solver.valid == 1)
+ res = session_retry(session, query.all, lambda x: x(), LOGGER)
+ for slv, sid in res:
+ solver_id_map[slv] = sid
+ solver_id_map[slv.replace(', ', '-')] = sid
+
+ return solver_id_map
diff --git a/tuna/miopen/db/tables.py b/tuna/miopen/db/tables.py
index 579e23881..a6b77b5ce 100644
--- a/tuna/miopen/db/tables.py
+++ b/tuna/miopen/db/tables.py
@@ -25,20 +25,27 @@
#
###############################################################################
"""Module that encapsulates the DB representation based on configuration type"""
+
+from tuna.miopen.db.batch_norm_tables import BNBenchmark, BNConfig
+from tuna.miopen.db.batch_norm_tables import BNConfigTags, BNFinJobCache
+from tuna.miopen.db.batch_norm_tables import BNJob, BNJobCache, BNKernelCache
+from tuna.miopen.db.batch_norm_tables import BNSolverApplicability
from tuna.miopen.db.find_db import ConvolutionFindDB, BNFindDB
-from tuna.miopen.db.miopen_tables import ConvolutionJob, ConvolutionConfig, ConvolutionConfigTags
-from tuna.miopen.db.miopen_tables import ConvJobCache, Solver
-from tuna.miopen.db.miopen_tables import BNJob, BNConfig, BNJobCache, BNFinJobCache, BNConfigTags
-from tuna.miopen.db.miopen_tables import ConvSolverApplicability, BNSolverApplicability
-from tuna.miopen.db.miopen_tables import ConvFinJobCache, BNKernelCache, ConvolutionKernelCache
-from tuna.miopen.db.miopen_tables import TensorTable, ConvolutionGolden, ConvolutionBenchmark
-from tuna.miopen.db.miopen_tables import BNBenchmark
-from tuna.miopen.db.miopen_tables import ConvSolverAnalyticsAggregated, ConvSolverAnalyticsDetailed
+from tuna.miopen.db.convolutionjob_tables import ConvolutionJob
+from tuna.miopen.db.convolutionjob_tables import ConvolutionConfig
+from tuna.miopen.db.convolutionjob_tables import ConvolutionConfigTags
+from tuna.miopen.db.convolutionjob_tables import ConvJobCache
+from tuna.miopen.db.convolutionjob_tables import ConvSolverApplicability
+from tuna.miopen.db.convolutionjob_tables import ConvolutionGolden, ConvolutionBenchmark
+from tuna.miopen.db.convolutionjob_tables import ConvFinJobCache, ConvolutionKernelCache
+from tuna.miopen.db.convolutionjob_tables import ConvSolverAnalyticsAggregated
+from tuna.miopen.db.convolutionjob_tables import ConvSolverAnalyticsDetailed
+from tuna.miopen.db.solver import Solver
+from tuna.miopen.db.tensortable import TensorTable
from tuna.miopen.db.benchmark import Framework, Model
-from tuna.miopen.db.session import Session
-
from tuna.miopen.utils.config_type import ConfigType
from tuna.tables_interface import DBTablesInterface
+from tuna.miopen.db.session import Session
#pylint: disable=too-many-instance-attributes
diff --git a/tuna/miopen/db/tensortable.py b/tuna/miopen/db/tensortable.py
new file mode 100644
index 000000000..7ba942a7f
--- /dev/null
+++ b/tuna/miopen/db/tensortable.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2022 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+""" Module for defining Tensor Table and model enums """
+
+from sqlalchemy import Column, Integer, String, UniqueConstraint
+from tuna.dbBase.base_class import BASE
+
+#pylint: disable=too-few-public-methods
+
+
+class TensorTable(BASE):
+ """Represents tensor table"""
+ __tablename__ = "tensor"
+ __table_args__ = (UniqueConstraint("dim0",
+ "dim1",
+ "dim2",
+ "dim3",
+ "dim4",
+ "layout",
+ "num_dims",
+ "data_type",
+ name="uq_idx"),)
+
+ dim0 = Column(Integer, nullable=False, server_default="0")
+ dim1 = Column(Integer, nullable=False, server_default="0")
+ dim2 = Column(Integer, nullable=False, server_default="0")
+ dim3 = Column(Integer, nullable=False, server_default="0")
+ dim4 = Column(Integer, nullable=False, server_default="0")
+ layout = Column(String(60), nullable=False, server_default="NCHW")
+ num_dims = Column(Integer, nullable=False, server_default="2")
+ data_type = Column(String(60), nullable=False, server_default="FP32")
diff --git a/tuna/miopen/doc/FinDocs.md b/tuna/miopen/doc/FinDocs.md
new file mode 100644
index 000000000..884647178
--- /dev/null
+++ b/tuna/miopen/doc/FinDocs.md
@@ -0,0 +1,21 @@
+MIFin Documentation
+===================
+
+How to run MIFin inside a docker!
+---------------------------------
+
+To run MIFin steps in MITuna, a new docker needs to be created, that contains MIFin, MIOpen
+and later a clone of MITuna.
+Steps.
+
+.. code-block::
+ :caption: Navigate to a clone of MITuna and run:
+
+ docker build -f Dockerfile -t my_docker_name --build-arg 'BACKEND=HIPNOGPU' .
+ drun --network host my_docker_name bash
+ cd
+ git clone https://github.com/ROCm/MITuna.git
+ cd MITuna/tuna
+ ./go_fish.py miopen --update_solver
+ ./go_fish.py miopen --init_session -l someReason -a gfx908 -n 120
+ ./go_fish.py miopen --update_applicability --session_id 1
diff --git a/tuna/miopen/doc/Tuning.md b/tuna/miopen/doc/Tuning.md
new file mode 100644
index 000000000..1c2831610
--- /dev/null
+++ b/tuna/miopen/doc/Tuning.md
@@ -0,0 +1,250 @@
+Tuning through MIOpen
+=====================
+
+As a high-performance kernels library, MIOpen needs a substantive tuning effort to discover the
+optimal tuning parameters. Kernel tuning entails compiling and running MIOpen kernels with different
+tuning parameters to determine the best performing tuning parameters for each kernel. While MIOpen
+contains much of the logic needed to iterate over possible tuning parameters, it is only applicable
+to a single machine. Therefore, a mechanism is required to parallelize this process across different
+machines as well as across multiple GPUs to speed up this inherently parallel procedure. Among other
+features, such a framework, it needs to be able to handle errors in both MIOpen and the stack on which
+MIOpen depends.
+
+Tuna is MIOpens team library, which parallelizes the tuning procedure across multiple GPUs on
+multiple machines. In addition to distributing jobs to servers, it is aware of the various
+architectures, whether a tuning effort was successful, or resulted in an error and other housekeeping.
+This makes it a useful automation tool. Tuna is also the custodian of the convolution layer parameters
+of interest (to the MIOpen team), received from customers, as well as various benchmarks. With the
+introduction of 'find database' for immediate mode, Tuna is also responsible for generating Find
+database as well as the upcoming precompiled kernels package.
+
+Tuna uses [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html)
+as a scheduler and offloads job scheduling. The backend and broker for Celery are both implemented
+through Redis. Tuning jobs are enqueued in a Redis queue then launched and executed through one or
+more Celery workers depending on the operation - compile or eval.
+
+When do we tune
+---------------
+
+There are two occasions that trigger tuning:
+1. Someone opens a Github issue that contains the configurations and network to be tuned.
+This implies we only need to tune the network specified in the issue along with the
+configurations specified. If the person requesting this did not mention any configurations,
+please ask for them. The Tuna team does not provide these.
+2. Recurrent configurations need retuning when internals of MIOpen/Tuna change. The tuning
+phase of all the recurrent configurations takes up to a few days. There are many configurations
+used for each network and one should try and use as many machines as possible to speed up
+the tuning part.
+
+MIOpen Tuning Steps
+-------------------
+
+Tuning stores final data in a central mySQL database(DB). Each reference to a table,
+refers to a table in this database. Table mixins can be found in
+[tuna_tables.py](tuna/db/tuna_tables.py) and [session_mixin.py](tuna/db/session_mixin.py). The
+actual implementations of these mixins can be found in the case of MIOpen in
+[miopen/db](tuna/miopen/db).
+Intermittent tuning data generated by the celery workers is stored in a REDIS DB. The MITuna enqueue
+call will then drain this DB and populate the mySQL DB with the final results.
+
+Tuning is divided in multiple steps and each step builds on top of the previous ones.
+To start a tuning session, some prerequisite have to be asserted: setting up configurations,
+getting the latest solvers and their associated applicability from MIOpen,
+and adding the jobs that compose the tuning session. The correct environment variables defined in
+the README must be set in each tuning terminal as well.
+Once these prerequisite are established the tuning session can begin. Each step,
+including the prerequisites are detailed below.
+
+**Add Network Configurations(1)**
+
+Before a configuration gets tagged, a model and framework need to be added. This allows for
+benchmarking of a certain model, post tuning.
+
+.. code-block::
+
+ ./go_fish.py --add_model Resnet50 --md_version 1
+ ./go_fish.py --add_framework Pytorch --fw_version 1
+ --add_model - model name
+ --md_version - model version
+ --add_framework - framework name
+ --fw_version - framework version
+
+The config table contains network configurations. If provided with a text file of MIOpenDriver
+commands, the import script can translate those commands and populate the config table.
+Additionally the user may provide a name to tag a configuration for easier recall later.
+A tag will be required when adding a tuning job. Tags are stored in the config_tags table.
+A model and framework name and version are also required. This enables MITuna to track
+benchmark performance post-tuning.
+
+.. code-block::
+
+ ./go_fish.py miopen import_configs --add_model Resnet50 --md_version 1
+ ./go_fish.py miopen import_configs --add_framework Pytorch --fw_version 1
+ ./go_fish.py miopen import_configs -t resnet50 -f ../utils/recurrent_cfgs/resnet50.txt
+ --model Resnet50 --md_version 1 --framework Pytorch --fw_version 1
+ -t - tag
+ -f - filepath
+ --model - model name
+ --md_version - model version
+ --framework - framework name
+ --fw_version - framework version
+
+**Add Solvers (2)**
+
+The solver table contains MIOpen solvers and solver characteristics.
+This should be updated when an MIOpen version modifies solvers.
+
+.. code-block::
+
+ ./go_fish.py miopen --update_solvers
+
+**Add Tuning Session (3)**
+
+Session will track the architecture and skew, as well as the miopen version and
+rocm version for the tuning session.
+
+This command will need to be run from inside the tuning environment eg MITuna docker
+and will populate the table with the version and architecture information.
+
+.. code-block::
+ :caption: [Use backend=HIPNOGPU docker]
+
+ ./go_fish.py miopen --init_session -l reason
+ --init_session - create a session entry
+ -l - reference text description
+
+**Add Applicability (4)**
+
+Each network configuration has a set of applicable solvers. This step will update the
+solver_applicability table with applicable solvers for each configuration for the session.
+
+.. code-block::
+ :caption: [Use backend=HIPNOGPU docker]
+
+ ./go_fish.py miopen --update_applicability --session_id 1
+ --session_id - tuning session id
+
+**Load Jobs (5)**
+
+Time to create the jobs for the tuning session. Specify the session id, the configs that
+should be tuned, and the fin_step to be executed. Configs can be added by using the tag from
+the config_tags table. Jobs should have a compile and an eval MIFin step pair.
+
+Fin steps include: miopen_perf_compile, miopen_perf_eval, miopen_find_compile, and miopen_find_eval.
+
+.. code-block::
+
+ ./load_job.py --session_id 1 -t resnet50 --fin_steps miopen_perf_compile,miopen_perf_eval -o -l reason
+ --session_id - tuning session id
+ -t - config tag
+ --fin_steps - operations to be performed by MIFin (tuning handle into miopen)
+ -o - only_applicable, will create a job for each applicable solver
+ -l - reference text description
+
+**Compile Step (6)**
+
+Once prerequisites are set, tuning can begin. To compile the jobs,
+supply the session id along with the compile fin_step matching the one in the job table.
+This step is launched in 2 different terminals: the job-enqueue terminal and the job-execution
+terminal.
+
+To enqueue the jobs run the following on any node:
+
+.. code-block::
+ :caption: [Use backend=HIPNOGPU docker]
+
+ ./go_fish.py miopen --session_id 1 --fin_steps miopen_perf_compile --enqueue_only
+ --session_id - tuning session id
+ --fin_steps - execute this operation
+ --enqueue_only - enqueue the jobs to the redis queue
+
+To launch the jobs through Celery workers, on the compile node run:
+
+.. code-block::
+
+ ./go_fish.py miopen --session_id 1 --fin_steps miopen_perf_compile
+ --session_id - tuning session id
+ --fin_steps - execute this operation
+
+**Evaluation Step (7)**
+
+Once compilation has been started, evaluation can also be launched.
+This command is similar to the previous. It is also comprised of 2 steps, the job enqueue process,
+and the job execution process that launched Celery workers on the evaluation node.
+
+[Use backend=HIP docker]
+To enqueue the jobs run the following on any node:
+
+.. code-block::
+
+ ./go_fish.py miopen --session_id 1 --fin_steps miopen_perf_eval --enqueue_only
+ --session_id - tuning session id
+ --fin_steps - execute this operation
+ --enqueue_only - enqueue the jobs to the redis queue
+
+**Database Export (8)**
+
+To export the results the export_db.py script can be run with options
+for selecting session as well as database type.
+
+The outputs of this function are database files in the format that MIOpen keeps and manages.
+eg for MI100, -p will produce a gfx90878.db file, -f will produce gfx90878.HIP.fdb.txt, and -k will produce gfx90878.kdb.
+
+.. code-block::
+
+ ./export_db.py --session_id 1 -p
+ --session_id - tuning session id
+ -p - export performance db
+ -f - export find db
+ -k - export kernel db
+
+
+.. note::
+ A celery worker can also be launched manually. It requires a few extra env variables. Launch the
+ enqueue step in a terminal, then separately launch the celery worker manually, sample:
+
+.. code-block::
+
+ export CELERY_BROKER_URL=redis://:6379/14
+ export CELERY_RESULT_BACKEND=redis://:6379/15
+ cd MITuna
+ celery -A tuna.celery_app.celery_app worker -l info --logfile= -n -Q
+
+Sample manual launch for the eval step:
+
+.. code-block::
+
+ export CELERY_BROKER_URL=amqp://@:6379/14
+ export CELERY_RESULT_BACKEND=redis://:6379/15
+ cd MITuna
+
+ celery -A tuna.celery_app.celery_app worker -l info --logfile=_gpu_id_ -n _gpu_id_ -Q -c 1
+
+Launching the worker manually can help with debugging part of the code used by the celery worker
+such as the decorated celery task (@app.task).
+
+
+MIOpen Golden Database
+----------------------
+
+Tuna's MySQL database tracks versioned data for MIOpen.
+These versions are kept in the golden table. A golden miopen version holds
+the complete tuning history at each step.
+
+Adding to the Golden Table
+--------------------------
+
+Once a tuning session is approved, the results in the generated find db
+may be used to populate the golden table.
+
+.. code-block::
+
+ ./update_golden.py --session_id \ --golden_v \ --base_golden_v \
+ --golden_v - create this golden version
+ --base_golden_v - initialize the new golden version with this previous golden data
+ --session_id - id of the tuning session to populate the golden version
+ --overwrite - may be used to force writing to existing golden_v
+
+If there are no previous golden version `--base_golden_v` need not be specified.
+Otherwise writing a new golden version will require `--base_golden_v`.
+
diff --git a/tuna/miopen/driver/base.py b/tuna/miopen/driver/base.py
index c1ecf50f0..2bed7e6ef 100755
--- a/tuna/miopen/driver/base.py
+++ b/tuna/miopen/driver/base.py
@@ -32,74 +32,70 @@
from sqlalchemy.orm import Session
from sqlalchemy.orm import Query
from sqlalchemy.inspection import inspect
-from tuna.dbBase.sql_alchemy import DbSession
from tuna.utils.logger import setup_logger
+from tuna.dbBase.sql_alchemy import DbSession
from tuna.utils.db_utility import build_dict_val_key, get_session_val_map
-from tuna.miopen.db.miopen_tables import TensorTable
-from tuna.miopen.db.miopen_tables import ConvolutionConfig
+from tuna.miopen.db.tensortable import TensorTable
+from tuna.miopen.db.convolutionjob_tables import ConvolutionConfig
from tuna.miopen.utils.metadata import TENSOR_PRECISION
from tuna.miopen.utils.parsing import parse_line
+from tuna.driver import DriverBase
-LOGGER = setup_logger('driver_base')
+LOGGER = setup_logger('MIOpenDriver_driver_base')
-class DriverBase():
+# pylint: disable=too-many-instance-attributes
+# pylint: disable=too-many-public-methods
+class MIOpenDriver(DriverBase):
"""Represents db tables based on ConfigType"""
tensor_attr: List[str] = [column.name for column in inspect(TensorTable).c]
tensor_id_map: Dict[str, int] = {}
- def __init__(self,
- line: str = str(),
- db_obj: ConvolutionConfig = None) -> None:
- if line:
- if not self.__construct_driver(line):
- raise ValueError(f"Error creating Driver from line: '{line}'")
- elif db_obj:
- if not self.__construct_driver_from_db(db_obj):
- raise ValueError(
- f"Error creating Driver from db obj: '{db_obj.to_dict()}'")
- else:
- raise ValueError(
- "Error creating Driver. MIOpen Driver cmd line or db_obj required")
+ def __init__(self, line: str = str(), db_obj: ConvolutionConfig = None):
+ super().__init__(line, db_obj)
- def parse_fdb_key(self, line: str):
- """Overloaded method.Defined in conv&bn driver child class"""
+ @abstractmethod
+ def config_set_defaults(self) -> None:
+ """Setting config DB defaults to avoid duplicates through SELECT"""
raise NotImplementedError("Not implemented")
- def parse_row(self, db_obj: ConvolutionConfig):
- """Overloaded method.Defined in conv&bn driver child class"""
+ @abstractmethod
+ def set_cmd(self, data_type: str) -> None:
+ """Set cmd based on tensor data type"""
raise NotImplementedError("Not implemented")
- def set_cmd(self, data_type: str):
- """Overloaded method.Defined in conv&bn driver child class"""
+ @abstractmethod
+ def compose_weight_t(self) -> dict:
+ """Build weight_tensor"""
raise NotImplementedError("Not implemented")
- def config_set_defaults(self):
- """Overloaded method.Defined in conv&bn driver child class"""
+ @abstractmethod
+ def parse_row(self, db_obj: ConvolutionConfig):
+ """Abstract/Inference for Overwritting base class function for batch_norm"""
raise NotImplementedError("Not implemented")
+ @abstractmethod
def get_layouts(self):
"""Return operation layouts"""
raise NotImplementedError("Not implemented")
- @abstractmethod
- def compose_weight_t(self):
- """Overloaded method.Defined in conv&br driver child class"""
+ def parse_fdb_key(self, line: str) -> None:
+ """Import config attributes from fdb key line"""
raise NotImplementedError("Not implemented")
@staticmethod
def test_skip_arg(tok1: str):
- """Overloaded method.Defined in conv&br driver child class"""
+ """Check if token is skipable"""
raise NotImplementedError("Not implemented")
@staticmethod
def get_params(tok1: str):
- """Overloaded method.Defined in conv&br driver child class"""
+ """Get full arg name"""
raise NotImplementedError("Not implemented")
@staticmethod
def get_check_valid(tok1: str, tok2: Union[str, int]):
- """Overloaded method.Defined in conv&br driver child class"""
+ """Check if valid conv arg"""
raise NotImplementedError("Not implemented")
@staticmethod
@@ -107,8 +103,8 @@ def get_common_cols() -> List[str]:
"""Returns common MIOpenDriver command line args"""
return ['wall', 'time', 'iter', 'verify']
- def __construct_driver(self, line: str) -> bool:
- """Takes a MIOpenDriver cmd or PDB key"""
+ def construct_driver(self, line: str) -> bool:
+ """Takes MIOpen line description of a configuration"""
LOGGER.info('Processing line: %s', line)
if line.find('=') != -1:
@@ -122,7 +118,7 @@ def __construct_driver(self, line: str) -> bool:
self.config_set_defaults()
return True
- def __construct_driver_from_db(self, db_obj: Any) -> bool:
+ def construct_driver_from_db(self, db_obj: Any) -> bool:
"""Takes a <>_config row and returns a driver cmd"""
LOGGER.info('Processing db_row: %s', db_obj.to_dict())
#common tensor among convolution and batch norm
@@ -171,10 +167,10 @@ def __insert_tensor(self, tensor_dict: dict) -> int:
tid.valid = 1
key = build_dict_val_key(tid)
#cache the tensor table to avoid queries
- if not DriverBase.tensor_id_map:
- DriverBase.tensor_id_map = get_session_val_map(
- session, TensorTable, DriverBase.tensor_attr)
- id_map = DriverBase.tensor_id_map
+ if not MIOpenDriver.tensor_id_map:
+ MIOpenDriver.tensor_id_map = get_session_val_map(
+ session, TensorTable, MIOpenDriver.tensor_attr)
+ id_map = MIOpenDriver.tensor_id_map
if key in id_map:
ret_id = id_map[key]
LOGGER.info("Get Tensor: %s", ret_id)
@@ -188,8 +184,8 @@ def __insert_tensor(self, tensor_dict: dict) -> int:
LOGGER.warning(err)
session.rollback()
#update tensor table cache
- DriverBase.tensor_id_map = get_session_val_map(session, TensorTable,
- DriverBase.tensor_attr)
+ MIOpenDriver.tensor_id_map = get_session_val_map(
+ session, TensorTable, MIOpenDriver.tensor_attr)
ret_id = self.get_tensor_id(session, tensor_dict)
LOGGER.info("Get Tensor: %s", ret_id)
return ret_id
@@ -218,7 +214,7 @@ def __compose_input_t(self) -> Dict[str, int]:
i_dict['dim3'] = self.in_h
i_dict['dim4'] = self.in_w
i_dict['layout'] = self.in_layout
- elif self.in_layout == 'NHWC':
+ elif self.in_layout in ('NHWC', 'NDHWC'):
i_dict['dim1'] = self.in_d
i_dict['dim2'] = self.in_h
i_dict['dim3'] = self.in_w
@@ -235,12 +231,12 @@ def __decompose_input_t(self, db_obj: Any) -> bool:
self.num_dims = db_obj.input_t.num_dims
self.in_layout = db_obj.input_t.layout
- if self.in_layout == 'NCHW':
+ if self.in_layout in ('NCHW', 'NCDHW'):
self.in_channels = db_obj.input_t.dim1
self.in_d = db_obj.input_t.dim2
self.in_h = db_obj.input_t.dim3
self.in_w = db_obj.input_t.dim4
- elif self.in_layout == 'NHWC':
+ elif self.in_layout in ('NHWC', 'NDHWC'):
self.in_d = db_obj.input_t.dim1
self.in_h = db_obj.input_t.dim2
self.in_w = db_obj.input_t.dim3
diff --git a/tuna/miopen/driver/batchnorm.py b/tuna/miopen/driver/batchnorm.py
index 40a304d75..4accb68e7 100755
--- a/tuna/miopen/driver/batchnorm.py
+++ b/tuna/miopen/driver/batchnorm.py
@@ -27,11 +27,11 @@
"""Module that encapsulates the DB representation of a batch_normDriver cmd"""
from tuna.utils.logger import setup_logger
-from tuna.miopen.driver.base import DriverBase
+from tuna.miopen.driver.base import MIOpenDriver
from tuna.miopen.utils.metadata import BN_CONFIG_COLS, IN_TENSOR_COLS, PREC_TO_CMD
from tuna.miopen.utils.metadata import SUPPORTED_BN_CMDS, TABLE_COLS_BN_MAP, BN_DEFAULTS
from tuna.miopen.utils.metadata import DIRECTION, DIR_MAP, BN_SKIP_ARGS
-from tuna.miopen.db.miopen_tables import BNConfig
+from tuna.miopen.db.batch_norm_tables import BNConfig
from tuna.miopen.utils.parsing import get_fd_name, arg_valid
from tuna.miopen.utils.helper import get_db_id
from tuna.miopen.utils.config_type import ConfigType
@@ -40,7 +40,7 @@
#pylint: disable=too-many-instance-attributes
-class DriverBatchNorm(DriverBase):
+class DriverBatchNorm(MIOpenDriver):
"""Represents db tables based on ConfigType"""
def __init__(self,
@@ -87,10 +87,6 @@ def parse_driver_line(self, line: str) -> None:
super().parse_driver_line(line)
self.compute_direction()
- def parse_fdb_key(self, line):
- """ Overidden Method"""
- raise NotImplementedError("Not implemented")
-
def compose_weight_t(self):
""" Overridden Method """
raise NotImplementedError("Not implemented")
diff --git a/tuna/miopen/driver/convolution.py b/tuna/miopen/driver/convolution.py
index 88f2327c4..1ac429173 100755
--- a/tuna/miopen/driver/convolution.py
+++ b/tuna/miopen/driver/convolution.py
@@ -29,14 +29,15 @@
from typing import Dict, Set, Optional, Any
from re import search
from tuna.utils.logger import setup_logger
-from tuna.miopen.driver.base import DriverBase
+from tuna.miopen.driver.base import MIOpenDriver
from tuna.miopen.utils.metadata import CONV_CONFIG_COLS
from tuna.miopen.utils.helper import get_db_id
-from tuna.miopen.db.miopen_tables import ConvolutionConfig
+from tuna.miopen.db.convolutionjob_tables import ConvolutionConfig
from tuna.miopen.utils.metadata import CONV_2D_DEFAULTS, SUPPORTED_CONV_CMDS, PREC_TO_CMD
from tuna.miopen.utils.metadata import CONV_3D_DEFAULTS, TENSOR_COLS
-from tuna.miopen.utils.metadata import TABLE_COLS_CONV_MAP, TENSOR_PRECISION
-from tuna.miopen.utils.metadata import DIRECTION, DIR_MAP, CONV_SKIP_ARGS, INVERS_DIR_MAP
+from tuna.miopen.utils.metadata import TABLE_COLS_CONV_MAP, TENSOR_PRECISION, DIR_MAP
+from tuna.miopen.utils.metadata import DIRECTION, CONV_SKIP_ARGS, INVERS_DIR_MAP
+from tuna.miopen.utils.metadata import SUPPORTED_LAYOUTS
from tuna.miopen.utils.parsing import get_fd_name, conv_arg_valid, get_fds_from_cmd
from tuna.miopen.utils.config_type import ConfigType
@@ -44,7 +45,7 @@
#pylint: disable=too-many-instance-attributes
-class DriverConvolution(DriverBase):
+class DriverConvolution(MIOpenDriver):
"""Represents an MIOpenDriver convolution command"""
def __init__(self,
@@ -79,9 +80,9 @@ def __init__(self,
self.trans_output_pad_h: int = 0
self.trans_output_pad_w: int = 0
self.trans_output_pad_d: int = 0
- self.out_layout: str = 'NCHW'
- self.in_layout: str = 'NCHW'
- self.fil_layout: str = 'NCHW'
+ self.out_layout: str = None # type: ignore #use config_set_defaults to pull from 2D/3D defaults
+ self.in_layout: str = None # type: ignore
+ self.fil_layout: str = None # type: ignore
self.in_d: int = 1
self.in_h: int = 32
self.in_w: int = 32
@@ -109,6 +110,11 @@ def __init__(self,
if self.in_layout != self.out_layout != self.fil_layout:
raise ValueError(
'Layouts do not match: in_layout/out_layout/fil_layout must match.')
+ for layout in [self.in_layout, self.out_layout, self.fil_layout]:
+ if not layout in SUPPORTED_LAYOUTS:
+ raise ValueError(
+ f'Layout {layout} is not a supported layout: ({SUPPORTED_LAYOUTS}).'
+ )
@property
def cmd(self) -> str:
@@ -124,24 +130,27 @@ def cmd(self, value: str) -> None:
)
self._cmd = value
- def get_layouts(self):
- """Get convolution layouts"""
- return ["in_layout", "out_layout", 'fil_layout']
-
def parse_fdb_key(self, line: str) -> None:
- """import config attributes from fdb key line"""
- fds: str
+ """Import config attributes from fdb key line"""
+ fds: dict
direction: str
fds, _, direction = get_fds_from_cmd(line)
- setattr(self, 'direction', DIR_MAP[direction])
+ setattr(self, 'direction',
+ DIR_MAP.get(direction,
+ '')) # Use .get() to safely access the dictionary
+
for key in self.to_dict():
if key in fds:
setattr(self, key, fds[key])
- pattern_3d = '[0-9]x[0-9]x[0-9]'
+ pattern_3d = '[0-9]+x[0-9]+x[0-9]+'
if search(pattern_3d, line):
setattr(self, 'spatial_dim', 3)
+ def get_layouts(self):
+ """Get convolution layouts"""
+ return ["in_layout", "out_layout", 'fil_layout']
+
def parse_driver_line(self, line: str) -> None:
"""Parse MIOpenDriver line"""
super().parse_driver_line(line)
diff --git a/tuna/miopen/metadata.py b/tuna/miopen/metadata.py
index ee60f65bc..6f06c9215 100644
--- a/tuna/miopen/metadata.py
+++ b/tuna/miopen/metadata.py
@@ -31,6 +31,11 @@
'list_solvers', 'fin_steps', 'import_db', 'check_status', 'execute_cmd'
]
+MIOPEN_CELERY_STEPS = [
+ "miopen_find_compile", "miopen_find_eval", "miopen_perf_compile",
+ "miopen_perf_eval"
+]
+
MIOPEN_SUBCOMMANDS = [
'import_configs', 'load_job', 'export_db', 'update_golden'
]
diff --git a/tuna/miopen/miopen_lib.py b/tuna/miopen/miopen_lib.py
index 8aba2cbb3..db5c17bc7 100644
--- a/tuna/miopen/miopen_lib.py
+++ b/tuna/miopen/miopen_lib.py
@@ -27,38 +27,61 @@
"""MIOpen class that holds MIOpen specifig tuning functionality"""
import sys
-
+import copy
+from typing import List, Tuple, Any
+from functools import lru_cache
+from collections.abc import Iterable
+
+from kombu.utils.uuid import uuid
+from sqlalchemy.inspection import inspect
+from sqlalchemy.exc import OperationalError, DataError, IntegrityError
from tuna.mituna_interface import MITunaInterface
from tuna.miopen.utils.helper import print_solvers
from tuna.parse_args import TunaArgs, setup_arg_parser, args_check
-from tuna.miopen.db.miopen_tables import FinStep, get_miopen_tables
+
+from tuna.dbBase.sql_alchemy import DbSession
+from tuna.tables_interface import DBTablesInterface
+from tuna.utils.utility import SimpleDict, serialize_chunk
+from tuna.utils.machine_utility import load_machines
+from tuna.utils.db_utility import gen_select_objs, has_attr_set, get_class_by_tablename
+from tuna.miopen.db.get_db_tables import get_miopen_tables
+from tuna.miopen.db.mixin_tables import FinStep
from tuna.miopen.utils.metadata import MIOPEN_ALG_LIST
+from tuna.miopen.metadata import MIOPEN_CELERY_STEPS
from tuna.miopen.worker.fin_class import FinClass
-from tuna.miopen.worker.fin_builder import FinBuilder
-from tuna.miopen.worker.fin_eval import FinEvaluator
-from tuna.worker_interface import WorkerInterface
from tuna.miopen.db.session import Session
-from tuna.utils.miopen_utility import load_machines
-from tuna.libraries import Library
from tuna.miopen.subcmd.import_configs import run_import_configs
from tuna.miopen.subcmd.load_job import run_load_job
from tuna.miopen.subcmd.export_db import run_export_db
from tuna.miopen.subcmd.update_golden import run_update_golden
-from tuna.miopen.parse_miopen_args import get_import_cfg_parser
-from tuna.miopen.parse_miopen_args import get_load_job_parser
-from tuna.miopen.parse_miopen_args import get_export_db_parser
-from tuna.miopen.parse_miopen_args import get_update_golden_parser
+from tuna.miopen.parse_miopen_args import get_import_cfg_parser, get_load_job_parser
+from tuna.miopen.parse_miopen_args import get_export_db_parser, get_update_golden_parser
from tuna.miopen.db.build_schema import create_tables, recreate_triggers
from tuna.miopen.db.triggers import drop_miopen_triggers, get_miopen_triggers
from tuna.miopen.utils.config_type import ConfigType
+from tuna.miopen.db.tables import MIOpenDBTables
+#from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue
+from tuna.miopen.utils.json_to_sql import process_fdb_w_kernels, process_pdb_compile
+from tuna.miopen.utils.json_to_sql import clean_cache_table
+from tuna.miopen.utils.helper import set_job_state
+from tuna.miopen.worker.fin_utils import get_fin_result
+from tuna.miopen.db.solver import get_solver_ids
+from tuna.libraries import Library, Operation
+from tuna.custom_errors import CustomError
+
+MAX_ERRORED_JOB_RETRIES = 3
+Q_NAME = None
class MIOpen(MITunaInterface):
"""Class to support MIOpen specific tuning functionality"""
+ # pylint: disable=too-many-public-methods
+
def __init__(self):
super().__init__(library=Library.MIOPEN)
self.args = None
+ self.set_state = None
def parse_args(self):
# pylint: disable=too-many-statements
@@ -68,7 +91,8 @@ def parse_args(self):
TunaArgs.ARCH, TunaArgs.NUM_CU, TunaArgs.VERSION,
TunaArgs.CONFIG_TYPE, TunaArgs.SESSION_ID, TunaArgs.MACHINES,
TunaArgs.REMOTE_MACHINE, TunaArgs.LABEL, TunaArgs.RESTART_MACHINE,
- TunaArgs.DOCKER_NAME
+ TunaArgs.DOCKER_NAME, TunaArgs.SHUTDOWN_WORKERS,
+ TunaArgs.ENQUEUE_ONLY
])
parser.add_argument(
'--find_mode',
@@ -175,6 +199,9 @@ def parse_args(self):
self.args = parser.parse_args()
+ if self.args.config_type is None:
+ self.args.config_type = ConfigType.convolution
+
#overwritte common lib args with subcommand args value
if self.args.subcommand is not None:
self.overwrite_common_args()
@@ -185,10 +212,11 @@ def parse_args(self):
if self.args.list_solvers:
print_solvers()
- raise ValueError('Printing solvers...')
+ raise CustomError('Printing solvers...')
if self.args.fin_steps and self.args.subcommand != 'load_job':
self.check_fin_args(parser)
+ self.set_prefix()
if self.args.find_mode is None and not (self.args.check_status or
self.args.restart_machine or
@@ -211,6 +239,22 @@ def parse_args(self):
if (self.args.update_applicability or has_fin) and not self.args.session_id:
parser.error("session_id must be specified with this operation")
+ self.dbt = MIOpenDBTables(session_id=self.args.session_id,
+ config_type=self.args.config_type)
+ self.update_operation()
+
+ def set_prefix(self):
+ """Set redis key prefix"""
+ if isinstance(self.args.fin_steps, Iterable):
+ steps_str = ('-').join(x for x in self.args.fin_steps)
+ self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_"\
+ f"{steps_str}"
+ else:
+ steps_str = self.args.fin_steps[0]
+ self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_{steps_str}"
+
+ self.logger.info('redis prefix: %s', self.prefix)
+
def overwrite_common_args(self):
"""Overwrite common MIOpen_lib args with subcommand args"""
if self.args.subcommand is not None:
@@ -236,6 +280,7 @@ def check_fin_args(self, parser):
def check_blacklist(self, parser):
"""! Helper function
@param parser The command line argument parser
+ @return ret Boolean value
"""
self.args.blacklist = self.args.blacklist.split(',')
for sol in self.args.blacklist:
@@ -261,25 +306,11 @@ def launch_worker(self, gpu_idx, f_vals, worker_lst):
@param gpu_idx Unique ID of the GPU
@param f_vals Dict containing runtime information
@param worker_lst List containing worker instances
- @retturn ret Boolean value
+ @return ret Boolean value
"""
# pylint: disable=too-many-branches
worker = None
kwargs = self.get_kwargs(gpu_idx, f_vals)
-
- if self.args.fin_steps:
- if 'miopen_find_compile' in self.args.fin_steps \
- or 'miopen_perf_compile' in self.args.fin_steps:
- kwargs['fetch_state'] = ['new']
- worker = FinBuilder(**kwargs)
- elif 'miopen_find_eval' in self.args.fin_steps or 'miopen_perf_eval' in self.args.fin_steps:
- kwargs['fetch_state'] = ['compiled']
- worker = FinEvaluator(**kwargs)
- else:
- raise ValueError('Unsupported fin step')
- worker.start()
- worker_lst.append(worker)
- return True
if self.args.update_applicability:
kwargs['fin_steps'] = ['applicability']
worker = FinClass(**kwargs)
@@ -287,7 +318,7 @@ def launch_worker(self, gpu_idx, f_vals, worker_lst):
worker_lst.append(worker)
return True
- worker = WorkerInterface(**kwargs)
+ worker = FinClass(**kwargs)
ret = False
if self.args.check_status:
if not super().check_status(worker, f_vals["b_first"], gpu_idx,
@@ -305,8 +336,7 @@ def launch_worker(self, gpu_idx, f_vals, worker_lst):
def compose_worker_list(self, machines):
# pylint: disable=too-many-branches
"""! Helper function to compose worker_list
- @param res DB query return item containg available machines
- @param args The command line arguments
+ @param machines List of machines to execute on
"""
worker_lst = []
fin_work_done = False
@@ -324,6 +354,20 @@ def compose_worker_list(self, machines):
else:
worker_ids = super().get_num_procs(machine)
+ if self.args.update_applicability:
+ f_vals = super().get_f_vals(machine, [1])
+ kwargs = self.get_kwargs(0, f_vals)
+ kwargs['fin_steps'] = ['applicability']
+ worker = FinClass(**kwargs)
+ query = worker.query_cfgs(self.args.label)
+ cfg_rows = query.all()
+ len_rows = len(cfg_rows)
+ proc_lim = (len_rows + 99) / 100
+ if 32 < proc_lim:
+ proc_lim = 32
+ while len(worker_ids) > proc_lim:
+ worker_ids.pop()
+
if len(worker_ids) == 0:
return None
@@ -342,6 +386,9 @@ def compose_worker_list(self, machines):
return worker_lst
def add_tables(self):
+ """! Function to create new DB tables
+ @return Bool
+ """
ret_t = create_tables(get_miopen_tables())
self.logger.info('DB creation successful: %s', ret_t)
recreate_triggers(drop_miopen_triggers(), get_miopen_triggers())
@@ -349,9 +396,11 @@ def add_tables(self):
def run(self):
# pylint: disable=duplicate-code
- """Main function to launch library"""
+ """! Main function to launch library"""
res = None
- self.parse_args()
+ if self.args is None:
+ self.parse_args()
+
if self.args.add_tables:
self.add_tables()
return None
@@ -394,19 +443,380 @@ def get_envmt(self):
return envmt
- def get_kwargs(self, gpu_idx, f_vals):
+ def get_kwargs(self, gpu_idx, f_vals, tuning=False):
"""! Helper function to set up kwargs for worker instances
@param gpu_idx Unique ID of the GPU
@param f_vals Dict containing runtime information
- @param args The command line arguments
+ @param tuning Boolean that indicates if kwargs are for a tuning step
+ @return kwargs Dictionary
"""
- if self.args.config_type is None:
- self.args.config_type = ConfigType.convolution
-
- kwargs = super().get_kwargs(gpu_idx, f_vals)
+ kwargs = super().get_kwargs(gpu_idx, f_vals, tuning)
kwargs['fin_steps'] = self.args.fin_steps
kwargs['dynamic_solvers_only'] = self.args.dynamic_solvers_only
kwargs['config_type'] = self.args.config_type
kwargs['reset_interval'] = self.args.reset_interval
return kwargs
+
+ def get_job_list(self, session, find_state, claim_num):
+ """! Get list of jobs
+ @param session DB session
+ @param find_state DB job state
+ @param claim_num Number of DB jobs to pick up
+ @return List of DB jobs
+
+ """
+ job_list = self.get_job_objs(session, find_state, self.args.label, self.dbt,
+ self.get_job_attr(), claim_num,
+ self.args.fin_steps)
+
+ return job_list
+
+ def get_job_objs(self,
+ session: DbSession,
+ find_state: list,
+ label: str,
+ dbt: DBTablesInterface,
+ job_attr: List[str],
+ claim_num: int = None,
+ fin_steps: List[str] = None) -> List[SimpleDict]:
+ """! Get list of job objects
+ @param session DB session
+ @param find_state DB job state
+ @param label DB job reason
+ @param dbt Class representing all DB tables associated with this class
+ @param job_attr List of DB job columns
+ @param claim_num Number of DB jobs to pick up
+ @param fin_steps List of MIFin steps
+ @return List of DB jobs
+ """
+ entries: List[Tuple[SimpleDict, ...]]
+ conds: List[str] = [f"session={dbt.session.id}", "valid=1"]
+
+ if label:
+ conds.append(f"reason='{label}'")
+
+ conds.append(f"retries<{self.max_job_retries}")
+ conds.append("state in (" + str(find_state).strip('{').strip('}') + ")")
+
+ entries = self.compose_work_objs(session, conds, dbt, job_attr, claim_num,
+ fin_steps)
+ return entries
+
+ def compose_work_objs(self,
+ session: DbSession,
+ conds: List[str],
+ dbt: DBTablesInterface,
+ job_attr: List[str],
+ claim_num: int = None,
+ fin_steps: List[str] = None) -> List[SimpleDict]:
+ """! Query a job list for update
+ @param session DB session
+ @param conds List of conditions for DB job WHERE clause
+ @param dbt Class representing all DB tables associated with this class
+ @param job_attr List of DB job columns
+ @param fin_steps List of MIFin steps
+ @return List of MIFin work objects
+ """
+ job_entries = []
+ if fin_steps:
+ conds.append(f"fin_step like '%{fin_steps[0]}%'")
+ else:
+ conds.append("fin_step='not_fin'")
+
+ cond_str = ' AND '.join(conds)
+ if cond_str:
+ cond_str = f"WHERE {cond_str}"
+ if claim_num:
+ cond_str += f" ORDER BY retries,config ASC LIMIT {claim_num} FOR UPDATE SKIP LOCKED"
+ else:
+ cond_str += " ORDER BY retries,config ASC FOR UPDATE SKIP LOCKED"
+
+ job_entries = gen_select_objs(session, job_attr,
+ dbt.job_table.__tablename__, cond_str)
+
+ return job_entries
+
+ def compose_work_objs_fin(self, session, job_entries,
+ dbt) -> List[Tuple[SimpleDict, SimpleDict]]:
+ """! Return jobs for fin work
+ @param session DB session
+ @param job_entries List of DB jobs
+ @param dbt Class representing all DB tables associated with this class
+ @return ret Job tuple
+ """
+ ret = []
+
+ cfg_rel = {
+ key: {
+ 'key': list(val.local_columns)[0].name,
+ 'ftble': str(list(val.remote_side)[0]).split('.', maxsplit=1)[0],
+ 'fkey': str(list(val.remote_side)[0]).split('.')[1]
+ } for key, val in inspect(dbt.config_table).relationships.items()
+ }
+
+ if job_entries:
+ id_str = ','.join({str(job.config) for job in job_entries})
+ cfg_cond_str = f"where valid=1 and id in ({id_str})"
+ cfg_attr = [column.name for column in inspect(dbt.config_table).c]
+ cfg_entries = gen_select_objs(session, cfg_attr,
+ dbt.config_table.__tablename__,
+ cfg_cond_str)
+
+ cfg_entries = self.attach_tensors(session, cfg_rel, cfg_entries)
+
+ cfg_map = {cfg.id: cfg for cfg in cfg_entries}
+
+ for job in job_entries:
+ ret.append((job, cfg_map[job.config]))
+
+ return ret
+
+ def attach_tensors(self, session, cfg_rel, cfg_entries):
+ """! Attach tensor relationship information to config entries
+ @param session DB session
+ @param cfg_rel DB Config col value
+ @param cfg_entries List of DB Config entries
+ @return cfg_entries List of DB Config entries with attached tensors (foreign keys)
+
+ """
+ for key, val in cfg_rel.items():
+ rel_attr = [
+ column.name
+ for column in inspect(get_class_by_tablename(val['ftble'])).c
+ ]
+ val['fattr'] = rel_attr
+
+ for cfg in cfg_entries:
+ for key, val in cfg_rel.items():
+ rel_val = getattr(cfg, val['key'])
+ rel_cond_str = f"where {val['fkey']}={rel_val}"
+ setattr(
+ cfg, key,
+ gen_select_objs(session, val['fattr'], val['ftble'],
+ rel_cond_str)[0])
+ return cfg_entries
+
+ #deprecated
+ def get_job_tables(self, job_rows: List[Tuple[SimpleDict, ...]],
+ job_attr: List[str]) -> List[SimpleDict]:
+ """Find job tables in query results"""
+ if has_attr_set(job_rows[0], job_attr):
+ job_tables: List[SimpleDict] = job_rows
+ else:
+ job_i: int = 0
+ tble: SimpleDict
+ for i, tble in enumerate(job_rows[0]):
+ if has_attr_set(tble, job_attr):
+ job_i = i
+ break
+ job_tables = [row[job_i] for row in job_rows]
+
+ return job_tables
+
+ def update_operation(self):
+ """! Update the workers type that this library needs"""
+ if self.args.fin_steps:
+ if 'miopen_find_compile' in self.args.fin_steps \
+ or 'miopen_perf_compile' in self.args.fin_steps:
+ self.fetch_state.add('new')
+ self.set_state = 'compile_start'
+ self.operation = Operation.COMPILE
+ elif 'miopen_find_eval' in self.args.fin_steps or 'miopen_perf_eval' in self.args.fin_steps:
+ self.fetch_state.add('new')
+ self.fetch_state.add('compiled')
+ self.set_state = 'eval_start'
+ self.operation = Operation.EVAL
+
+ if self.args.update_applicability:
+ self.fetch_state.add("new")
+
+ def has_tunable_operation(self):
+ """! Check if its a tuning loop operation
+ @return Bool value that represents if operation is tuning
+ """
+ if self.args is None:
+ self.parse_args()
+ if self.args.subcommand and "load_job" in self.args.subcommand:
+ return False
+ if self.args.shutdown_workers:
+ return True
+
+ return self.args.fin_steps and any(
+ s in self.args.fin_steps for s in MIOPEN_CELERY_STEPS)
+
+ @lru_cache(1)
+ def get_fdb_attr(self):
+ """! Get find_db table attrs
+ @return fdb_attr find_db table attributes without timestamps
+ """
+ fdb_attr = None
+ fdb_attr = [column.name for column in inspect(self.dbt.find_db_table).c]
+ fdb_attr.remove("insert_ts")
+ fdb_attr.remove("update_ts")
+ return fdb_attr
+
+ def serialize_jobs(self, session: DbSession, batch_jobs: List[Any]):
+ """! Return list of serialize jobs
+ @param session DB session
+ @param batch_jobs List of DB jobs
+ @return DB jobs, serialized
+ """
+ entries = self.compose_work_objs_fin(session, batch_jobs, self.dbt)
+ return serialize_chunk(entries)
+
+ def build_context(
+ self, serialized_jobs: Tuple[SimpleDict, SimpleDict]) -> List[dict]:
+ """Build context list for enqueue job"""
+ context_list = []
+ kwargs = self.get_context_items()
+ fdb_attr = self.get_fdb_attr()
+ for job, config in serialized_jobs:
+ context = {
+ 'job': job,
+ 'config': config,
+ 'operation': self.operation,
+ 'arch': self.dbt.session.arch,
+ 'num_cu': self.dbt.session.num_cu,
+ 'kwargs': kwargs,
+ 'fdb_attr': fdb_attr
+ }
+ context_list.append(context)
+
+ return context_list
+
+ def celery_enqueue_call(self, context: dict, q_name: str, task_id=False):
+ """! Enqueue job (context) for queue:q_name
+ @param context Context for Celery job
+ @param q_name Custom Celery queue name
+ @param task_id Custom Redis Key
+ """
+
+ #hacky way to get the Q_NAME to the task decorator for interpreter to decorate the
+ #function with correct q_name arg
+ #if import is moved to top it will result in circular imports
+ Q_NAME = q_name #pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name
+ from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue #pylint: disable=import-outside-toplevel
+
+ return celery_enqueue.apply_async((context,),
+ task_id=('-').join([self.prefix,
+ uuid()]),
+ queue=q_name,
+ reply_to=q_name)
+
+ def process_compile_results(self, session, fin_json, context):
+ """! Process result from fin_build worker
+ @param session DB session
+ @param fin_json MIFin results for job
+ @param context Context for Celery job
+ @return Boolean value
+ """
+ job = SimpleDict(**context['job'])
+ pending = []
+ solver_id_map = get_solver_ids()
+
+ failed_job = False
+ result_str = ''
+ status = None
+ try:
+ if fin_json:
+ if 'success' in fin_json and fin_json["success"] is False:
+ status = [fin_json]
+ else:
+ if 'miopen_find_compile_result' in fin_json:
+ status = process_fdb_w_kernels(session, fin_json,
+ copy.deepcopy(context), self.dbt,
+ context['fdb_attr'], pending)
+
+ elif 'miopen_perf_compile_result' in fin_json:
+ status = process_pdb_compile(session, fin_json, job, self.dbt,
+ solver_id_map)
+
+ success, result_str = get_fin_result(status)
+ failed_job = not success
+
+ except (OperationalError, IntegrityError) as err:
+ self.logger.warning('FinBuild: Unable to update Database %s', err)
+ session.rollback()
+ failed_job = True
+ except DataError as err:
+ self.logger.warning(
+ 'FinBuild: Invalid data, likely large workspace. DB Error: %s', err)
+ session.rollback()
+ failed_job = True
+
+ if failed_job:
+ set_job_state(session, job, self.dbt, 'errored', False, result=result_str)
+ else:
+ set_job_state(session,
+ job,
+ self.dbt,
+ 'compiled',
+ False,
+ result=result_str)
+
+ return True
+
+ def process_eval_results(self, session, fin_json, context):
+ """! Process fin_json result
+ @param session DB session
+ @param fin_json MIFin results for job
+ @param context Context for Celery job
+ @return Boolean value
+ """
+ job = SimpleDict(**context['job'])
+ failed_job = True
+ result_str = ''
+ pending = []
+ orig_state = 'compiled'
+
+ try:
+ if fin_json:
+ if 'success' in fin_json and fin_json["success"] is False:
+ status = [fin_json]
+ else:
+ if 'miopen_find_eval_result' in fin_json:
+ status = process_fdb_w_kernels(session,
+ fin_json,
+ copy.deepcopy(context),
+ self.dbt,
+ context['fdb_attr'],
+ pending,
+ result_str='miopen_find_eval_result',
+ check_str='evaluated')
+ elif 'miopen_perf_eval_result' in fin_json:
+ status = process_fdb_w_kernels(session,
+ fin_json,
+ copy.deepcopy(context),
+ self.dbt,
+ context['fdb_attr'],
+ pending,
+ result_str='miopen_perf_eval_result',
+ check_str='evaluated')
+
+ success, result_str = get_fin_result(status)
+ failed_job = not success
+
+ if failed_job:
+ if job.retries >= (MAX_ERRORED_JOB_RETRIES - 1): #pylint: disable=no-member
+ self.logger.warning('max job retries exhausted, setting to errored')
+ set_job_state(session, job, self.dbt, 'errored', result=result_str)
+ else:
+ self.logger.warning('resetting job state to %s, incrementing retries',
+ orig_state)
+ set_job_state(session,
+ job,
+ self.dbt,
+ orig_state,
+ increment_retries=True,
+ result=result_str)
+ else:
+ self.logger.info("\n\n Setting job state to evaluated")
+ set_job_state(session, job, self.dbt, 'evaluated', result=result_str)
+ clean_cache_table(self.dbt, job)
+ except (OperationalError, IntegrityError) as err:
+ self.logger.warning('FinBuild: Unable to update Database %s', err)
+ session.rollback()
+ set_job_state(session, job, self.dbt, 'errored', result=result_str)
+
+ return True
diff --git a/tuna/miopen/scripts/query_db.py b/tuna/miopen/scripts/query_db.py
index 925e4033e..a5fde60e5 100755
--- a/tuna/miopen/scripts/query_db.py
+++ b/tuna/miopen/scripts/query_db.py
@@ -32,7 +32,7 @@
from tuna.dbBase.sql_alchemy import DbSession
from tuna.parse_args import TunaArgs, setup_arg_parser
-from tuna.miopen.db.miopen_tables import ConvolutionJob, ConvolutionConfig
+from tuna.miopen.db.convolutionjob_tables import ConvolutionJob, ConvolutionConfig
from tuna.miopen.utils.parsing import get_pdb_key
from tuna.miopen.utils.parsing import get_fds_from_cmd
diff --git a/tuna/miopen/scripts/report.py b/tuna/miopen/scripts/report.py
index eb6be798d..1c6bc0421 100755
--- a/tuna/miopen/scripts/report.py
+++ b/tuna/miopen/scripts/report.py
@@ -32,7 +32,7 @@
from tuna.utils.logger import setup_logger
from tuna.miopen.db.tables import MIOpenDBTables
from tuna.dbBase.sql_alchemy import DbSession
-from tuna.utils.db_utility import get_id_solvers
+from tuna.miopen.db.solver import get_id_solvers
LOGGER = setup_logger('report')
diff --git a/tuna/miopen/subcmd/export_db.py b/tuna/miopen/subcmd/export_db.py
index 610fdfe3c..3d8a9d96d 100755
--- a/tuna/miopen/subcmd/export_db.py
+++ b/tuna/miopen/subcmd/export_db.py
@@ -27,6 +27,8 @@
"""Module to export find_db to txt file"""
import sqlite3
import os
+import tempfile
+import json
from collections import OrderedDict
from typing import Dict, Any, Optional, Union, List, Tuple
import base64
@@ -35,14 +37,18 @@
from tuna.dbBase.sql_alchemy import DbSession
from tuna.miopen.db.find_db import FindDBMixin
-from tuna.miopen.db.miopen_tables import GoldenMixin
+from tuna.miopen.db.mixin_tables import GoldenMixin
from tuna.miopen.db.tables import MIOpenDBTables
from tuna.miopen.utils.metadata import SQLITE_PERF_DB_COLS
-from tuna.utils.db_utility import get_id_solvers, DB_Type
+from tuna.utils.db_utility import DB_Type
+from tuna.miopen.db.solver import get_id_solvers
from tuna.utils.logger import setup_logger
from tuna.miopen.utils.analyze_parse_db import get_config_sqlite, insert_solver_sqlite
from tuna.miopen.utils.analyze_parse_db import get_sqlite_cfg_dict
+from tuna.miopen.utils.config_type import ConfigType
+from tuna.miopen.utils.metadata import INVERS_DIR_MAP
from tuna.miopen.parse_miopen_args import get_export_db_parser
+from tuna.miopen.worker.fin_utils import compose_config_obj
DIR_NAME: dict = {'F': 'Fwd', 'B': 'BwdData', 'W': 'BwdWeights'}
@@ -88,13 +94,67 @@ def get_filename(arch: str,
elif db_type == DB_Type.KERN_DB:
extension = '.kdb'
else:
- extension = ".db"
+ extension = ".db.txt"
final_name = f"{final_name}{extension}"
return final_name
+def fin_net_cfg_job(cfg_lst,
+ logger,
+ config_type: ConfigType = ConfigType.convolution):
+ """Construct a fin network_config job from a config
+ """
+ #arch and num_cu are required by fin, but unused for this command
+ job_list = []
+ if config_type == ConfigType.convolution:
+ for config in cfg_lst:
+ job = {
+ "steps": ["network_config"],
+ "arch": 'gfx908',
+ "num_cu": 120,
+ "config_tuna_id": config.id,
+ "direction": int(INVERS_DIR_MAP[config.direction]),
+ "config": compose_config_obj(config)
+ }
+ job_list.append(job)
+ else:
+ logger.error(f"Config type not implemented: {config_type}")
+
+ return job_list
+
+
+def fin_db_key(config_lst, logger):
+ """rerieve network_config from fin for config """
+ _, fin_ifile = tempfile.mkstemp(suffix='.json')
+ _, fin_ofile = tempfile.mkstemp(suffix='.json')
+
+ with open(fin_ifile, 'w') as in_file: # pylint: disable=unspecified-encoding
+ in_file.write(json.dumps(fin_net_cfg_job(config_lst, logger), indent=2))
+
+ fin_cmd = f"/opt/rocm/bin/fin -i {fin_ifile} -o {fin_ofile}"
+ logger.info('Executing fin cmd: %s', fin_cmd)
+
+ os.system(fin_cmd)
+
+ result = None
+ with open(fin_ofile, 'r') as out_file: # pylint: disable=unspecified-encoding
+ try:
+ result = json.load(out_file)
+ except Exception as err: # pylint: disable=broad-except
+ logger.error('Unable to load fin json file %s', err)
+ for line in out_file:
+ logger.error(line)
+
+ db_key_dict = {}
+ for elem in result:
+ if "db_key" in elem.keys():
+ db_key_dict[elem['config_tuna_id']] = elem['db_key']
+
+ return db_key_dict
+
+
def get_base_query(dbt: MIOpenDBTables, args: argparse.Namespace,
logger: logging.Logger):
""" general query for fdb/pdb results """
@@ -355,6 +415,7 @@ def export_kdb(dbt: MIOpenDBTables,
return write_kdb(args.arch, args.num_cu, kern_db, logger, args.filename)
+#deprecated
def create_sqlite_tables(arch, num_cu, filename=None):
"""create sqlite3 tables"""
local_path = get_filename(arch, num_cu, filename, False, DB_Type.PERF_DB)
@@ -390,6 +451,7 @@ def create_sqlite_tables(arch, num_cu, filename=None):
return cnx, local_path
+#deprecated
def insert_perf_db_sqlite(cnx, perf_db_entry, ins_cfg_id):
"""insert perf_db entry into sqlite"""
perf_db_dict = perf_db_entry.to_dict()
@@ -405,6 +467,29 @@ def insert_perf_db_sqlite(cnx, perf_db_entry, ins_cfg_id):
return perf_db_dict
+#deprecated
+def populate_sqlite(cfg_map, num_perf, cnx, perf_db_entry, cfg_entry,
+ total_entries, logger: logging.Logger):
+ """Analyze perf_db entry"""
+ if cfg_entry.id in cfg_map:
+ ins_cfg_id = cfg_map[cfg_entry.id]
+ else:
+ cfg_dict = get_sqlite_cfg_dict(perf_db_entry.fdb_key)
+
+ #filters cfg_dict by SQLITE_CONFIG_COLS, inserts cfg if missing
+ ins_cfg_id = get_config_sqlite(cnx, cfg_dict)
+ cfg_map[cfg_entry.id] = ins_cfg_id
+
+ pdb_dict = insert_perf_db_sqlite(cnx, perf_db_entry, ins_cfg_id)
+ num_perf += 1
+
+ if num_perf % (total_entries // 10) == 0:
+ cnx.commit()
+ logger.info("PDB count: %s, mysql cfg: %s, pdb: %s", num_perf, cfg_entry.id,
+ pdb_dict)
+
+
+#deprecated
def export_pdb(dbt: MIOpenDBTables, args: argparse.Namespace,
logger: logging.Logger):
""" export perf db from mysql to sqlite """
@@ -427,25 +512,65 @@ def export_pdb(dbt: MIOpenDBTables, args: argparse.Namespace,
return local_path
-def populate_sqlite(cfg_map, num_perf, cnx, perf_db_entry, cfg_entry,
- total_entries, logger: logging.Logger):
- """Analyze perf_dv entry"""
- if cfg_entry.id in cfg_map:
- ins_cfg_id = cfg_map[cfg_entry.id]
- else:
- cfg_dict = get_sqlite_cfg_dict(perf_db_entry.fdb_key)
+def build_miopen_pdb(query, logger: logging.Logger) -> OrderedDict:
+ """return dict with key: fdb_key, val: list of fdb entries"""
+ perf_db: OrderedDict = OrderedDict()
+ solvers: Dict[str, Dict[str, Any]] = {}
+ db_key_map: Dict[str, str] = {}
+ db_entries = query.all()
+ total_entries = len(db_entries)
+ logger.info("pdb query returned: %s", total_entries)
- #filters cfg_dict by SQLITE_CONFIG_COLS, inserts cfg if missing
- ins_cfg_id = get_config_sqlite(cnx, cfg_dict)
- cfg_map[cfg_entry.id] = ins_cfg_id
+ cfg_lst = []
+ for _, config in db_entries:
+ if config not in cfg_lst:
+ cfg_lst.append(config)
- pdb_dict = insert_perf_db_sqlite(cnx, perf_db_entry, ins_cfg_id)
- num_perf += 1
+ db_key_map = fin_db_key(cfg_lst, logger)
- if num_perf % (total_entries // 10) == 0:
- cnx.commit()
- logger.info("PDB count: %s, mysql cfg: %s, pdb: %s", num_perf, cfg_entry.id,
- pdb_dict)
+ for pdb_entry, config in db_entries:
+ if add_entry_to_solvers(pdb_entry, solvers, logger):
+ db_key = db_key_map[config.id]
+ lst = perf_db.get(db_key)
+ if not lst:
+ perf_db[db_key] = [pdb_entry]
+ else:
+ lst.append(pdb_entry)
+
+ return perf_db
+
+
+def write_pdb(arch, num_cu, ocl, perf_db, filename=None):
+ """
+ Serialize perf_db map to plain text file in MIOpen format
+ """
+ file_name = get_filename(arch, num_cu, filename, ocl, DB_Type.PERF_DB)
+
+ require_id_solvers()
+ with open(file_name, 'w') as out: # pylint: disable=unspecified-encoding
+ for key, solvers in sorted(perf_db.items(), key=lambda kv: kv[0]):
+ solvers.sort(
+ #key=lambda x: (float(x.kernel_time), ID_SOLVER_MAP[x.solver]))
+ key=lambda x: (ID_SOLVER_MAP[x.solver]))
+ lst = []
+ # for alg_lib, solver_id, kernel_time, workspace_sz in solvers:
+ for rec in solvers:
+ # pylint: disable-next=consider-using-f-string ; more reable
+ lst.append('{slv}:{params}'.format(slv=ID_SOLVER_MAP[rec.solver],
+ params=rec.params))
+ out.write(f"{key}={';'.join(lst)}\n")
+ return file_name
+
+
+def export_pdb_txt(dbt: MIOpenDBTables, args: argparse.Namespace,
+ logger: logging.Logger):
+ """ export perf db from mysql to txt file """
+ query = get_pdb_query(dbt, args, logger)
+ miopen_pdb = build_miopen_pdb(query, logger)
+
+ logger.info("write pdb to file.")
+ return write_pdb(args.arch, args.num_cu, args.opencl, miopen_pdb,
+ args.filename)
def run_export_db(args: argparse.Namespace, logger: logging.Logger):
@@ -469,7 +594,7 @@ def run_export_db(args: argparse.Namespace, logger: logging.Logger):
elif args.kern_db:
result_file = export_kdb(dbt, args, logger)
elif args.perf_db:
- result_file = export_pdb(dbt, args, logger)
+ result_file = export_pdb_txt(dbt, args, logger)
print(result_file)
diff --git a/tuna/miopen/subcmd/import_configs.py b/tuna/miopen/subcmd/import_configs.py
index 289477496..5e23d2ac8 100755
--- a/tuna/miopen/subcmd/import_configs.py
+++ b/tuna/miopen/subcmd/import_configs.py
@@ -350,25 +350,9 @@ def run_import_configs(args: argparse.Namespace,
if args.add_benchmark:
add_benchmark(args, dbt, logger)
return True
- if not (args.tag and args.framework and args.fw_version and args.model and
- args.md_version):
- logger.error(
- """Tag, framework & version, model & version arguments is required to \
- import configurations""")
- return False
-
- mid, fid = get_database_id(args.framework, args.fw_version, args.model,
- args.md_version, dbt, logger)
- if mid is None or fid is None:
- logger.error(
- 'Please use --add_model and --add_framework to add new model and framework'
- )
- return False
set_import_cfg_batches(args)
counts = import_cfgs(args, dbt, logger)
- #tagging imported configs with benchmark
- add_benchmark(args, dbt, logger)
logger.info('New configs added: %u', counts['cnt_configs'])
if args.tag or args.tag_only:
diff --git a/tuna/miopen/subcmd/import_db.py b/tuna/miopen/subcmd/import_db.py
index 8762e0dbf..f6640dd7c 100755
--- a/tuna/miopen/subcmd/import_db.py
+++ b/tuna/miopen/subcmd/import_db.py
@@ -40,7 +40,7 @@
from tuna.miopen.driver.convolution import DriverConvolution
from tuna.miopen.subcmd.import_configs import insert_config
from tuna.miopen.utils.metadata import PREC_TO_CMD
-from tuna.utils.db_utility import get_solver_ids
+from tuna.miopen.db.solver import get_solver_ids
from tuna.miopen.utils.parsing import parse_fdb_line
LOGGER = setup_logger('import_db')
@@ -48,6 +48,7 @@
COMMIT_FREQ = 1000
+#pylint: disable=too-few-public-methods
def parse_args():
"""command line parsing"""
parser = setup_arg_parser('Import Performance DBs once tunning is finished',
diff --git a/tuna/miopen/subcmd/load_job.py b/tuna/miopen/subcmd/load_job.py
index 34b495585..1cc9fe954 100755
--- a/tuna/miopen/subcmd/load_job.py
+++ b/tuna/miopen/subcmd/load_job.py
@@ -36,15 +36,17 @@
from sqlalchemy.sql.expression import true
from tuna.miopen.utils.metadata import ALG_SLV_MAP, TENSOR_PRECISION
-from tuna.utils.db_utility import get_solver_ids
+from tuna.miopen.db.solver import get_solver_ids
from tuna.utils.logger import setup_logger
from tuna.utils.db_utility import connect_db
-from tuna.miopen.db.miopen_tables import Solver
+from tuna.miopen.db.solver import Solver
from tuna.dbBase.sql_alchemy import DbSession
from tuna.miopen.utils.config_type import ConfigType
from tuna.miopen.db.tables import MIOpenDBTables
from tuna.miopen.parse_miopen_args import get_load_job_parser
+#pylint: disable=R0914
+
def arg_fin_steps(args: argparse.Namespace):
"""fin steps for load jobs"""
@@ -153,8 +155,10 @@ def add_jobs(args: argparse.Namespace, dbt: MIOpenDBTables,
fin_step_str = 'not_fin'
if args.fin_steps:
- fin_step_str = ','.join(args.fin_steps)
- query = f"select config, solver from {dbt.job_table.__tablename__} where session={args.session_id} and fin_step='{fin_step_str}'"
+ fin_step_str = ','.join(sorted(args.fin_steps))
+ query = f"select config, solver from {dbt.job_table.__tablename__} \
+ where session={args.session_id} and fin_step='{fin_step_str}'"
+
logger.info(query)
ret = session.execute(query)
pre_ex: Dict[str, Dict[str, bool]] = {}
@@ -183,9 +187,11 @@ def add_jobs(args: argparse.Namespace, dbt: MIOpenDBTables,
continue
session.add(job)
+ counts += 1
if do_commit:
session.commit()
- counts += 1
+ elif counts % 1000 == 0:
+ session.commit()
except IntegrityError as err:
session.rollback()
logger.warning('Integrity Error: %s', err)
diff --git a/tuna/miopen/subcmd/update_golden.py b/tuna/miopen/subcmd/update_golden.py
index 59c8ee8db..7ec485cff 100755
--- a/tuna/miopen/subcmd/update_golden.py
+++ b/tuna/miopen/subcmd/update_golden.py
@@ -114,15 +114,15 @@ def get_perf_str(args: argparse.Namespace, table_name):
"""Create perf table SQL query and return"""
new_table = f"""
create table {table_name} as select a.config, a.num_cu, a.arch, b.k1 as k1, c.k1 as k2,
- d.k1 as k3, c.k1-b.k1 as gv4_5, d.k1-c.k1 as gv5_6 from conv_golden a
+ d.k1 as k3, c.k1-b.k1 as gv4_5, d.k1-c.k1 as gv5_6 from conv_golden as a
inner join(select config, min(kernel_time) as k1, arch, num_cu from conv_golden
- where golden_miopen_v={args.golden_v-2} and kernel_time!=-1 group by config, arch, num_cu)
+ where golden_miopen_v={args.golden_v-2} and kernel_time>0 group by config, arch, num_cu)
as b on a.config=b.config and a.arch=b.arch and a.num_cu=b.num_cu
inner join(select config, min(kernel_time) as k1, arch, num_cu from conv_golden
- where golden_miopen_v={args.golden_v-1} and kernel_time!=-1 group by config, arch, num_cu)
+ where golden_miopen_v={args.golden_v-1} and kernel_time>0 group by config, arch, num_cu)
as c on a.config=c.config and a.arch=c.arch and a.num_cu=c.num_cu
inner join(select config, min(kernel_time) as k1, arch, num_cu from conv_golden
- where golden_miopen_v={args.golden_v} and kernel_time!=-1 group by config, arch, num_cu)
+ where golden_miopen_v={args.golden_v} and kernel_time>0 group by config, arch, num_cu)
as d on a.config=d.config and a.arch=d.arch and a.num_cu=d.num_cu
where a.golden_miopen_v={args.golden_v} group by a.config, a.arch, a.num_cu, b.k1, c.k1, d.k1;
"""
@@ -134,11 +134,11 @@ def create_perf_table(args: argparse.Namespace, logger: logging.Logger):
if args.golden_v == 0:
table_name = "conv_gv0"
elif args.golden_v == 1:
- table_name = "conv_gv10"
+ table_name = "conv_gv1_0"
else:
vm1 = str(args.golden_v - 1)
vm2 = str(args.golden_v - 2)
- table_name = f"conv_gv{vm2}{vm1}{args.golden_v}"
+ table_name = f"conv_gv{vm2}_{vm1}_{args.golden_v}"
print(table_name)
with ENGINE.connect() as conn:
try:
@@ -167,7 +167,7 @@ def gold_base_update(session: DbSession,
" set cg.valid=ps.valid, cg.params=ps.params, cg.workspace_sz=ps.workspace_sz"\
", cg.kernel_time=ps.kernel_time, cg.kernel_group=ps.kernel_group, cg.session=ps.session"\
f" where cg.golden_miopen_v={gold_v} and ps.golden_miopen_v={base_gold_v} and ps.valid=1"\
- " and ps.kernel_time>=0;"
+ " and ps.kernel_time>0;"
logger.info(update_q)
session.execute(update_q)
@@ -176,7 +176,7 @@ def gold_base_update(session: DbSession,
", fdb_key, params, kernel_time, workspace_sz, alg_lib, opencl, kernel_group, session, solver)"\
f" select valid, {gold_v}, arch, num_cu, config, fdb_key, params, kernel_time"\
", workspace_sz, alg_lib, opencl, kernel_group, session, solver"\
- f" from conv_golden where golden_miopen_v={base_gold_v} and valid=1 and kernel_time>=0;"
+ f" from conv_golden where golden_miopen_v={base_gold_v} and valid=1 and kernel_time>0;"
logger.info(insert_q)
session.execute(insert_q)
session.commit()
@@ -199,7 +199,7 @@ def gold_session_update(session: DbSession,
" set cg.valid=ps.valid, cg.params=ps.params, cg.workspace_sz=ps.workspace_sz"\
", cg.kernel_time=ps.kernel_time, cg.kernel_group=ps.kernel_group, cg.session=ps.session"\
f" where cg.golden_miopen_v={gold_v} and ps.session={tune_s} and ps.valid=1"\
- " and ps.kernel_time>=0;"
+ " and ps.kernel_time>0;"
session.execute(update_q)
logger.info("Gold %s Insert session %s.", gold_v, tune_s)
@@ -208,7 +208,7 @@ def gold_session_update(session: DbSession,
f" select cfd.valid, {gold_v}, arch, num_cu, config, fdb_key, params, kernel_time"\
", workspace_sz, alg_lib, opencl, kernel_group, session, solver"\
" from conv_find_db as cfd inner join session as s on cfd.session=s.id"\
- f" where session={tune_s} and cfd.valid=1 and kernel_time>=0;"
+ f" where session={tune_s} and cfd.valid=1 and kernel_time>0;"
session.execute(insert_q)
session.commit()
diff --git a/tuna/miopen/utils/config_type.py b/tuna/miopen/utils/config_type.py
index a4e8004d8..2efa900a4 100644
--- a/tuna/miopen/utils/config_type.py
+++ b/tuna/miopen/utils/config_type.py
@@ -37,3 +37,6 @@ class ConfigType(Enum):
def __str__(self):
return self.value
+
+ def __json__(self):
+ return self.value
diff --git a/tuna/miopen/utils/helper.py b/tuna/miopen/utils/helper.py
index a946af342..fea4ffade 100644
--- a/tuna/miopen/utils/helper.py
+++ b/tuna/miopen/utils/helper.py
@@ -27,6 +27,7 @@
"""Utility module for helper functions"""
import random
+import string
from time import sleep
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.orm import Query
@@ -34,12 +35,12 @@
from tuna.utils.logger import setup_logger
from tuna.dbBase.sql_alchemy import DbSession
from tuna.machine import Machine
-from tuna.utils.db_utility import get_solver_ids
+from tuna.miopen.db.solver import get_solver_ids
from tuna.utils.utility import check_qts
-from tuna.miopen.utils.metadata import MYSQL_LOCK_WAIT_TIMEOUT
-from tuna.miopen.utils.metadata import BN_DEFAULTS
+from tuna.miopen.utils.metadata import MYSQL_LOCK_WAIT_TIMEOUT, BN_DEFAULTS
from tuna.miopen.utils.metadata import FUSION_DEFAULTS, CONV_2D_DEFAULTS, CONV_3D_DEFAULTS
from tuna.utils.metadata import NUM_SQL_RETRIES
+from tuna.utils.db_utility import gen_update_query, session_retry
LOGGER = setup_logger('helper')
@@ -200,3 +201,37 @@ def get_db_id(db_elems, config_table):
if res:
cid = res[0][0]
return cid
+
+
+def set_job_state(session, job, dbt, state, increment_retries=False, result=""):
+ """Update job state for builder/evaluator job_set_attr: List[str]"""
+
+ LOGGER.info('Setting job id %s state to %s', job.id, state)
+ job_set_attr = ['state', 'gpu_id']
+ job.state = state
+ if result:
+ job_set_attr.append('result')
+ job.result = result
+ if increment_retries:
+ job_set_attr.append('retries')
+ job.retries += 1
+
+ #pylint: disable=duplicate-code
+ if '_start' in state:
+ job_set_attr.append('cache_loc')
+ cache: str = '~/.cache/miopen_'
+ blurr: str = ''.join(
+ random.choice(string.ascii_lowercase) for i in range(10))
+ cache_loc: str = cache + blurr
+ job.cache_loc = cache_loc
+ #pylint: enable=duplicate-code
+
+ query: str = gen_update_query(job, job_set_attr, dbt.job_table.__tablename__)
+
+ def callback() -> bool:
+ session.execute(query)
+ session.commit()
+ return True
+
+ assert session_retry(session, callback, lambda x: x(), LOGGER)
+ return True
diff --git a/tuna/miopen/utils/json_to_sql.py b/tuna/miopen/utils/json_to_sql.py
new file mode 100644
index 000000000..8f2ff88bf
--- /dev/null
+++ b/tuna/miopen/utils/json_to_sql.py
@@ -0,0 +1,312 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2024 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE.
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+"""Utility module for parsing fin json results"""
+import functools
+from sqlalchemy.exc import OperationalError
+
+from tuna.utils.logger import setup_logger
+from tuna.dbBase.sql_alchemy import DbSession
+from tuna.utils.utility import SimpleDict
+from tuna.utils.db_utility import session_retry, gen_select_objs
+from tuna.utils.db_utility import gen_update_query, gen_insert_query
+from tuna.miopen.worker.fin_utils import get_fin_slv_status
+from tuna.miopen.utils.parsing import parse_pdb_key
+from tuna.miopen.db.solver import get_solver_ids
+
+LOGGER = setup_logger('parse_results')
+
+
+def __update_fdb_w_kernels( #pylint: disable=too-many-arguments,too-many-locals
+ session: DbSession,
+ fin_json,
+ config,
+ session_id,
+ dbt,
+ job,
+ fdb_attr,
+ pending,
+ result_str: str = 'miopen_find_compile_result',
+ check_str: str = 'find_compiled') -> list:
+ """update find db + kernels from json results"""
+ status = []
+ solver_id_map = get_solver_ids()
+ if result_str in fin_json.keys():
+ for fdb_obj in fin_json.get(result_str):
+ slv_stat = get_fin_slv_status(fdb_obj, check_str)
+ status.append(slv_stat)
+
+ if fdb_obj[check_str]:
+ #returned entry is added to the table
+ fdb_entry = __compose_fdb_entry(session, fin_json, fdb_obj, session_id,
+ dbt, config, job, fdb_attr,
+ solver_id_map, pending)
+ __check_layout_mismatch(fdb_entry, slv_stat, config)
+ if not pending:
+ query = gen_update_query(fdb_entry, fdb_attr,
+ dbt.find_db_table.__tablename__)
+ session.execute(query)
+ else:
+ assert len(pending) == 1
+ pending.pop()
+ query = gen_insert_query(fdb_entry, fdb_attr,
+ dbt.find_db_table.__tablename__)
+ session.execute(query)
+
+ fdb_entry = __update_fdb_entry(session,
+ solver_id_map[fdb_obj['solver_name']],
+ session_id, dbt, config, job, fdb_attr,
+ pending)
+ fdb_entry.kernel_group = fdb_entry.id
+ query = gen_update_query(fdb_entry, ['kernel_group'],
+ dbt.find_db_table.__tablename__)
+ session.execute(query)
+
+ if fdb_obj['reason'] == 'Success':
+ __compose_kernel_entry(session, fdb_obj, fdb_entry, dbt)
+ LOGGER.info('Updating find Db(Build) for job_id=%s', job.id)
+ else:
+ # JD: add info about reason to the logs table
+ fdb_entry.valid = False
+ else:
+ LOGGER.warning("Failed find_db compile, cfg_id: %s, obj: %s",
+ fin_json['config_tuna_id'], fdb_obj)
+ else:
+ status = [{
+ 'solver': 'all',
+ 'success': False,
+ 'result': 'Find Compile: No results'
+ }]
+
+ session.commit()
+
+ return status
+
+
+def process_pdb_compile(session, fin_json, job, dbt, solver_id_map):
+ """retrieve perf db compile json results"""
+ status = []
+ if fin_json['miopen_perf_compile_result']:
+
+ def actuator(func, pdb_obj, dbt, job, solver_id_map):
+ return func(session, pdb_obj, dbt, job, solver_id_map)
+
+ for pdb_obj in fin_json['miopen_perf_compile_result']:
+ slv_stat = get_fin_slv_status(pdb_obj, 'perf_compiled')
+ status.append(slv_stat)
+ if pdb_obj['perf_compiled']:
+ session_retry(
+ session, compose_job_cache_entrys,
+ functools.partial(actuator,
+ pdb_obj=pdb_obj,
+ dbt=dbt,
+ job=job,
+ solver_id_map=solver_id_map), LOGGER)
+ LOGGER.info('Updating pdb job_cache for job_id=%s', job.id)
+ else:
+ status = [{
+ 'solver': 'all',
+ 'success': False,
+ 'result': 'Perf Compile: No results'
+ }]
+
+ return status
+
+
+def compose_job_cache_entrys(session, pdb_obj, dbt, job, solver_id_map):
+ """Compose new pdb kernel cache entry from fin input"""
+ for kern_obj in pdb_obj['kernel_objects']:
+ kernel_obj = dbt.fin_cache_table()
+ populate_kernels(kern_obj, kernel_obj)
+ kernel_obj.solver_id = solver_id_map[pdb_obj['solver_name']]
+ kernel_obj.job_id = job.id
+
+ session.add(kernel_obj)
+ session.commit()
+
+ return True
+
+
+def populate_kernels(kern_obj, kernel_obj):
+ """populate kernel object"""
+ kernel_obj.kernel_name = kern_obj['kernel_file']
+ kernel_obj.kernel_args = kern_obj['comp_options']
+ kernel_obj.kernel_blob = bytes(kern_obj['blob'], 'utf-8')
+ kernel_obj.kernel_hash = kern_obj['md5_sum']
+ kernel_obj.uncompressed_size = kern_obj['uncompressed_size']
+ return kernel_obj
+
+
+def __check_layout_mismatch(fdb_entry: SimpleDict, status: dict,
+ config) -> bool:
+ """Check that the fdb key returned by fin matches the config being tuned,
+ states to error if not"""
+ fdb_key = fdb_entry.fdb_key
+ fds, vals, _, _ = parse_pdb_key(fdb_key)
+ key_layout = vals[fds.index('out_layout')]
+ cfg_layout = config.out_layout
+
+ if cfg_layout != key_layout:
+ status['success'] = False
+ status['result'] = f"fdb_key layout mismatch with config"\
+ f" {key_layout} != {cfg_layout}"
+ fdb_entry.valid = False
+ return False
+
+ return True
+
+
+def __compose_kernel_entry(session, fdb_obj, fdb_entry, dbt):
+ """Compose a new Kernel Cache entry from fin input"""
+ # Now we have the ID, lets add the binary cache objects
+ for kern_obj in fdb_obj['kernel_objects']:
+ kernel_obj = dbt.kernel_cache()
+ populate_kernels(kern_obj, kernel_obj)
+ kernel_obj.kernel_group = fdb_entry.kernel_group
+ session.add(kernel_obj)
+ return True
+
+
+def __update_fdb_entry(session, solver, session_id, dbt, config, job, fdb_attr,
+ pending):
+ """ Add a new entry to fdb if there isnt one already """
+ obj, fdb_entry = get_fdb_entry(session, solver, session_id, dbt, config,
+ fdb_attr)
+ if obj: # existing entry in db
+ # This can be removed if we implement the delete orphan cascade
+ fdb_entry = obj
+ if not fdb_entry.kernel_group is None:
+ LOGGER.info('Invalidate kernel_group %s', fdb_entry.kernel_group)
+ session.query(dbt.kernel_cache)\
+ .filter(dbt.kernel_cache.valid == 1)\
+ .filter(dbt.kernel_cache.kernel_group ==
+ fdb_entry.kernel_group)\
+ .update({'valid': 0})
+ else:
+ # Bundle Insert for later
+ pending.append((job, fdb_entry))
+ return fdb_entry
+
+
+def get_fdb_entry(session, solver, session_id, dbt, config, fdb_attr):
+ """ Get FindDb entry from db """
+ obj = None
+ fdb_entry = None
+
+ conds = [
+ f"session={session_id}", f"config={config.id}", f"solver={solver}",
+ "opencl=0"
+ ]
+ cond_str = f"where {' AND '.join(conds)}"
+ entries = gen_select_objs(session, fdb_attr, dbt.find_db_table.__tablename__,
+ cond_str)
+
+ if entries:
+ assert len(entries) == 1
+ obj = entries[0]
+ else:
+ fdb_entry = SimpleDict()
+ for attr in fdb_attr:
+ setattr(fdb_entry, attr, None)
+ setattr(fdb_entry, 'session', session_id)
+ setattr(fdb_entry, 'config', config.id)
+ setattr(fdb_entry, 'solver', solver)
+ setattr(fdb_entry, 'opencl', False)
+ setattr(fdb_entry, 'logger', LOGGER)
+
+ return obj, fdb_entry
+
+
+def __compose_fdb_entry( #pylint: disable=too-many-arguments
+ session, fin_json, fdb_obj, session_id, dbt, config, job, fdb_attr,
+ solver_id_map, pending):
+ """Compose a FindDB table entry from fin_output"""
+ solver = solver_id_map[fdb_obj['solver_name']]
+ fdb_entry = __update_fdb_entry(session, solver, session_id, dbt, config, job,
+ fdb_attr, pending)
+ fdb_entry.fdb_key = fin_json['db_key']
+ fdb_entry.alg_lib = fdb_obj['algorithm']
+ fdb_entry.params = fdb_obj['params']
+ fdb_entry.workspace_sz = fdb_obj['workspace']
+ fdb_entry.valid = True
+
+ fdb_entry.kernel_time = -1
+ if 'time' in fdb_obj:
+ fdb_entry.kernel_time = fdb_obj['time']
+
+ fdb_entry.kernel_group = fdb_entry.id
+
+ return fdb_entry
+
+
+def process_fdb_w_kernels(session,
+ fin_json,
+ context,
+ dbt,
+ fdb_attr,
+ pending,
+ result_str='miopen_find_compile_result',
+ check_str='find_compiled'):
+ """initiate find db update"""
+ job = SimpleDict(**context['job'])
+ #get_db_obj_by_id(context['job']['id'], dbt.job_table)
+ config = SimpleDict(**context['config'])
+ #get_db_obj_by_id(context['config']['id'], dbt.config_table)
+
+ callback = __update_fdb_w_kernels
+ status = session_retry(
+ session, callback,
+ lambda x: x(session, fin_json, config, context['kwargs']['session_id'],
+ dbt, job, fdb_attr, pending, result_str, check_str), LOGGER)
+
+ if not status:
+ LOGGER.warning('Fin: Unable to update Database')
+ status = [{
+ 'solver': 'all',
+ 'success': False,
+ 'result': 'Fin: Unable to update Database'
+ }]
+
+ return status
+
+
+def clean_cache_table(dbt, job):
+ """Remove the fin cache kernel entries for this job"""
+ with DbSession() as session:
+ try:
+ LOGGER.info('Delete kernel cache entries job(%s)', job.id)
+ job_cache = session.query(dbt.fin_cache_table)\
+ .filter(dbt.fin_cache_table.job_id == job.id)
+ job_cache.delete()
+ invalid_fdb_cache = session.query(dbt.kernel_cache)\
+ .filter(dbt.kernel_cache.valid == 0)
+ invalid_fdb_cache.delete()
+ session.commit()
+ except OperationalError as err:
+ session.rollback()
+ LOGGER.warning('FinEval: Unable to clean %s / %s: %s',
+ dbt.fin_cache_table.__tablename__,
+ dbt.kernel_cache.__tablename__, err)
diff --git a/tuna/miopen/utils/lib_helper.py b/tuna/miopen/utils/lib_helper.py
new file mode 100644
index 000000000..a421c0519
--- /dev/null
+++ b/tuna/miopen/utils/lib_helper.py
@@ -0,0 +1,44 @@
+#!/usr/bin/env python3
+###############################################################################
+#
+# MIT License
+#
+# Copyright (c) 2023 Advanced Micro Devices, Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+###############################################################################
+"""Utility module for miopen library"""
+
+from tuna.miopen.worker.fin_builder import FinBuilder
+from tuna.miopen.worker.fin_eval import FinEvaluator
+from tuna.worker_interface import WorkerInterface
+from tuna.libraries import Operation
+
+
+def get_worker(kwargs, operation):
+ """Return worker based on operation type"""
+
+ worker = WorkerInterface(**kwargs)
+ if operation == Operation.COMPILE:
+ worker = FinBuilder(**kwargs)
+ elif operation == Operation.EVAL:
+ worker = FinEvaluator(**kwargs)
+
+ return worker
diff --git a/tuna/miopen/utils/metadata.py b/tuna/miopen/utils/metadata.py
index 4e1d99ecd..b0f422f9b 100644
--- a/tuna/miopen/utils/metadata.py
+++ b/tuna/miopen/utils/metadata.py
@@ -145,6 +145,8 @@
's': ('save', 0)
}
+SUPPORTED_LAYOUTS = ["NCHW", "NHWC", "NCDHW", "NDHWC"]
+
#NOTE: dim0 for input_tensor is 1
#3D layouts
NCDHW_LAYOUT = {
diff --git a/tuna/miopen/worker/fin_builder.py b/tuna/miopen/worker/fin_builder.py
index 1b9fad353..e3c60bcc1 100644
--- a/tuna/miopen/worker/fin_builder.py
+++ b/tuna/miopen/worker/fin_builder.py
@@ -26,19 +26,12 @@
###############################################################################
"""Builder class implements the worker interface. The purpose of this class is to run fin
jobs in compile mode"""
-from time import sleep
-import random
-import functools
import json
-from sqlalchemy.exc import OperationalError, DataError, IntegrityError
from sqlalchemy.inspection import inspect
from tuna.miopen.worker.fin_class import FinClass
-from tuna.dbBase.sql_alchemy import DbSession
from tuna.miopen.worker.fin_utils import fin_job
-from tuna.miopen.worker.fin_utils import get_fin_slv_status, get_fin_result
-from tuna.utils.db_utility import session_retry
class FinBuilder(FinClass):
@@ -67,97 +60,11 @@ def get_fin_input(self):
is_temp=True)
return fin_input
- def compose_job_cache_entrys(self, session, pdb_obj):
- """Compose new pdb kernel cache entry from fin input"""
- for kern_obj in pdb_obj['kernel_objects']:
- kernel_obj = self.dbt.fin_cache_table()
- self.populate_kernels(kern_obj, kernel_obj)
- kernel_obj.solver_id = self.solver_id_map[pdb_obj['solver_name']]
- kernel_obj.job_id = self.job.id
-
- session.add(kernel_obj)
- session.commit()
-
- return True
-
- def process_pdb_compile(self, session, fin_json):
- """retrieve perf db compile json results"""
- status = []
- if fin_json['miopen_perf_compile_result']:
-
- def actuator(func, pdb_obj):
- return func(session, pdb_obj)
-
- for pdb_obj in fin_json['miopen_perf_compile_result']:
- slv_stat = get_fin_slv_status(pdb_obj, 'perf_compiled')
- status.append(slv_stat)
- if pdb_obj['perf_compiled']:
- session_retry(session, self.compose_job_cache_entrys,
- functools.partial(actuator, pdb_obj=pdb_obj),
- self.logger)
- self.logger.info('Updating pdb job_cache for job_id=%s', self.job.id)
- else:
- status = [{
- 'solver': 'all',
- 'success': False,
- 'result': 'Perf Compile: No results'
- }]
-
- return status
-
- def close_job(self):
- """mark a job complete"""
- self.set_job_state('compiled')
-
def step(self):
"""Main functionality of the builder class. It picks up jobs in new state and compiles them"""
- self.pending = []
- self.result_queue_drain()
if not self.init_check_env():
return False
- if not self.get_job("new", "compile_start", True):
- while not self.result_queue_drain():
- sleep(random.randint(1, 10))
- return False
-
- # JD: while fin can exec multiple jobs at a time, that makes error detection difficult
- self.logger.info('Acquired new job: job_id=%s', self.job.id)
- self.set_job_state('compiling')
fin_json = self.run_fin_cmd()
-
- failed_job = True
- result_str = ''
- if fin_json:
- failed_job = False
- with DbSession() as session:
- try:
- if 'miopen_find_compile_result' in fin_json:
- status = self.process_fdb_w_kernels(session, fin_json)
-
- elif 'miopen_perf_compile_result' in fin_json:
- status = self.process_pdb_compile(session, fin_json)
-
- success, result_str = get_fin_result(status)
- failed_job = not success
-
- except (OperationalError, IntegrityError) as err:
- self.logger.warning('FinBuild: Unable to update Database %s', err)
- session.rollback()
- failed_job = True
- except DataError as err:
- self.logger.warning(
- 'FinBuild: Invalid data, likely large workspace. DB Error: %s',
- err)
- session.rollback()
- failed_job = True
-
- if failed_job:
- self.set_job_state('errored', result=result_str)
- elif self.pending:
- self.set_job_state('compiled_pend', result=result_str)
- self.result_queue.put(self.pending)
- else:
- self.set_job_state('compiled', result=result_str)
- return True
+ return fin_json
diff --git a/tuna/miopen/worker/fin_class.py b/tuna/miopen/worker/fin_class.py
index e7db91984..fd123cee8 100644
--- a/tuna/miopen/worker/fin_class.py
+++ b/tuna/miopen/worker/fin_class.py
@@ -30,12 +30,15 @@
import os
import tempfile
import functools
+from time import sleep
from typing import List, Dict, Tuple
+import random
import paramiko
try:
import queue
except ImportError:
import Queue as queue #type: ignore
+
from sqlalchemy import func as sqlalchemy_func
from sqlalchemy.exc import IntegrityError, InvalidRequestError #pylint: disable=wrong-import-order
from sqlalchemy.inspection import inspect
@@ -44,15 +47,11 @@
from tuna.dbBase.sql_alchemy import DbSession
from tuna.miopen.utils.metadata import FIN_CACHE
from tuna.miopen.utils.metadata import INVERS_DIR_MAP
-from tuna.miopen.db.tables import MIOpenDBTables
from tuna.miopen.worker.fin_utils import compose_config_obj
-from tuna.miopen.worker.fin_utils import get_fin_slv_status
from tuna.miopen.utils.config_type import ConfigType
-from tuna.miopen.utils.parsing import parse_pdb_key
from tuna.utils.db_utility import session_retry
-from tuna.utils.db_utility import get_solver_ids, get_id_solvers
-from tuna.utils.db_utility import gen_select_objs, gen_insert_query, gen_update_query
-from tuna.utils.db_utility import get_class_by_tablename, has_attr_set
+from tuna.miopen.db.solver import get_solver_ids, get_id_solvers
+from tuna.utils.db_utility import gen_select_objs, get_class_by_tablename
from tuna.utils.utility import split_packets
from tuna.utils.utility import SimpleDict
@@ -62,6 +61,7 @@ class FinClass(WorkerInterface):
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-public-methods
+ # pylint: disable=no-member
def __init__(self, **kwargs):
"""Constructor"""
@@ -78,20 +78,20 @@ def __init__(self, **kwargs):
_, self.local_output = tempfile.mkstemp()
self.fin_outfile = self.local_output.split("/tmp/", 1)[1] + ".json"
- self.solver_id_map = get_solver_ids()
_, self.id_solver_map = get_id_solvers(
) #hyphenated names used by miopen::solver.ToString()
self.all_configs = []
self.fin_list = []
self.multiproc = False
- self.pending = []
self.first_pass = True
self.dynamic_solvers_only = False
+ self.solver_id_map = get_solver_ids()
self.__dict__.update(
(key, value) for key, value in kwargs.items() if key in allowed_keys)
- self.config_type = ConfigType.convolution if self.config_type is None else self.config_type
+ self.config_type = ConfigType.convolution if self.config_type is None else ConfigType(
+ self.config_type)
super().__init__(**kwargs)
@@ -102,7 +102,6 @@ def __init__(self, **kwargs):
self.envmt.append(
f"MIOPEN_CUSTOM_CACHE_DIR=/tmp/miopenpdb/thread-{self.gpu_id}/cache")
- self.config = SimpleDict()
self.cfg_attr = [column.name for column in inspect(self.dbt.config_table).c]
# dict of relationship_column : dict{local_key, foreign_table_name, foreign_key, [remote_attr]}
@@ -140,6 +139,19 @@ def get_miopen_v(self) -> str:
return commit_hash
+ def check_env(self) -> bool:
+ """Interface function to check the miopen env version vs presumed miopen version"""
+ if super().check_env():
+ env_miopen_v: str = self.get_miopen_v()
+ if self.dbt.session.miopen_v != env_miopen_v:
+ raise ValueError(
+ f'session miopen_v {self.dbt.session.miopen_v} does not match env miopen_v\
+ {env_miopen_v}')
+ else:
+ return False
+
+ return True
+
def chk_abort_file(self):
"""Checking presence of abort file to terminate processes immediately"""
abort_reason = []
@@ -157,6 +169,8 @@ def chk_abort_file(self):
def set_db_tables(self):
"""Initialize tables"""
+ # pylint: disable=import-outside-toplevel
+ from tuna.miopen.db.tables import MIOpenDBTables
self.dbt = MIOpenDBTables(session_id=self.session_id,
config_type=self.config_type)
@@ -196,27 +210,6 @@ def compose_work_objs(
return ret
- #pylint: disable=R0801
- def check_jobs_found(self, job_rows, find_state, imply_end):
- """check for end of jobs"""
- if not job_rows:
- # we are done
- self.logger.warning('No %s jobs found, fin_step: %s, session %s',
- find_state, self.fin_steps, self.session_id)
- if imply_end:
- self.logger.warning("set end")
- self.end_jobs.value = 1
- return False
- return True
-
- #pylint: enable=R0801
-
- def job_queue_pop(self):
- """load job & config from top of job queue"""
- self.job, self.config = self.job_queue.get(True, 1)
- self.logger.info("Got job %s %s %s", self.job.id, self.job.state,
- self.job.reason)
-
def __compose_fincmd(self):
"""Helper function to compose fin docker cmd"""
if self.machine.local_machine:
@@ -264,15 +257,19 @@ def __get_fin_results(self):
if self.__prep_fin_input(self.local_file, to_file=True):
fin_cmd = self.__compose_fincmd()
- ret_code, out, err = self.exec_docker_cmd(fin_cmd)
- if ret_code > 0:
- self.logger.warning('Err executing cmd: %s', fin_cmd)
- self.logger.warning(out)
+ for i in range(3):
+ ret_code, out, err = self.exec_docker_cmd(fin_cmd)
+ if ret_code != 0:
+ self.logger.warning('Error executing cmd(%u): %s', i, fin_cmd)
+ self.logger.warning(out)
+ sleep(random.randint(1, 10))
+ else:
+ result = self.__parse_out()
+ break
+ if ret_code != 0:
raise ValueError(
f'Failed to execute fin cmd: {fin_cmd} err: {err.read()}')
- result = self.__parse_out()
-
return result
def __parse_out(self):
@@ -315,69 +312,54 @@ def applicability(self):
return True
- def __rm_old_app(self, session: DbSession, cfg_rows: list) -> None:
- """remove old applicability"""
- rm_old = ''
- if self.label and cfg_rows:
- cfg_ids = [str(row.id) for row in cfg_rows]
- cfg_str = ','.join(cfg_ids)
- rm_old = f"update {self.dbt.solver_app.__tablename__} set applicable=0"\
- f" where session={self.session_id} and config in ({cfg_str});"
- else:
- rm_old = f"update {self.dbt.solver_app.__tablename__} set applicable=0"\
- f" where session={self.session_id};"
+ def query_cfgs(self, label=None):
+ """query all configs from table, optionally limit by label"""
+ with DbSession() as session:
+ query = session.query(self.dbt.config_table)\
+ .filter(self.dbt.config_table.valid == 1)
- self.logger.info("Start applic zeroing")
- session.execute(rm_old)
- session.commit()
- self.logger.info("Finished applic zeroing")
+ if label:
+ query = query.filter(self.dbt.config_table.id == self.dbt.config_tags_table.config)\
+ .filter(self.dbt.config_tags_table.tag == label)
+
+ #order by id for splitting configs into blocks
+ query = query.order_by(self.dbt.config_table.id)
+ return query
def __set_all_configs(self, idx: int = 0, num_blk: int = 1) -> bool:
"""Gathering all configs from Tuna DB to set up fin input file"""
if idx == 0:
- with DbSession() as session:
- query = session.query(self.dbt.config_table)\
- .filter(self.dbt.config_table.valid == 1)
-
- if self.label:
- query = query.filter(self.dbt.config_table.id == self.dbt.config_tags_table.config)\
- .filter(self.dbt.config_tags_table.tag == self.label)
-
- #order by id for splitting configs into blocks
- query = query.order_by(self.dbt.config_table.id)
- rows = query.all()
-
- len_rows = len(rows)
- master_cfg_list = []
- for row in rows:
- r_dict = compose_config_obj(row, self.config_type)
- if self.config_type == ConfigType.batch_norm:
- r_dict['direction'] = row.get_direction()
- master_cfg_list.append(r_dict)
-
- self.__rm_old_app(session, rows)
-
- block_size = len_rows // num_blk #size of the config block
- extra = len_rows % num_blk #leftover configs, don't divide evenly
- self.logger.info(
- "cfg workdiv: num_blocks: %s, block_size: %s, extra: %s", num_blk,
- block_size, extra)
- for i in range(num_blk):
- start = i * block_size #start of a process block
- end = (i + 1) * block_size
- #distributing leftover configs to processes
- if i < extra:
- start += i
- end += 1 + i
- else:
- start += extra
- end += extra
-
- if start >= len_rows:
- self.job_queue.put([])
- else:
- self.logger.info("cfg workdiv: start %s, end %s", start, end)
- self.job_queue.put(master_cfg_list[start:end])
+ query = self.query_cfgs(self.label)
+ rows = query.all()
+
+ len_rows = len(rows)
+ master_cfg_list = []
+ for row in rows:
+ r_dict = compose_config_obj(row, self.config_type)
+ if self.config_type == ConfigType.batch_norm:
+ r_dict['direction'] = row.get_direction()
+ master_cfg_list.append(r_dict)
+
+ block_size = len_rows // num_blk #size of the config block
+ extra = len_rows % num_blk #leftover configs, don't divide evenly
+ self.logger.info("cfg workdiv: num_blocks: %s, block_size: %s, extra: %s",
+ num_blk, block_size, extra)
+ for i in range(num_blk):
+ start = i * block_size #start of a process block
+ end = (i + 1) * block_size
+ #distributing leftover configs to processes
+ if i < extra:
+ start += i
+ end += 1 + i
+ else:
+ start += extra
+ end += extra
+
+ if start >= len_rows:
+ self.job_queue.put([])
+ else:
+ self.logger.info("cfg workdiv: start %s, end %s", start, end)
+ self.job_queue.put(master_cfg_list[start:end])
try:
self.all_configs = self.job_queue.get(True, 180)
except queue.Empty:
@@ -470,49 +452,41 @@ def __insert_applicability(self, session: DbSession,
json_in: List[Dict]) -> bool:
"""write applicability to sql"""
inserts = []
+ app_values = []
+ app_cfgs = []
for elem in json_in:
if "applicable_solvers" in elem.keys():
cfg_id = elem["input"]["config_tuna_id"]
- # pylint: disable=comparison-with-callable
- app_query = session.query(self.dbt.solver_app)\
- .filter(self.dbt.solver_app.session == self.session_id)\
- .filter(self.dbt.solver_app.config == cfg_id)
- # pylint: enable=comparison-with-callable
if not elem["applicable_solvers"]:
self.logger.warning("No applicable solvers for %s", cfg_id)
- app_slv_ids = []
+ app_cfgs.append(f"{cfg_id}")
+
for solver in elem["applicable_solvers"]:
try:
solver_id = self.solver_id_map[solver]
- app_slv_ids.append(solver_id)
+ vals = f"({self.session_id}, {cfg_id}, {solver_id}, 1)"
+ app_values.append(vals)
except KeyError:
self.logger.warning('Solver %s not found in solver table', solver)
self.logger.info("Please run 'go_fish.py --update_solver' first")
- return False
-
- for solver_id in app_slv_ids:
- obj = app_query.filter(
- self.dbt.solver_app.solver == solver_id).first() # pylint: disable=W0143
- if obj:
- obj.applicable = 1
- else:
- inserts.append((cfg_id, solver_id))
- #commit updates
- session.commit()
+ cleanup = f"delete from {self.dbt.solver_app.__tablename__} where session={self.session_id}"\
+ " and config in (" + ", ".join(app_cfgs) + ");"
+ ins_str = f"insert ignore into {self.dbt.solver_app.__tablename__}"\
+ " (session, config, solver, applicable)"\
+ " values " + ", ".join(app_values) + ";"
+ inserts.append(cleanup)
+ inserts.append(ins_str)
- #bulk inserts
with self.job_queue_lock:
- self.logger.info('Commit bulk inserts, please wait')
- for cfg_id, solver_id in inserts:
- new_entry = self.dbt.solver_app(solver=solver_id,
- config=cfg_id,
- session=self.session_id,
- applicable=1)
- session.add(new_entry)
+ self.logger.info('Commit bulk configs (%s), entries (%s), please wait',
+ len(app_cfgs), len(app_values))
+ for sql_str in inserts:
+ session.execute(sql_str)
session.commit()
+ self.logger.info('End bulk inserts')
return True
@@ -534,11 +508,16 @@ def actuator(func, pack):
session_retry(session, self.__insert_applicability,
functools.partial(actuator, pack=pack), self.logger)
- query = session.query(sqlalchemy_func.count(self.dbt.solver_app.id))
+ query = session.query(sqlalchemy_func.count(self.dbt.solver_app.id),
+ self.dbt.solver_app.applicable)
query = query.filter(self.dbt.solver_app.session == self.session_id) # pylint: disable=W0143
- sapp_count = query.one()[0]
+ if self.label:
+ query = query.filter(self.dbt.solver_app.config == self.dbt.config_tags_table.config)\
+ .filter(self.dbt.config_tags_table.tag == self.label)
+ query = query.group_by(self.dbt.solver_app.applicable)
+ sapp_count = query.all()
self.logger.warning(
- "Finished parsing solver applicability, new session size: %d entries",
+ "Finished parsing solver applicability, label(%s): %s", self.label,
sapp_count)
return True
@@ -629,76 +608,6 @@ def __parse_solvers(self, solvers):
return True
- def get_fdb_entry(self, session, solver):
- """ Get FindDb entry from db """
- obj = None
- fdb_entry = None
-
- conds = [
- f"session={self.dbt.session.id}", f"config={self.config.id}",
- f"solver={solver}", "opencl=0"
- ]
- cond_str = f"where {' AND '.join(conds)}"
- entries = gen_select_objs(session, self.fdb_attr,
- self.dbt.find_db_table.__tablename__, cond_str)
-
- if entries:
- assert len(entries) == 1
- obj = entries[0]
- else:
- fdb_entry = SimpleDict()
- for attr in self.fdb_attr:
- setattr(fdb_entry, attr, None)
- setattr(fdb_entry, 'session', self.dbt.session.id)
- setattr(fdb_entry, 'config', self.config.id)
- setattr(fdb_entry, 'solver', solver)
- setattr(fdb_entry, 'opencl', False)
- setattr(fdb_entry, 'logger', self.logger)
-
- return obj, fdb_entry
-
- def __update_fdb_entry(self, session, solver):
- """ Add a new entry to fdb if there isnt one already """
- obj, fdb_entry = self.get_fdb_entry(session, solver)
- if obj: # existing entry in db
- # This can be removed if we implement the delete orphan cascade
- fdb_entry = obj
- session.query(
- self.dbt.kernel_cache).filter(self.dbt.kernel_cache.kernel_group ==
- fdb_entry.kernel_group).delete()
- else:
- # Bundle Insert for later
- self.pending.append((self.job, fdb_entry))
- return fdb_entry
-
- def __compose_fdb_entry(self, session, fin_json, fdb_obj):
- """Compose a FindDB table entry from fin_output"""
- solver = self.solver_id_map[fdb_obj['solver_name']]
- fdb_entry = self.__update_fdb_entry(session, solver)
- fdb_entry.fdb_key = fin_json['db_key']
- fdb_entry.alg_lib = fdb_obj['algorithm']
- fdb_entry.params = fdb_obj['params']
- fdb_entry.workspace_sz = fdb_obj['workspace']
- fdb_entry.valid = True
-
- fdb_entry.kernel_time = -1
- if 'time' in fdb_obj:
- fdb_entry.kernel_time = fdb_obj['time']
-
- fdb_entry.kernel_group = fdb_entry.id
-
- return fdb_entry
-
- def __compose_kernel_entry(self, session, fdb_obj, fdb_entry):
- """Compose a new Kernel Cache entry from fin input"""
- # Now we have the ID, lets add the binary cache objects
- for kern_obj in fdb_obj['kernel_objects']:
- kernel_obj = self.dbt.kernel_cache()
- self.populate_kernels(kern_obj, kernel_obj)
- kernel_obj.kernel_group = fdb_entry.kernel_group
- session.add(kernel_obj)
- return True
-
@staticmethod
def populate_kernels(kern_obj, kernel_obj):
"""populate kernel object"""
@@ -709,160 +618,6 @@ def populate_kernels(kern_obj, kernel_obj):
kernel_obj.uncompressed_size = kern_obj['uncompressed_size']
return kernel_obj
- def __check_layout_mismatch(self, fdb_entry: SimpleDict,
- status: dict) -> bool:
- """Check that the fdb key returned by fin matches the config being tuned,
- states to error if not"""
- fdb_key = fdb_entry.fdb_key
- fds, vals, _, _ = parse_pdb_key(fdb_key)
- key_layout = vals[fds.index('out_layout')]
- cfg_layout = self.config.out_layout
-
- if cfg_layout != key_layout:
- status['success'] = False
- status['result'] = f"fdb_key layout mismatch with config"\
- f" {key_layout} != {cfg_layout}"
- fdb_entry.valid = False
- return False
-
- return True
-
- def __update_fdb_w_kernels(self,
- session: DbSession,
- fin_json: dict,
- result_str: str = 'miopen_find_compile_result',
- check_str: str = 'find_compiled') -> list:
- """update find db + kernels from json results"""
- status = []
- if fin_json[result_str]:
- for fdb_obj in fin_json[result_str]:
- slv_stat = get_fin_slv_status(fdb_obj, check_str)
- status.append(slv_stat)
-
- if fdb_obj[check_str]:
- #returned entry is added to the table
- fdb_entry = self.__compose_fdb_entry(session, fin_json, fdb_obj)
- self.__check_layout_mismatch(fdb_entry, slv_stat)
- if not self.pending:
- query = gen_update_query(fdb_entry, self.fdb_attr,
- self.dbt.find_db_table.__tablename__)
- session.execute(query)
- else:
- assert len(self.pending) == 1
- self.pending.pop()
- query = gen_insert_query(fdb_entry, self.fdb_attr,
- self.dbt.find_db_table.__tablename__)
- session.execute(query)
-
- fdb_entry = self.__update_fdb_entry(
- session, self.solver_id_map[fdb_obj['solver_name']])
- fdb_entry.kernel_group = fdb_entry.id
- query = gen_update_query(fdb_entry, ['kernel_group'],
- self.dbt.find_db_table.__tablename__)
- session.execute(query)
-
- if fdb_obj['reason'] == 'Success':
- self.__compose_kernel_entry(session, fdb_obj, fdb_entry)
- self.logger.info('Updating find Db(Build) for job_id=%s',
- self.job.id)
- else:
- # JD: add info about reason to the logs table
- fdb_entry.valid = False
- else:
- self.logger.warning("Failed find_db compile, cfg_id: %s, obj: %s",
- fin_json['config_tuna_id'], fdb_obj)
- else:
- status = [{
- 'solver': 'all',
- 'success': False,
- 'result': 'Find Compile: No results'
- }]
-
- session.commit()
-
- return status
-
- def process_fdb_w_kernels(self,
- session,
- fin_json,
- result_str='miopen_find_compile_result',
- check_str='find_compiled'):
- """initiate find db update"""
-
- callback = self.__update_fdb_w_kernels
- status = session_retry(
- session, callback,
- lambda x: x(session, fin_json, result_str, check_str), self.logger)
-
- if not status:
- self.logger.warning('Fin: Unable to update Database')
- status = [{
- 'solver': 'all',
- 'success': False,
- 'result': 'Fin: Unable to update Database'
- }]
-
- return status
-
- def __add_sql_objs(self, session, obj_list):
- """add sql objects to the table"""
- for obj in obj_list:
- if isinstance(obj, SimpleDict):
- if has_attr_set(obj, self.fdb_attr):
- query = gen_insert_query(obj, self.fdb_attr,
- self.dbt.find_db_table.__tablename__)
- session.execute(query)
- else:
- return False
- else:
- session.add(obj)
- session.commit()
- return True
-
- def __result_queue_commit(self, session, close_job):
- """commit the result queue and set mark job complete"""
- while not self.result_queue.empty():
- obj_list = []
- res_list = self.result_queue.get(True, 1)
- res_job = res_list[0][0]
- for _, obj in res_list:
- obj_list.append(obj)
-
- self.logger.info("commit pending job %s, #objects: %s", res_job.id,
- len(obj_list))
- status = session_retry(session, self.__add_sql_objs,
- lambda x: x(session, obj_list), self.logger)
- if not status:
- self.logger.error("Failed commit pending job %s", res_job.id)
- return False
-
- this_job = self.job
-
- #set job states after successful commit
- self.job = res_job
- close_job()
-
- self.job = this_job
-
- return True
-
- def close_job(self):
- """mark a job complete"""
-
- def result_queue_drain(self):
- """check for lock and commit the result queue"""
- if self.result_queue_lock.acquire(block=False):
- with DbSession() as session:
- self.__result_queue_commit(session, self.close_job)
- self.result_queue_lock.release()
- return True
- return False
-
- def reset_job_state(self):
- """finish committing result queue"""
- super().reset_job_state()
- self.result_queue_drain()
-
def init_check_env(self):
"""check environment on the first run"""
if self.first_pass:
@@ -885,20 +640,30 @@ def run_fin_cmd(self):
['/opt/rocm/bin/fin', '-i',
self.get_fin_input(), '-o', fin_output]) # pylint: disable=no-member
- ret_code, _ = super().run_command(cmd)
+ ret_code, out_str = super().run_command(cmd)
if ret_code != 0:
- return None
+ result = {
+ 'solver':
+ 'all',
+ 'success':
+ False,
+ 'result':
+ out_str[-128:].replace('\n', ';').replace('\'', '"').replace(
+ '%', 'x').replace(':', ': ') # correct string for sql
+ }
+ return result
# load the output json file and strip the env
fin_json = json.loads(self.machine.read_file(fin_output))[1:]
assert len(fin_json) == 1
- # JD: if we implement multiple jobs per fin launch, this would be a loop
fin_json = fin_json[0]
return fin_json
def step(self):
"""Inner loop for Process run defined in worker_interface"""
+ _, self.id_solver_map = get_id_solvers(
+ ) #hyphenated names used by miopen::solver.ToString()
self.multiproc = True
if "applicability" in self.fin_steps:
self.applicability()
diff --git a/tuna/miopen/worker/fin_eval.py b/tuna/miopen/worker/fin_eval.py
index 6c755726f..58369d8c5 100644
--- a/tuna/miopen/worker/fin_eval.py
+++ b/tuna/miopen/worker/fin_eval.py
@@ -26,21 +26,14 @@
###############################################################################
"""Fin Evaluator class implements the worker interface. The purpose of this class
is to run fin commands in benchmarking mode"""
-from time import sleep
-import random
-import functools
import json
from typing import List, Dict
-from sqlalchemy.exc import OperationalError
from tuna.miopen.worker.fin_class import FinClass
from tuna.miopen.worker.fin_utils import fin_job
-from tuna.miopen.worker.fin_utils import get_fin_slv_status, get_fin_result
from tuna.dbBase.sql_alchemy import DbSession
-from tuna.utils.db_utility import session_retry, gen_update_query
-
-MAX_ERRORED_JOB_RETRIES = 3
+from tuna.utils.db_utility import session_retry
class FinEvaluator(FinClass):
@@ -49,16 +42,8 @@ class FinEvaluator(FinClass):
def __init__(self, **kwargs):
super().__init__(**kwargs)
- self.envmt.append(f"HIP_VISIBLE_DEVICES={self.gpu_id}")
-
- def get_job(self, find_state, set_state, imply_end):
- """Polling to see if job available"""
- self.logger.info('find job: %s', find_state)
- if not super().get_job(find_state, set_state, imply_end):
- with self.bar_lock:
- self.num_procs.value -= 1
- return False
- return True
+ if self.gpu_id != -1:
+ self.envmt.append(f"HIP_VISIBLE_DEVICES={self.gpu_id}")
def check_gpu(self):
"""Function to check gpu heartbeat"""
@@ -118,13 +103,34 @@ def fin_pdb_input(self, _fjob):
assert perf_compile_res
fjob['miopen_perf_compile_result'] = perf_compile_res
- fjob = [fjob]
- return fjob
+ return [fjob]
def fin_fdb_input(self, _fjob: Dict) -> List[Dict]:
"""prepare find db command input for fin"""
fjob = _fjob.copy()
with DbSession() as session:
+ find_compile_res = []
+
+ # pylint: disable=comparison-with-callable
+ query = session.query(self.dbt.solver_app).filter(
+ self.dbt.solver_app.session == self.dbt.session.id,
+ self.dbt.solver_app.config == self.job.config,
+ self.dbt.solver_app.applicable == 1)
+ # pylint: enable=comparison-with-callable
+
+ res = session_retry(session, query.all, lambda x: x(), self.logger)
+ for slv_entry in res:
+ slv_name = self.id_solver_map[slv_entry.solver]
+ if not self.job.solver or slv_name == self.job.solver:
+ compile_entry = {
+ 'solver_name': slv_name,
+ 'find_compiled': False,
+ 'kernel_objects': []
+ }
+ find_compile_res.append(compile_entry)
+
+ solvers = [x['solver_name'] for x in find_compile_res]
+
fdb_entry = self.dbt.find_db_table()
fdb_entry.num_cu = self.dbt.session.num_cu
fdb_entry.config = self.config.id
@@ -139,32 +145,26 @@ def fin_fdb_input(self, _fjob: Dict) -> List[Dict]:
fdb_query = fdb_query.filter(self.dbt.find_db_table.workspace_sz != -1,
self.dbt.find_db_table.valid == 1)
- find_compile_res = []
- # Enumerate all solvers for this config
res = session_retry(session, fdb_query.all, lambda x: x(), self.logger)
+
for fdb_rec in res:
slv_name = self.id_solver_map[fdb_rec.solver]
if not self.job.solver or slv_name == self.job.solver:
- compile_entry = {
- 'algorithm': fdb_rec.alg_lib,
- 'find_compiled': True,
- 'solver_name': slv_name,
- 'workspace': fdb_rec.workspace_sz
- }
- kernel_objects = []
+ compile_entry = find_compile_res[solvers.index(slv_name)]
+ compile_entry['find_compiled'] = True
+
blobs = session.query(self.dbt.kernel_cache).filter(
+ self.dbt.kernel_cache.valid == 1,
self.dbt.kernel_cache.kernel_group == fdb_rec.kernel_group)
res = session_retry(session, blobs.all, lambda x: x(), self.logger)
for obj in res:
- kernel_objects.append({
+ compile_entry['kernel_objects'].append({
'blob': obj.kernel_blob.decode('utf-8'),
'comp_options': obj.kernel_args,
'kernel_file': obj.kernel_name,
'md5_sum': obj.kernel_hash,
'uncompressed_size': obj.uncompressed_size
})
- compile_entry['kernel_objects'] = kernel_objects
- find_compile_res.append(compile_entry)
assert find_compile_res
fjob['miopen_find_compile_result'] = find_compile_res
@@ -182,7 +182,7 @@ def get_fin_input(self):
fjob = self.fin_pdb_input(fjob)
elif self.fin_steps[0] == 'miopen_find_eval':
fjob = self.fin_fdb_input(fjob)
- except AssertionError as err:
+ except (AssertionError, ValueError) as err:
self.logger.error('Unable to get compiled objects for job %s : %s',
self.job.id, err)
raise AssertionError from err
@@ -190,137 +190,40 @@ def get_fin_input(self):
return self.machine.write_file(json.dumps(fjob, indent=2).encode(),
is_temp=True)
- def update_fdb_eval_entry(self, session, fdb_obj):
- """update fdb with individual fin json entry"""
- if fdb_obj['evaluated']:
- obj, _ = self.get_fdb_entry(session,
- self.solver_id_map[fdb_obj['solver_name']])
- if not obj:
- self.logger.info(
- 'Unable to find fdb entry for config: %s, solver: %s, '\
- 'arch: %s, num_cu: %s, direction: %s',
- self.config.id, self.solver_id_map[fdb_obj['solver_name']],
- self.dbt.session.arch, self.dbt.session.num_cu, self.config.direction)
- return False
+ def get_job(self, find_state, set_state, imply_end):
+ """Polling to see if job available"""
+ self.logger.info('find job: %s', find_state)
- fdb_entry = obj
- fdb_entry.alg_lib = fdb_obj['algorithm']
- fdb_entry.kernel_time = fdb_obj['time']
- fdb_entry.workspace_sz = fdb_obj['workspace']
- fdb_entry.session = self.dbt.session.id
- fdb_entry.params = fdb_obj['params']
+ if not super().get_job(find_state, set_state, imply_end):
+ return False
+ return True
- self.logger.info('Updating find db(Eval) for job_id=%s', self.job.id)
- query = gen_update_query(fdb_entry, self.fdb_attr,
- self.dbt.find_db_table.__tablename__)
- session.execute(query)
- session.commit()
+ def check_env(self) -> bool:
+ """Check the GPU on the machine matches the GPU specified in session table"""
+ if super().check_env():
+ if self.dbt.session.arch != self.machine.arch or \
+ self.dbt.session.num_cu != self.machine.num_cu:
+ self.logger.error(
+ 'Session arch/num_cu (%s/%s) does not match env arch/num_cu (%s/%s)',
+ self.dbt.session.arch, self.dbt.session.num_cu, self.machine.arch,
+ self.machine.num_cu)
+ return False
else:
- self.logger.warning("Not evaluated: job(%s), solver(%s), %s", self.job.id,
- fdb_obj['solver_name'], fdb_obj['reason'])
return False
return True
- def process_fdb_eval(
- self,
- fin_json: Dict,
- result_str: str = 'miopen_find_eval_result') -> List[Dict]:
- """process find db eval json results"""
- status = []
- fdb_obj = None
- with DbSession() as session:
-
- def actuator(func, fdb_obj):
- return func(session, fdb_obj)
-
- for fdb_obj in fin_json[result_str]:
- self.logger.info('Processing object: %s', fdb_obj)
- slv_stat = get_fin_slv_status(fdb_obj, 'evaluated')
- #retry returns false on failure, callback return on success
- ret = session_retry(session, self.update_fdb_eval_entry,
- functools.partial(actuator, fdb_obj=fdb_obj),
- self.logger)
- if not ret:
- self.logger.warning('FinEval: Unable to update Database')
- slv_stat['success'] = False
- slv_stat['result'] = fdb_obj['reason']
-
- status.append(slv_stat)
-
- return status
-
- def clean_cache_table(self):
- """Remove the fin cache kernel entries for this job"""
- with DbSession() as session:
- try:
- old_cache = session.query(self.dbt.fin_cache_table)\
- .filter(self.dbt.fin_cache_table.job_id == self.job.id)
- old_cache.delete()
- session.commit()
- except OperationalError as err:
- session.rollback()
- self.logger.warning('FinEval: Unable to clean %s: %s',
- self.dbt.fin_cache_table.__tablename__, err)
-
- def close_job(self):
- """mark a job complete"""
- self.set_job_state('evaluated')
- self.clean_cache_table()
-
def step(self):
"""Function that defined the evaluator specific functionality which implies picking up jobs
to benchmark and updating DB with evaluator specific state"""
- self.pending = []
- self.result_queue_drain()
if not self.init_check_env():
return False
- if not self.get_job("compiled", "eval_start", True):
- while not self.result_queue_drain():
- sleep(random.randint(1, 10))
- return False
-
- orig_state = 'compiled'
- self.logger.info('Acquired new job: job_id=%s', self.job.id)
- self.set_job_state('evaluating')
- fin_json = self.run_fin_cmd()
-
- failed_job = True
- result_str = ''
- if fin_json:
- if 'miopen_find_eval_result' in fin_json:
- status = self.process_fdb_eval(fin_json)
-
- elif 'miopen_perf_eval_result' in fin_json:
- with DbSession() as session:
- status = self.process_fdb_w_kernels(
- session,
- fin_json,
- result_str='miopen_perf_eval_result',
- check_str='evaluated')
-
- success, result_str = get_fin_result(status)
- failed_job = not success
-
- if failed_job:
- if not self.check_gpu():
- return False
- if self.job.retries >= (MAX_ERRORED_JOB_RETRIES - 1):
- self.logger.warning('max job retries exhausted, setting to errored')
- self.set_job_state('errored', result=result_str)
- else:
- self.logger.warning('resetting job state to %s, incrementing retries',
- orig_state)
- self.set_job_state(orig_state,
- increment_retries=True,
- result=result_str)
- elif self.pending:
- self.set_job_state('evaluated_pend', result=result_str)
- self.result_queue.put(self.pending)
- else:
- self.set_job_state('evaluated', result=result_str)
- self.clean_cache_table()
+ fin_json = None
+ try:
+ fin_json = self.run_fin_cmd()
+ except AssertionError:
+ self.logger.error('Error building Fin input, job(%s)', self.job.id)
- return True
+ return fin_json
diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py
index 847c5c070..140d1542b 100644
--- a/tuna/mituna_interface.py
+++ b/tuna/mituna_interface.py
@@ -25,57 +25,90 @@
#
###############################################################################
"""Interface class to set up and launch tuning functionality"""
-from multiprocessing import Value, Lock, Queue as mpQueue
+import os
+from multiprocessing import Value, Lock, Queue as mpQueue, Process
from typing import Optional, Dict, Any, List
from io import StringIO
+from functools import lru_cache
+import json
import logging
import argparse
+import subprocess
+import time
+import threading
+import asyncio
+from datetime import timedelta
+from sqlalchemy.exc import NoInspectionAvailable
+from sqlalchemy.inspection import inspect
+import aioredis
+import kombu
from paramiko.channel import ChannelFile
+
from tuna.worker_interface import WorkerInterface
from tuna.machine import Machine
from tuna.libraries import Library
from tuna.utils.logger import setup_logger
-from tuna.utils.utility import get_env_vars
+from tuna.utils.utility import get_env_vars, SimpleDict
+from tuna.dbBase.sql_alchemy import DbSession
+from tuna.celery_app.celery_app import stop_active_workers, stop_named_worker
+from tuna.celery_app.celery_app import get_backend_env, purge_queue
+from tuna.celery_app.utility import get_q_name
+from tuna.celery_app.celery_workers import launch_celery_worker
+from tuna.libraries import Operation
+from tuna.custom_errors import CustomError
+from tuna.utils.db_utility import gen_update_query, session_retry
+
+job_counter_lock = threading.Lock()
-class MITunaInterface():
+class MITunaInterface(): #pylint:disable=too-many-instance-attributes,too-many-public-methods
""" Interface class extended by libraries. The purpose of this class is to define
common functionalities. """
def __init__(self, library=Library.MIOPEN) -> None:
- self.library: Library = library
+ self.self: Library = self
- self.logger: logging.Logger = setup_logger(logger_name=self.library.value,
+ self.logger: logging.Logger = setup_logger(logger_name=library.value,
add_streamhandler=True)
self.args: argparse.Namespace
+ self.fetch_state: set = set()
+ self.max_job_retries = 10
+ self.dbt = None
+ self.operation = None
+ self.db_name = os.environ['TUNA_DB_NAME']
+ self.prefix = None
+
def check_docker(self,
worker: WorkerInterface,
- dockername="miopentuna") -> None:
+ dockername="miopentuna") -> bool:
"""! Checking for docker
@param worker The worker interface instance
@param dockername The name of the docker
"""
out2: ChannelFile
- self.logger.warning("docker not installed or requires sudo .... ")
_, out2, _ = worker.exec_command("sudo docker info")
while not out2.channel.exit_status_ready():
self.logger.warning(out2.readline())
if out2.channel.exit_status > 0:
self.logger.warning(
"docker not installed or failed to run with sudo .... ")
- else:
- out: StringIO = StringIO()
- line: Optional[str] = None
- _, out, _ = worker.exec_command(f"sudo docker images | grep {dockername}")
- for line in out.readlines():
- if line is not None:
- if line.find(dockername) != -1:
- self.logger.warning('%s docker image exists', dockername)
- break
- if line is None:
- self.logger.warning('%s docker image does not exist', dockername)
+ return False
+
+ out: StringIO = StringIO()
+ line: Optional[str] = None
+ _, out, _ = worker.exec_command(f"sudo docker images | grep {dockername}")
+ for line in out.readlines():
+ if line is not None:
+ if line.find(dockername) != -1:
+ self.logger.warning('%s docker image exists', dockername)
+ return True
+ if line is None:
+ self.logger.warning('%s docker image does not exist', dockername)
+ return False
+
+ return False
def check_status(self,
worker: WorkerInterface,
@@ -121,7 +154,7 @@ def check_status(self,
return True
def add_tables(self) -> bool:
- """Add library specific tables"""
+ """Add self specific tables"""
return self.add_tables()
def get_num_procs(self, machine: Machine) -> List:
@@ -144,38 +177,48 @@ def get_num_procs(self, machine: Machine) -> List:
return worker_ids
- def get_f_vals(self, machine: Machine, worker_ids: range) -> Dict[str, Any]:
+ def get_f_vals(self,
+ machine: Machine,
+ worker_ids: range,
+ tuning=False) -> Dict[str, Any]:
+ #pylint:disable=unused-argument
"""Determine kwargs for worker_interface"""
f_vals: Dict[str, Any]
f_vals = self.compose_f_vals(machine)
- f_vals["num_procs"] = Value('i', len(worker_ids))
f_vals['envmt'] = self.get_envmt()
+
+ if not tuning:
+ f_vals["num_procs"] = Value('i', len(worker_ids))
+
return f_vals
def get_envmt(self):
"""Get runtime envmt"""
raise NotImplementedError("Not implemented")
- def compose_f_vals(self, machine: Machine) -> Dict[str, Any]:
+ def compose_f_vals(self, machine: Machine, tuning=False) -> Dict[str, Any]:
"""! Compose dict for WorkerInterface constructor
@param args The command line arguments
@param machine Machine instance
"""
f_vals: Dict[str, Any] = {}
- f_vals["barred"] = Value('i', 0)
- f_vals["bar_lock"] = Lock()
- #multiprocess queue for jobs, shared on machine
- f_vals["job_queue"] = mpQueue()
- f_vals["job_queue_lock"] = Lock()
- f_vals["result_queue"] = mpQueue()
- f_vals["result_queue_lock"] = Lock()
- f_vals["machine"] = machine
f_vals["b_first"] = True
- f_vals["end_jobs"] = Value('i', 0)
+
+ #adding non-serializable obj when not running through celery
+ if not tuning:
+ f_vals["machine"] = machine
+ f_vals["bar_lock"] = Lock()
+ #multiprocess queue for jobs, shared on machine
+ f_vals["job_queue"] = mpQueue()
+ f_vals["job_queue_lock"] = Lock()
+ f_vals["end_jobs"] = Value('i', 0)
return f_vals
- def get_kwargs(self, gpu_idx: int, f_vals: Dict[str, Any]) -> Dict[str, Any]:
+ def get_kwargs(self,
+ gpu_idx: int,
+ f_vals: Dict[str, Any],
+ tuning=False) -> Dict[str, Any]:
"""! Helper function to set up kwargs for worker instances
@param gpu_idx Unique ID of the GPU
@param f_vals Dict containing runtime information
@@ -184,20 +227,434 @@ def get_kwargs(self, gpu_idx: int, f_vals: Dict[str, Any]) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {}
kwargs = {
- 'machine': f_vals["machine"],
'gpu_id': gpu_idx,
- 'num_procs': f_vals["num_procs"],
- 'barred': f_vals["barred"],
- 'bar_lock': f_vals["bar_lock"],
'envmt': envmt,
- 'job_queue': f_vals["job_queue"],
- 'job_queue_lock': f_vals["job_queue_lock"],
- 'result_queue': f_vals["result_queue"],
- 'result_queue_lock': f_vals["result_queue_lock"],
'label': self.args.label,
'docker_name': self.args.docker_name,
- 'end_jobs': f_vals['end_jobs'],
'session_id': self.args.session_id
}
+ #adding non-serializable obj when not running through celery
+ if not tuning:
+ kwargs["machine"] = f_vals["machine"]
+ kwargs["job_queue"] = f_vals["job_queue"]
+ kwargs["job_queue_lock"] = f_vals["job_queue_lock"]
+ kwargs["num_procs"] = f_vals["num_procs"]
+ kwargs["bar_lock"] = f_vals["bar_lock"]
+ kwargs["end_jobs"] = f_vals["end_jobs"]
+ kwargs["job_queue"] = f_vals["job_queue"]
+ kwargs["job_queue_lock"] = f_vals["job_queue_lock"]
+
return kwargs
+
+ def get_job_list(self, session, find_state, claim_num):
+ """Get list of jobs"""
+ raise NotImplementedError("Not implemented")
+
+ def get_jobs(self,
+ session: DbSession,
+ find_state: List[str],
+ set_state: str,
+ session_id: int,
+ claim_num: int = None,
+ no_update=False):
+ """Interface function to get jobs based on session and find_state"""
+ #job_rows: List[SimpleDict]
+ ids: list
+ row: SimpleDict
+
+ self.logger.info('Fetching DB rows...')
+ job_list = self.get_job_list(session, find_state, claim_num)
+
+ if not self.check_jobs_found(job_list, find_state, session_id):
+ return []
+
+ if no_update:
+ return job_list
+
+ ids = [row.id for row in job_list]
+ self.logger.info("%s jobs %s", find_state, ids)
+ self.logger.info('Updating job state to %s', set_state)
+ for job in job_list:
+ job.state = set_state
+ if self.dbt is not None:
+ query: str = gen_update_query(job, ['state'],
+ self.dbt.job_table.__tablename__)
+ else:
+ raise CustomError('DBTable must be set')
+ session.execute(query)
+
+ session.commit()
+
+ return job_list
+
+ def shutdown_workers(self):
+ """Shutdown all active celery workers regardless of queue"""
+ return stop_active_workers()
+
+ def cancel_consumer(self, queue):
+ """Cancel consumers for queue"""
+ try:
+ cmd = f"celery -A tuna.celery_app.celery_app control cancel_consumer {queue}"
+ subp = subprocess.Popen( #pylint: disable=consider-using-with
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ shell=True,
+ universal_newlines=True)
+
+ #filter the workers by session id
+ sess_str = "sess_" + queue.split('_')[-1]
+ stdout, _ = subp.stdout, subp.stderr
+ while True:
+ line = stdout.readline()
+ if not line:
+ break
+ #stop workers that were feeding from this queue
+ if "->" in line and sess_str in line:
+ hostname = line.split('->')[1].split()[0].split(':')[0]
+ stop_named_worker(hostname)
+
+ except Exception as exp: #pylint: disable=broad-exception-caught
+ self.logger.warning(
+ 'Error occurred trying to cancel consumer for queue: %s ', queue)
+ self.logger.warning(exp)
+ return False
+
+ self.logger.info('Sucessfully cancelled consumer for queue: %s', queue)
+
+ return True
+
+ def celery_enqueue_call(self, context, q_name, task_id=False):
+ """Wrapper function for celery enqueue func"""
+ raise NotImplementedError('Not implemented')
+
+ def enqueue_jobs(self, job_counter, job_batch_size, q_name):
+ """Enqueue celery jobs"""
+ self.logger.info('Starting enqueue')
+ with DbSession() as session:
+ while True:
+ job_list = []
+ #get all the jobs from mySQL
+ job_list = self.get_jobs(
+ session,
+ self.fetch_state,
+ self.set_state, #pylint: disable=no-member
+ self.args.session_id, #pylint: disable=no-member
+ job_batch_size)
+
+ with job_counter_lock:
+ job_counter.value = job_counter.value + len(job_list)
+
+ for i in range(0, len(job_list), job_batch_size):
+ batch_jobs = job_list[i:min(i + job_batch_size, len(job_list))]
+ context_list = self.get_context_list(session, batch_jobs)
+ for context in context_list:
+ #calling celery task, enqueuing to celery queue
+ self.celery_enqueue_call(context, q_name=q_name)
+
+ self.logger.info('Job counter: %s', job_counter.value)
+ if not job_list:
+ self.logger.info('All tasks added to queue')
+ break
+
+ async def cleanup_redis_results(self, prefix):
+ """Remove stale redis results by key"""
+ backend_port, backend_host = get_backend_env()
+ redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15")
+
+ keys = []
+ cursor = "0"
+ if prefix:
+ #a prefix is necessary when the need to different results in redis based on operation
+ #withough a prefix the redis key defaults to: "celery-task-meta-"
+ #with a prefix the key will look like: "celery-task-meta--"
+ #the prefix can be applied when filtering the redis keys as bellow
+ cursor, results = await redis.scan(cursor, match=f"*{prefix}*")
+ else:
+ #no prefix, match any key
+ cursor, results = await redis.scan(cursor, match="*")
+ keys.extend(results)
+ self.logger.info('Found %s old results', len(results))
+ for key in keys:
+ try:
+ await redis.delete(key)
+ except aioredis.exceptions.ResponseError as red_err:
+ self.logger.error(red_err)
+ self.logger.info(key.decode('utf-8'))
+ continue
+
+ self.logger.info('Done removing old redis results for prefix: %s', prefix)
+
+ return True
+
+ async def consume(self, job_counter, prefix):
+ """Retrieve celery results from redis db"""
+
+ backend_port, backend_host = get_backend_env()
+ redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15")
+
+ while job_counter.value > 0:
+ cursor = "0"
+ keys = []
+ while cursor != 0:
+ if prefix:
+ #a prefix is necessary when the need to different results in redis based on operation
+ #withough a prefix the redis key defaults to: "celery-task-meta-"
+ #with a prefix the key will look like: "celery-task-meta--"
+ #the prefix can be applied when filtering the redis keys as bellow
+ cursor, results = await redis.scan(cursor, match=f"*{prefix}*")
+ else:
+ #no prefix, match any key
+ cursor, results = await redis.scan(cursor, match="*")
+ keys.extend(results)
+ self.logger.info('Found %s results', len(results))
+ for key in keys:
+ try:
+ data = await redis.get(key)
+ if data:
+ _ = await self.parse_result(data.decode('utf-8'))
+ await redis.delete(key)
+ with job_counter_lock:
+ job_counter.value = job_counter.value - 1
+ except aioredis.exceptions.ResponseError as red_err:
+ self.logger.error(red_err)
+ self.logger.info(key.decode('utf-8'))
+
+ await asyncio.sleep(1)
+ self.logger.info('Job counter reached 0')
+ await redis.close()
+
+ return True
+
+ def prep_tuning(self):
+ """Prep env for tuning start"""
+ cmd = None
+ subp_list = []
+ q_name = None
+ if self.operation == Operation.COMPILE:
+ q_name = get_q_name(self, op_compile=True)
+ cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{self.args.session_id} -Q {q_name}" #pylint: disable=line-too-long
+ else:
+ q_name = get_q_name(self, op_eval=True)
+ cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -c 1 -n tuna_HOSTNAME_sess_{self.args.session_id}_gpu_id_GPUID -Q {q_name}" #pylint: disable=line-too-long
+
+ self.logger.info('celery Q name: %s', q_name)
+ if not self.args.enqueue_only:
+ try:
+ self.logger.info('Launching celery workers for queue %s', q_name)
+ subp_list = launch_celery_worker(self.operation, cmd, self.args, True)
+ self.logger.info('Done launching celery workers')
+ if not subp_list:
+ raise CustomError('Could not launch celery worker')
+ except kombu.exceptions.OperationalError as k_err:
+ self.logger.error('Redis error ocurred: %s', k_err)
+ return False
+ else:
+ purge_queue([q_name])
+
+ return q_name, subp_list
+
+ #pylint: disable=too-many-locals
+ def tune(self, job_batch_size=1000):
+ """tuning loop to spin out celery tasks"""
+
+ if self.args.shutdown_workers:
+ self.logger.info('Shutting down all celery workers')
+ stop_active_workers()
+ return True
+
+ try:
+ q_name, subp_list = self.prep_tuning()
+ except CustomError as verr:
+ self.logger.error(verr)
+ return False
+
+ try:
+ #if enqueue_only is False, we launch the celery workers
+ if not self.args.enqueue_only:
+ for subp in subp_list:
+ subp.wait()
+ return True
+ except KeyboardInterrupt:
+ for subp in subp_list:
+ subp.kill()
+ return False
+
+ start = time.time()
+
+ #set job count to 1 until first job fetch is finished
+ job_counter = Value('i', 1)
+ try:
+ enqueue_proc = Process(target=self.enqueue_jobs,
+ args=[job_counter, job_batch_size, q_name])
+ #Start enqueue proc
+ enqueue_proc.start()
+
+ #cleanup old results
+ cleanup_proc = Process(target=self.async_wrap,
+ args=(self.cleanup_redis_results, self.prefix))
+ cleanup_proc.start()
+ cleanup_proc.join()
+
+ #start async consume thread, blocking
+ consume_proc = Process(target=self.async_wrap,
+ args=(self.consume, job_counter, self.prefix))
+ self.logger.info('Starting consume thread')
+ consume_proc.start()
+
+ enqueue_proc.join()
+ #enqueue finished first fetch, remove hold on job_counter
+ with job_counter_lock:
+ job_counter.value = job_counter.value - 1
+
+ #check for new jobs
+ while consume_proc.is_alive():
+ enqueue_proc = Process(target=self.enqueue_jobs,
+ args=[job_counter, job_batch_size, q_name])
+ enqueue_proc.start()
+ enqueue_proc.join()
+ time.sleep(10)
+
+ consume_proc.join()
+
+ except (KeyboardInterrupt, Exception) as exp: #pylint: disable=broad-exception-caught
+ self.logger.error('Error ocurred %s', exp)
+ purge_queue([q_name])
+ self.cancel_consumer(q_name)
+ self.reset_job_state_on_ctrl_c()
+ with job_counter_lock:
+ job_counter.value = 0
+
+ self.cancel_consumer(q_name)
+ end = time.time()
+ self.logger.info("Took {:0>8} to tune".format( #pylint: disable=consider-using-f-string
+ str(timedelta(seconds=end - start))))
+
+ return True
+
+ async def async_callback(self, async_func, *args):
+ """Wrapper function to await on async function"""
+ await async_func(*args)
+
+ def async_wrap(self, async_func, *args):
+ """Run async function"""
+ try:
+ asyncio.run(self.async_callback(async_func, *args))
+ except KeyboardInterrupt:
+ self.logger.warning('Keyboard interrupt caught, terminating')
+
+ def reset_job_state_on_ctrl_c(self):
+ """Reset job state for jobs in flight"""
+ temp_obj = SimpleDict()
+ temp_obj.session_id = self.args.session_id #pylint: disable=invalid-name
+ attribs = ['state']
+ temp_obj.state = 1
+
+ self.logger.info('Resetting job state in DB for in flight jobs')
+
+ if self.operation == Operation.COMPILE:
+ state = 16
+ elif self.operation == Operation.EVAL:
+ state = 12
+
+ query = gen_update_query(temp_obj, attribs,
+ self.dbt.job_table.__tablename__,
+ [('session', self.args.session_id),
+ ('state', state)])
+ with DbSession() as session:
+
+ #pylint: disable=duplicate-code
+ def callback() -> bool:
+ session.execute(query)
+ session.commit()
+ return True
+
+ #pylint: enable=duplicate-code
+
+ assert session_retry(session, callback, lambda x: x(), self.logger)
+ self.logger.info('Sucessfully reset job state')
+ return True
+
+ return False
+
+ def has_tunable_operation(self):
+ """Check if current operation is a tuning operation"""
+ raise NotImplementedError("Not implemented")
+
+ def get_job_attr(self):
+ """Get job attr for row selection"""
+ job_attr: List[str] = None
+ try:
+ job_attr = [column.name for column in inspect(self.dbt.job_table).c]
+ job_attr.remove("insert_ts")
+ job_attr.remove("update_ts")
+ except NoInspectionAvailable as error:
+ self.logger.warning("Ignoring error for init_session: %s", error)
+ return job_attr
+
+ def check_jobs_found(self, job_rows: List[SimpleDict], find_state: List[Any],
+ session_id: int) -> bool:
+ """check for end of jobs"""
+ if not job_rows:
+ # we are done
+ self.logger.warning('No %s jobs found, session %s', find_state,
+ session_id)
+ return False
+ return True
+
+ @lru_cache(1)
+ def get_context_items(self):
+ """Helper function to get items for celery job context"""
+ kwargs = None
+ f_vals = self.get_f_vals(Machine(local_machine=True), range(0), tuning=True)
+ kwargs = self.get_kwargs(0, f_vals, tuning=True)
+ return kwargs
+
+ def serialize_jobs(self, session, batch_jobs):
+ """Return list of serialize jobs"""
+ raise NotImplementedError("Not implemented")
+
+ def build_context(self, serialized_jobs):
+ """Build context list for enqueue job"""
+ raise NotImplementedError("Not implemented")
+
+ def get_context_list(self, session, batch_jobs):
+ """Return list of jobs (context) for celery queue"""
+
+ context_list: List[dict] = None
+ serialized_jobs = self.serialize_jobs(session, batch_jobs)
+ #build context for each celery task
+ context_list = self.build_context(serialized_jobs)
+
+ return context_list
+
+ async def parse_result(self, data):
+ """Function callback for celery async jobs to store results"""
+ data = json.loads(data)
+
+ with DbSession() as session:
+ try:
+ fin_json = data['result']['ret']
+ context = data['result']['context']
+ except KeyError as kerr:
+ self.logger.error(kerr)
+ return False
+
+ self.logger.info('Parsing: %s', fin_json)
+ if self.operation == Operation.COMPILE:
+ self.process_compile_results(session, fin_json, context)
+ elif self.operation == Operation.EVAL:
+ self.process_eval_results(session, fin_json, context)
+ else:
+ raise CustomError('Unsupported tuning operation')
+
+ return True
+
+ def process_compile_results(self, session, fin_json, context):
+ """Process result from fin_build worker"""
+ raise NotImplementedError("Not implemented")
+
+ def process_eval_results(self, session, fin_json, context):
+ """Process fin_json result"""
+ raise NotImplementedError("Not implemented")
diff --git a/tuna/parse_args.py b/tuna/parse_args.py
index d6e99b347..e3011e781 100644
--- a/tuna/parse_args.py
+++ b/tuna/parse_args.py
@@ -47,8 +47,11 @@ class TunaArgs(Enum):
LABEL: str = 'label'
RESTART_MACHINE: str = 'restart_machine'
DOCKER_NAME: str = 'docker_name'
+ SHUTDOWN_WORKERS: str = 'shutdown_workers'
+ ENQUEUE_ONLY: str = 'enqueue_only'
+# pylint: disable=too-many-branches
def setup_arg_parser(desc: str,
arg_list: List[TunaArgs],
parser: argparse.Namespace = None,
@@ -61,15 +64,17 @@ def setup_arg_parser(desc: str,
parser.add_argument('--yaml', action=jsonargparse.ActionConfigFile)
if TunaArgs.ARCH in arg_list:
- parser.add_argument(
- '-a',
- '--arch',
- type=str,
- dest='arch',
- default=None,
- required=False,
- help='Architecture of machines',
- choices=['gfx900', 'gfx906', 'gfx908', 'gfx1030', 'gfx90a', 'gfx940'])
+ parser.add_argument('-a',
+ '--arch',
+ type=str,
+ dest='arch',
+ default=None,
+ required=False,
+ help='Architecture of machines',
+ choices=[
+ 'gfx900', 'gfx906', 'gfx908', 'gfx1030', 'gfx90a',
+ 'gfx940', 'gfx942'
+ ])
if TunaArgs.NUM_CU in arg_list:
parser.add_argument(
'-n',
@@ -79,7 +84,7 @@ def setup_arg_parser(desc: str,
default=None,
required=False,
help='Number of CUs on GPU',
- choices=['36', '56', '60', '64', '104', '110', '120', '228'])
+ choices=['36', '56', '60', '64', '104', '110', '120', '228', '304'])
if TunaArgs.DIRECTION in arg_list:
parser.add_argument(
'-d',
@@ -139,8 +144,19 @@ def setup_arg_parser(desc: str,
'--docker_name',
dest='docker_name',
type=str,
- default='miopentuna',
+ default='',
help='Select a docker to run on. (default miopentuna)')
+ if TunaArgs.SHUTDOWN_WORKERS in arg_list:
+ parser.add_argument('--shutdown_workers',
+ dest='shutdown_workers',
+ action='store_true',
+ help='Shutdown all active celery workers')
+
+ if TunaArgs.ENQUEUE_ONLY in arg_list:
+ parser.add_argument('--enqueue_only',
+ action='store_true',
+ dest='enqueue_only',
+ help='Enqueue jobs to celery queue')
return parser
diff --git a/tuna/rocmlir/config_type.py b/tuna/rocmlir/config_type.py
index f55fe4ac4..7eeedc054 100644
--- a/tuna/rocmlir/config_type.py
+++ b/tuna/rocmlir/config_type.py
@@ -27,12 +27,4 @@
"""Module that encapsulates different configuration types supported by Tuna"""
from enum import Enum
-#ConfigType = Enum('ConfigType', ['convolution', 'gemm', 'attention'])
-
-class ConfigType(Enum):
- convolution: str = 'convolution'
- gemm: str = 'gemm'
- attention: str = 'attention'
-
- def __str__(self) -> str:
- return self.value
+ConfigType = Enum('ConfigType', ['convolution', 'gemm', 'attention'])
diff --git a/tuna/rocmlir/rocmlir_lib.py b/tuna/rocmlir/rocmlir_lib.py
index 5c7275b8e..0eb00e2ab 100644
--- a/tuna/rocmlir/rocmlir_lib.py
+++ b/tuna/rocmlir/rocmlir_lib.py
@@ -33,7 +33,7 @@
from typing import Dict, Any, List, Optional
from tuna.mituna_interface import MITunaInterface
from tuna.parse_args import TunaArgs, setup_arg_parser, args_check
-from tuna.utils.miopen_utility import load_machines
+from tuna.utils.machine_utility import load_machines
from tuna.machine import Machine
from tuna.libraries import Library
@@ -43,6 +43,7 @@
from tuna.miopen.db.build_schema import recreate_triggers
from tuna.rocmlir.triggers import get_timestamp_trigger
from tuna.rocmlir.config_type import ConfigType
+from tuna.dbBase.sql_alchemy import DbSession
class RocMLIR(MITunaInterface):
@@ -184,8 +185,8 @@ def run(self) -> Optional[List[RocMLIRWorker]]:
SessionRocMLIR().add_new_session(self.args, worker)
return None
- assert self.args.execute, \
- "one of --add_tables, --init_session, or --execute must be present"
+ # Must be --execute, because the mutually-exclusive-group argument is
+ # required, and we just checked for --add_tables and --init_session.
res = self.compose_worker_list(machines)
return res
@@ -196,7 +197,10 @@ def get_envmt(self) -> List[str]:
envmt: List[str] = []
return envmt
- def get_kwargs(self, gpu_idx: int, f_vals: Dict[str, Any]) -> Dict[str, Any]:
+ def get_kwargs(self,
+ gpu_idx: int,
+ f_vals: Dict[str, Any],
+ tuning=False) -> Dict[str, Any]:
# pylint: disable=duplicate-code
"""! Helper function to set up kwargs for worker instances
@param gpu_idx Unique ID of the GPU
@@ -206,3 +210,23 @@ def get_kwargs(self, gpu_idx: int, f_vals: Dict[str, Any]) -> Dict[str, Any]:
kwargs['config_type'] = self.args.config_type
return kwargs
+
+ def get_jobs(self,
+ session: DbSession,
+ find_state: List[str],
+ set_state: str,
+ session_id: int,
+ claim_num: int = None,
+ no_update: bool = False):
+ """Get jobs based on find_state"""
+ self.logger.info('Placeholder')
+
+ return True
+
+ def get_context_list(self, session, batch_jobs):
+ """Get a list of context items to be used for celery task"""
+ raise NotImplementedError("Not implemented in rocmlir")
+
+ def celery_enqueue_call(self, context, q_name, task_id=False):
+ """Wrapper function for celery enqueue func"""
+ raise NotImplementedError('Not implemented')
diff --git a/tuna/rocmlir/rocmlir_tables.py b/tuna/rocmlir/rocmlir_tables.py
index 1ccd7f359..9eaf9a027 100644
--- a/tuna/rocmlir/rocmlir_tables.py
+++ b/tuna/rocmlir/rocmlir_tables.py
@@ -164,18 +164,6 @@ class ConvolutionJob(BASE, JobMixin):
index=True)
-class SimpleCSVMixin():
- """Just a method to write whole table as CSV."""
- def export_as_csv(self, filename):
- with open(filename, 'w', encoding='utf8') as f:
- outcsv = csv.writer(f)
-# with DbCursor() as cur:
-# # (may need list(cur....))
-# outcsv.writerows(cur.execute(f"select * from {self.__tablename__};"))
- with DbSession() as session:
- outcsv.writerows(session.execute(sql_select(self.__table__.columns)))
-
-
def make_option_if_not_in_line(option, value, line):
"""If option is not already in line, make an option-value string."""
if f"{option} " in line:
@@ -184,7 +172,7 @@ def make_option_if_not_in_line(option, value, line):
return f"{option} {value} "
-class ConvolutionConfig(BASE, SimpleCSVMixin):
+class ConvolutionConfig(BASE):
"""Represents convolution config table"""
__tablename__ = "rocmlir_conv_config"
@@ -347,7 +335,6 @@ def get_configurations(self, filename):
# Add options if they aren't already supplied.
# We need trailing spaces here to account for the string concat.
- one_config = ""
# For datatype, check for the presence of a positional arg.
if line[0][0] == "-":
one_config = f"{datatype} "
@@ -359,6 +346,13 @@ def get_configurations(self, filename):
one_config += make_option_if_not_in_line("-O", layout, line)
one_config += line
one_config = one_config.strip()
+ if "-F" not in line:
+ one_config += f"{direction} " # -F included in direction.
+ one_config += make_option_if_not_in_line("-f", layout, line)
+ one_config += make_option_if_not_in_line("-I", layout, line)
+ one_config += make_option_if_not_in_line("-O", layout, line)
+ one_config += line
+ one_config = one_config.strip()
if one_config not in configs:
configs.append(one_config)
diff --git a/tuna/rocmlir/rocmlir_worker.py b/tuna/rocmlir/rocmlir_worker.py
index 8c9ab951f..bf8d55ff1 100644
--- a/tuna/rocmlir/rocmlir_worker.py
+++ b/tuna/rocmlir/rocmlir_worker.py
@@ -33,10 +33,11 @@
import functools
import logging
import traceback
-from tenacity import Retrying, stop_after_attempt, before_sleep_log, wait_random, TryAgain
from sqlalchemy.inspection import inspect
+from tenacity import Retrying, stop_after_attempt, before_sleep_log, wait_random
+
from tuna.dbBase.sql_alchemy import DbSession
from tuna.worker_interface import WorkerInterface
from tuna.rocmlir.rocmlir_tables import RocMLIRDBTables
diff --git a/tuna/rocmlir/tuning_space.py b/tuna/rocmlir/tuning_space.py
index 5d1a0d56b..525e4301f 100644
--- a/tuna/rocmlir/tuning_space.py
+++ b/tuna/rocmlir/tuning_space.py
@@ -27,12 +27,4 @@
"""Module that encapsulates different tuning spaces used by tuning driver."""
from enum import Enum
-#TuningSpace = Enum('TuningSpace', ["quick", "full", "exhaustive"])
-
-class TuningSpace(Enum):
- quick: str = 'quick'
- full: str = 'full'
- exhaustive: str = 'exhaustive'
-
- def __str__(self) -> str:
- return self.value
+TuningSpace = Enum('TuningSpace', ["quick", "full", "exhaustive"])
diff --git a/flaskapp/app.py b/tuna/utils/celery_utils.py
similarity index 69%
rename from flaskapp/app.py
rename to tuna/utils/celery_utils.py
index cd08a8206..4bb9ce00f 100644
--- a/flaskapp/app.py
+++ b/tuna/utils/celery_utils.py
@@ -3,7 +3,7 @@
#
# MIT License
#
-# Copyright (c) 2022 Advanced Micro Devices, Inc.
+# Copyright (c) 2024 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -24,24 +24,22 @@
# SOFTWARE.
#
###############################################################################
-"""Tuna Flask App main function"""
-import json
+"""Utility module for celery helper functions"""
+from tuna.utils.utility import SimpleDict
-from flask import Flask
-from flaskapp.views.fdb_key import fdb_key
-from flaskapp.views.grafana import grafana
-app = Flask(__name__)
-app.register_blueprint(fdb_key)
-app.register_blueprint(grafana)
+def prep_default_kwargs(kwargs, job, machine):
+ """Populate kwargs with serialized job and machine"""
+ kwargs["job"] = SimpleDict(**job)
+ kwargs["machine"] = machine
+ return kwargs
-@app.route('/')
-def main():
- """Main entry point"""
- return json.dumps({'success': True}), 200, {'ContentType': 'application/json'}
+def get_cached_worker(context, cached_worker):
+ """Get worker from cache"""
+ worker = cached_worker[context['operation']]
+ worker.job = SimpleDict(**context['job'])
+ worker.gpu_id = context['kwargs']['gpu_id']
-if __name__ == '__main__':
- app.config['SECRET_KEY'] = 't\xbe-\xdc\xe7A\r\x1f\xb7\t\xa6\xa1\x8c\xd14\xf3'
- app.run(debug=True)
+ return worker
diff --git a/tuna/utils/db_utility.py b/tuna/utils/db_utility.py
index 64d74e822..400559c14 100644
--- a/tuna/utils/db_utility.py
+++ b/tuna/utils/db_utility.py
@@ -32,14 +32,13 @@
import logging
from time import sleep
from datetime import datetime
+from typing import Callable, Any, List, Dict
import pymysql
from sqlalchemy.exc import OperationalError, IntegrityError, ProgrammingError
from sqlalchemy import create_engine
-from typing import Callable, Any, List, Dict
from tuna.dbBase.sql_alchemy import DbSession
from tuna.dbBase.base_class import BASE
-from tuna.miopen.db.miopen_tables import Solver
from tuna.utils.metadata import NUM_SQL_RETRIES
from tuna.utils.logger import setup_logger
from tuna.utils.utility import get_env_vars
@@ -108,36 +107,6 @@ def create_indices(all_indices):
continue
-def get_solver_ids():
- """DB solver name to id map"""
- # TODO: Get this info from the SQLAlchemy class # pylint: disable=fixme
- solver_id_map = {}
- with DbSession() as session:
- query = session.query(Solver.solver, Solver.id).filter(Solver.valid == 1)
- res = session_retry(session, query.all, lambda x: x(), LOGGER)
- for slv, sid in res:
- solver_id_map[slv] = sid
- solver_id_map[slv.replace(', ', '-')] = sid
-
- return solver_id_map
-
-
-def get_id_solvers():
- """DB solver id to name map"""
- solver_id_map_c = {}
- solver_id_map_h = {}
- with DbSession() as session:
- query = session.query(Solver.solver, Solver.id).filter(Solver.valid == 1)
- res = session_retry(session, query.all, lambda x: x(), LOGGER)
- for slv, sid in res:
- solver_id_map_c[slv] = sid
- solver_id_map_h[slv.replace(', ', '-')] = sid
- id_solver_map_c = {val: key for key, val in solver_id_map_c.items()}
- id_solver_map_h = {val: key for key, val in solver_id_map_h.items()}
-
- return id_solver_map_c, id_solver_map_h
-
-
def session_retry(session: DbSession,
callback: Callable,
actuator: Callable,
@@ -170,7 +139,7 @@ def get_attr_vals(obj, attr_list):
val = getattr(obj, attr)
if val is None:
val = 'NULL'
- elif isinstance(val, str) or isinstance(val, datetime):
+ elif isinstance(val, (datetime, str)):
val = f"'{val}'"
elif isinstance(val, bytes):
val = val.decode('utf-8')
@@ -181,7 +150,7 @@ def get_attr_vals(obj, attr_list):
return attr_vals
-def gen_update_query(obj, attribs, tablename):
+def gen_update_query(obj, attribs, tablename, where_clause_tuples_lst=None):
"""Create an update query string to table with tablename for an object (obj)
for the attributes in attribs"""
set_arr = []
@@ -190,15 +159,20 @@ def gen_update_query(obj, attribs, tablename):
set_arr.append(f"{attr}={attr_vals[attr]}")
set_str = ','.join(set_arr)
- query = f"UPDATE {tablename} SET {set_str}"\
- f" WHERE id={obj.id};"
+ if where_clause_tuples_lst:
+ where_clause = ' AND '.join(f"{x}={y}" for x, y in where_clause_tuples_lst)
+ query = f"UPDATE {tablename} SET {set_str}"\
+ f" WHERE {where_clause};"
+ else:
+ query = f"UPDATE {tablename} SET {set_str}"\
+ f" WHERE id={obj.id};"
LOGGER.info('Query Update: %s', query)
return query
def gen_insert_query(obj, attribs, tablename):
"""create a select query and generate name space objects for the results"""
- attr_list = [attr for attr in attribs]
+ attr_list = list(attribs)
attr_list.remove('id')
attr_str = ','.join(attr_list)
@@ -214,11 +188,42 @@ def gen_insert_query(obj, attribs, tablename):
def gen_select_objs(session, attribs, tablename, cond_str):
"""create a select query and generate name space objects for the results"""
- attr_str = ','.join(attribs)
- query = f"SELECT {attr_str} FROM {tablename}"\
- f" {cond_str};"
+ ret = get_job_rows(session, attribs, tablename, cond_str)
+ entries = None
+
+ if ret:
+ entries = db_rows_to_obj(ret, attribs)
+
+ return entries
+
+
+def get_job_rows(session, attribs, tablename, cond_str):
+ """Get db rows"""
+ ret = None
+ if attribs is not None or attribs != []:
+ attr_str = ','.join(attribs)
+ else:
+ attr_str = '*'
+
+ if cond_str:
+ query = f"SELECT {attr_str} FROM {tablename}"\
+ f" {cond_str};"
+ else:
+ query = f"SELECT {attr_str} FROM {tablename};"
+
LOGGER.info('Query Select: %s', query)
- ret = session.execute(query)
+ try:
+ ret = session.execute(query)
+ except (Exception, KeyboardInterrupt) as ex: #pylint: disable=broad-except
+ LOGGER.warning(ex)
+ ret = None
+ session.rollback()
+
+ return ret
+
+
+def db_rows_to_obj(ret, attribs):
+ """Compose SimpleDict list of db jobs"""
entries = []
for row in ret:
#LOGGER.info('select_row: %s', row)
@@ -239,13 +244,17 @@ def has_attr_set(obj, attribs):
def get_class_by_tablename(tablename):
"""use tablename to find class"""
- for c in BASE._decl_class_registry.values():
- if hasattr(c, '__tablename__') and c.__tablename__ == tablename:
- return c
+ # pylint: disable=protected-access
+ for class_name in BASE._decl_class_registry.values():
+ if hasattr(class_name,
+ '__tablename__') and class_name.__tablename__ == tablename:
+ return class_name
+ return None
-def build_dict_val_key(obj: SimpleDict, exclude: List[str] = ['id']):
- """take object with to_dict function and create a key using values from the object's sorted keys"""
+def build_dict_val_key(obj: SimpleDict, exclude: List[str] = ['id']): # pylint: disable=W0102
+ """take object with to_dict function and create a key using values from the object's \
+ sorted keys"""
obj_dict = obj.to_dict()
for val in exclude:
obj_dict.pop(val, False)
diff --git a/tuna/utils/logger.py b/tuna/utils/logger.py
index 0b3cc98bb..1cf7a4024 100644
--- a/tuna/utils/logger.py
+++ b/tuna/utils/logger.py
@@ -27,11 +27,33 @@
"""logger file"""
import logging
import os
-
from typing import Union
+from logstash_async.handler import AsynchronousLogstashHandler
+from logstash_async.handler import LogstashFormatter
from tuna.utils.metadata import TUNA_LOG_DIR
+def get_logstash_config():
+ """retrieve env vars for logstash"""
+ logstash_status = os.getenv('TUNA_LOGSTASH_STATUS', 'false').lower() == 'true'
+ logstash_host = os.getenv('TUNA_LOGSTASH_HOST', 'localhost')
+ logstash_port = int(os.getenv('TUNA_LOGSTASH_PORT', "5000"))
+ logstash_path = os.getenv('TUNA_LOGSTASH_PATH', None)
+ return logstash_status, logstash_host, logstash_port, logstash_path
+
+
+def add_logstash_handler(logger: logging.Logger, host: str, port: int,
+ path: str):
+ """add logstash handker for logs streams"""
+ logstash_handler = AsynchronousLogstashHandler(host=host,
+ port=port,
+ database_path=path)
+ logstash_formatter = LogstashFormatter()
+ logstash_handler.setFormatter(logstash_formatter)
+ logstash_handler.setLevel(logging.INFO)
+ logger.addHandler(logstash_handler)
+
+
def setup_logger(logger_name: str = 'Tuna',
add_streamhandler: bool = True,
add_filehandler: bool = False) -> logging.Logger:
@@ -51,11 +73,21 @@ def setup_logger(logger_name: str = 'Tuna',
file_handler.setFormatter(formatter)
file_handler.setLevel(log_level.upper() if log_level else logging.INFO)
logger.addHandler(file_handler)
+
if add_streamhandler:
stream_handler: logging.StreamHandler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
stream_handler.setLevel(logging.INFO)
- logger.addHandler(stream_handler)
+ if not logger.hasHandlers():
+ logger.addHandler(stream_handler)
+
+ logstash_status, logstash_host, logstash_port, logstash_path = get_logstash_config(
+ )
+
+ if logstash_status:
+ add_logstash_handler(logger, logstash_host, logstash_port, logstash_path)
+ logger.info("Logstash is enabled. Sending logs to %s:%d", logstash_host,
+ logstash_port)
logger.setLevel(log_level.upper() if log_level else logging.DEBUG)
return logger
@@ -68,7 +100,7 @@ def set_usr_logger(logger_name: str) -> logging.Logger:
logging.RootLogger] = logging.getLogger(logger_name)
log_file: str = os.path.join(TUNA_LOG_DIR, logger_name + ".log")
fmt: logging.Formatter = logging.Formatter(
- '%(lineno)d - %(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ '%(lineno)d - %(asctime)s - %(name)s - %(message)s')
file_handler: logging.FileHandler = logging.FileHandler(log_file, mode='a')
file_handler.setFormatter(fmt)
file_handler.setLevel(log_level.upper() if log_level else logging.INFO)
@@ -77,5 +109,14 @@ def set_usr_logger(logger_name: str) -> logging.Logger:
stream_handler.setLevel(logging.INFO)
lgr.addHandler(file_handler)
lgr.addHandler(stream_handler)
+
+ logstash_status, logstash_host, logstash_port, logstash_path = get_logstash_config(
+ )
+
+ if logstash_status:
+ add_logstash_handler(lgr, logstash_host, logstash_port, logstash_path)
+ lgr.info("Logstash is enabled. Sending logs to %s:%d", logstash_host,
+ logstash_port)
+
lgr.setLevel(log_level.upper() if log_level else logging.DEBUG)
return lgr
diff --git a/tuna/utils/miopen_utility.py b/tuna/utils/machine_utility.py
similarity index 98%
rename from tuna/utils/miopen_utility.py
rename to tuna/utils/machine_utility.py
index 1e75c8a58..8be18c13a 100644
--- a/tuna/utils/miopen_utility.py
+++ b/tuna/utils/machine_utility.py
@@ -32,7 +32,7 @@
from tuna.utils.logger import setup_logger
from tuna.utils.db_utility import session_retry
-LOGGER = setup_logger('miopen_utility')
+LOGGER = setup_logger('machine_utility')
def load_machines(args, logger=LOGGER):
diff --git a/tuna/utils/utility.py b/tuna/utils/utility.py
index 9e6aa2a03..d07eae32c 100644
--- a/tuna/utils/utility.py
+++ b/tuna/utils/utility.py
@@ -138,9 +138,23 @@ def get_mmi_env_vars(env_vars={}):
return env_vars
+# pylint: disable=too-few-public-methods
+# pylint: disable=too-many-instance-attributes
class SimpleDict:
"""empty object"""
+ def __init__(self, **kwargs):
+ for key, value in kwargs.items():
+ if isinstance(value, dict):
+ setattr(self, key, self.from_dict(value))
+ else:
+ setattr(self, key, value)
+
+ @classmethod
+ def from_dict(cls, dict_obj):
+ """recreate object from dict"""
+ return cls(**dict_obj)
+
def to_dict(self, ommit_ts=True, ommit_valid=False):
"""return dict copy of object"""
ret = {}
@@ -164,3 +178,23 @@ def to_dict(self, ommit_ts=True, ommit_valid=False):
ret.pop('insert_ts')
return ret
+
+
+def serialize_job_config_row(elem):
+ """Serialize job row from DB, including its foreign keys"""
+ config_dict = {}
+ for key, value in elem[1].to_dict().items():
+ if isinstance(value, SimpleDict):
+ config_dict[key] = value.to_dict()
+ else:
+ config_dict[key] = value
+
+ return (elem[0].to_dict(), config_dict)
+
+
+def serialize_chunk(chunk):
+ """Serialize a list of tuple(job, configs) rows"""
+ result = []
+ for elem in chunk:
+ result.append(serialize_job_config_row(elem))
+ return result
diff --git a/tuna/worker_interface.py b/tuna/worker_interface.py
index b8bbeee73..854c45d65 100644
--- a/tuna/worker_interface.py
+++ b/tuna/worker_interface.py
@@ -42,7 +42,7 @@
import string
from io import StringIO
from time import sleep
-from typing import List, Tuple, Union, Set, Callable, Optional
+from typing import List, Tuple, Union, Set, Optional, Any, Dict
from sqlalchemy.exc import IntegrityError, OperationalError, NoInspectionAvailable
from sqlalchemy.inspection import inspect
@@ -57,6 +57,7 @@
from tuna.connection import Connection
from tuna.utils.utility import SimpleDict
from tuna.utils.logger import set_usr_logger
+from tuna.db.tuna_tables import JobMixin
class WorkerInterface(Process):
@@ -72,9 +73,9 @@ def __init__(self, **kwargs):
super().__init__()
allowed_keys: Set[str] = set([
- 'machine', 'gpu_id', 'num_procs', 'barred', 'bar_lock', 'envmt',
- 'reset_interval', 'job_queue', 'job_queue_lock', 'result_queue',
- 'result_queue_lock', 'label', 'fetch_state', 'end_jobs', 'session_id'
+ 'machine', 'gpu_id', 'num_procs', 'bar_lock', 'envmt', 'reset_interval',
+ 'job_queue', 'job_queue_lock', 'label', 'fetch_state', 'end_jobs',
+ 'session_id', 'job', 'config'
])
self.reset_interval: bool = None
@@ -83,18 +84,17 @@ def __init__(self, **kwargs):
#multiprocess vars
self.gpu_id: int = None
self.num_procs = None
- self.barred = None
self.bar_lock = Lock()
self.job_queue = None
self.job_queue_lock = Lock()
- self.result_queue = None
- self.result_queue_lock = Lock()
self.end_jobs = None
#job detail vars
self.envmt: List = []
- self.fetch_state: List = ['new']
+ self.fetch_state = set()
self.label: str = None
self.session_id: int = None
+ self.job: SimpleDict = None
+ self.config: dict = None
for key, value in kwargs.items():
if key in allowed_keys:
@@ -106,22 +106,23 @@ def __init__(self, **kwargs):
self.set_db_tables()
self.hostname: str = self.machine.hostname
- self.claim_num: int = self.num_procs.value * 3
+ self.claim_num: int = 1
self.last_reset: datetime = datetime.now()
dir_name: str = os.path.join(TUNA_LOG_DIR,
type(self).__name__,
f"{self.hostname}_{self.machine.port}p")
- if not os.path.exists(dir_name):
- os.makedirs(dir_name)
+ try:
+ if not os.path.exists(dir_name):
+ os.makedirs(dir_name)
+ except FileExistsError:
+ pass
logger_name: str = os.path.join(dir_name, str(self.gpu_id))
self.logger = set_usr_logger(logger_name)
connect_db()
- self.job: SimpleDict = SimpleDict()
-
try:
self.job_attr: List[str] = [
column.name for column in inspect(self.dbt.job_table).c
@@ -135,7 +136,7 @@ def __init__(self, **kwargs):
#also set cnx here in case WorkerInterface exec_command etc called directly
self.cnx: Connection = self.machine.connect(chk_abort_file)
- def step(self) -> bool:
+ def step(self) -> Optional[Dict[Any, Any]]: #type: ignore[override]
"""Regular run loop operation, to be overloaded in class specialization """
raise NotImplementedError("Not implemented")
@@ -148,6 +149,7 @@ def reset_machine(self) -> None:
self.machine.restart_server()
self.last_reset = datetime.now()
+ #deprecated
def compose_work_objs(self, session: DbSession,
conds: List[str]) -> List[Tuple[SimpleDict, ...]]:
"""Query a job list for update"""
@@ -165,6 +167,7 @@ def compose_work_objs(self, session: DbSession,
return [(job,) for job in entries]
+ #deprecated
def get_job_objs(self, session: DbSession,
find_state: str) -> List[Tuple[SimpleDict, ...]]:
"""Get list of job objects"""
@@ -175,16 +178,18 @@ def get_job_objs(self, session: DbSession,
conds.append(f"reason='{self.label}'")
conds.append(f"retries<{MAX_JOB_RETRIES}")
- conds.append(f"state='{find_state}'")
+ conds.append("state in (\"" + find_state + "\")")
entries = self.compose_work_objs(session, conds)
return entries
+ #deprecated
def queue_end_reset(self) -> None:
"""resets end queue flag"""
with self.bar_lock:
self.end_jobs.value = 0
+ #deprecated
def check_jobs_found(self, job_rows: List[SimpleDict], find_state: str,
imply_end: bool) -> bool:
"""check for end of jobs"""
@@ -198,6 +203,7 @@ def check_jobs_found(self, job_rows: List[SimpleDict], find_state: str,
return False
return True
+ #deprecated
def get_job_from_tuple(
self, job_tuple: Tuple[SimpleDict, ...]) -> Optional[SimpleDict]:
"""find job table in a job tuple"""
@@ -210,9 +216,11 @@ def get_job_from_tuple(
return tble
return None
+ #deprecated
def get_job_tables(
self, job_rows: List[Tuple[SimpleDict, ...]]) -> List[SimpleDict]:
"""find job tables in query results"""
+ #pylint:disable=duplicate-code
if has_attr_set(job_rows[0], self.job_attr):
job_tables: List[SimpleDict] = job_rows
else:
@@ -225,6 +233,7 @@ def get_job_tables(
job_tables = [row[job_i] for row in job_rows]
return job_tables
+ #deprecated
def job_queue_push(self, job_rows: List[Tuple[SimpleDict, ...]]) -> None:
"""load job_queue with info for job ids"""
job: SimpleDict
@@ -234,12 +243,14 @@ def job_queue_push(self, job_rows: List[Tuple[SimpleDict, ...]]) -> None:
job = self.get_job_from_tuple(job_tuple)
self.logger.info("Put job %s %s %s", job.id, job.state, job.reason)
+ #deprecated
def job_queue_pop(self) -> None:
"""load job from top of job queue"""
self.job = self.job_queue.get(True, 1)[0]
self.logger.info("Got job %s %s %s", self.job.id, self.job.state,
self.job.reason)
+ #deprecated
#pylint: disable=too-many-branches
def get_job(self, find_state: str, set_state: str, imply_end: bool) -> bool:
"""Interface function to get new job for builder/evaluator"""
@@ -302,6 +313,11 @@ def get_job(self, find_state: str, set_state: str, imply_end: bool) -> bool:
NUM_SQL_RETRIES, self.hostname, self.gpu_id)
return False
+ def set_job(self, job: JobMixin):
+ """Set worker job"""
+ self.job = job
+ self.job.gpu_id = self.gpu_id
+
#TODO_: This should take a session obj as an input to remove the creation of an extraneous
# session
def set_job_state(self,
@@ -353,6 +369,7 @@ def exec_command(self, cmd: str) -> Tuple[int, str, StringIO]:
if (ret_code != 0 or not out) and err:
self.logger.info('Error executing cmd: %s \n code: %u err: %s', cmd,
ret_code, err.read())
+ err.seek(0)
return ret_code, strout, err
@@ -369,22 +386,10 @@ def exec_docker_cmd(self, cmd: str) -> Tuple[int, str, StringIO]:
if (ret_code != 0 or not out) and err:
self.logger.info('Error executing cmd: %s \n code: %u err: %s', cmd,
ret_code, err.read())
+ err.seek(0)
return ret_code, strout, err
- def get_miopen_v(self) -> str:
- """Interface function to get new branch hash"""
- commit_hash: str
- _, commit_hash, _ = self.exec_docker_cmd(
- "cat /opt/rocm/include/miopen/version.h "
- "| grep MIOPEN_VERSION_TWEAK | cut -d ' ' -f 3")
- if "No such file" in commit_hash:
- _, commit_hash, _ = self.exec_docker_cmd(
- "cat /opt/rocm/miopen/include/miopen/version.h "
- "| grep MIOPEN_VERSION_TWEAK | cut -d ' ' -f 3")
- self.logger.info('Got branch commit hash: %s', commit_hash)
- return commit_hash
-
def get_rocm_v(self) -> str:
"""Interface function to get rocm version info"""
rocm_ver: str
@@ -399,57 +404,8 @@ def check_env(self) -> bool:
raise ValueError(
f'session rocm_v {self.dbt.session.rocm_v} does not match env rocm_v {env_rocm_v}'
)
- env_miopen_v: str = self.get_miopen_v()
- if self.dbt.session.miopen_v != env_miopen_v:
- raise ValueError(
- f'session miopen_v {self.dbt.session.miopen_v} does not match env miopen_v {env_miopen_v}'
- )
-
return True
- def set_barrier(self, funct: Callable, with_timeout: bool) -> bool:
- """Setting time barrier for Process to define execution timeout"""
- if self.barred.value == 0:
- # this is the first proc to reach the barrier
- with self.bar_lock:
- self.barred.value += 1
- self.logger.info('Waiting for other instances to pause')
- wait_cnt: int = 0
- timeout: bool = False
- while self.barred.value < self.num_procs.value:
- sleep(10)
- if with_timeout and self.barred.value == 1:
- wait_cnt += 1
- timeout = True
- if wait_cnt > 180:
- break
- if timeout:
- self.logger.warning(
- 'Timed out waiting for hung process, proceeding ... ')
- else:
- self.logger.info('Finished waiting for instances to pause')
- funct()
- with self.bar_lock:
- self.barred.value = 0
- return True
-
- return False
-
- def check_wait_barrier(self) -> bool:
- """Checking time barrier"""
- self.logger.info('Checking barrier')
- if self.barred.value != 0:
- self.logger.info('Blocked procs found')
- self.logger.info('Current barrier count: %s', self.barred.value)
- with self.bar_lock:
- self.barred.value += 1
- self.logger.warning('Waiting for processes to finish')
- while self.barred.value != 0:
- sleep(60)
- self.logger.warning('Finished waiting for processes')
- return True
- return False
-
def reset_job_state(self) -> None:
"""Helper function to reset job state during signal interrupt"""
#also filter pending states eg compiled_pend
@@ -472,24 +428,22 @@ def reset_job_state(self) -> None:
except queue.Empty:
break
- def run(self) -> bool: #type: ignore[override]
+ def run(self) -> dict: #type: ignore
"""
Main run function of WorkerInterface Process
#type: ignore[override] - parent class returns None type.
"""
+ ret = None
self.machine.set_logger(self.logger)
usage: float
try:
self.cnx = self.machine.connect(chk_abort_file)
while True:
- self.check_wait_barrier()
if chk_abort_file(self.machine.id, self.logger, self.machine.arch):
- with self.bar_lock:
- self.num_procs.value -= 1
- return False
+ return None #type: ignore
# re-establish node connection
usage = 0
@@ -497,15 +451,11 @@ def run(self) -> bool: #type: ignore[override]
usage = self.machine.getusedspace()
except (socket.timeout, socket.error):
usage = 0
- if not usage:
- self.set_barrier(self.reset_machine, True)
- continue
if usage > 90:
self.logger.warning('Used space overflow detected')
- self.set_barrier(lambda: (), True)
- continue
+ return False #type: ignore
# the step member is defined in the derived class
- ret: bool = self.step()
+ ret = self.step()
self.logger.info("proc %s step %s", self.gpu_id, ret)
if not ret:
self.logger.warning('No more steps, quitting...')
@@ -514,12 +464,12 @@ def run(self) -> bool: #type: ignore[override]
if hasattr(self, "any_failed") and self.any_failed:
sys.exit(1)
return True
+ return ret #type: ignore
except KeyboardInterrupt as err:
self.logger.error('%s', err)
self.reset_job_state()
- with self.bar_lock:
- self.num_procs.value -= 1
- return False
+
+ return ret #type: ignore
def run_command(self, cmd: str) -> Tuple[int, str]:
"""Run cmd and return ret_code"""
@@ -533,6 +483,7 @@ def run_command(self, cmd: str) -> Tuple[int, str]:
self.logger.error('Error executing command: %s', cmd)
if err:
err_str: str = err.read()
+ out = err_str
self.logger.error('%s : %s', ret_code, err_str)
if "disk I/O error" in err_str:
self.logger.error('fin retry : %u', i)
diff --git a/vars/utils.groovy b/vars/utils.groovy
index 50a67afff..484bb0338 100644
--- a/vars/utils.groovy
+++ b/vars/utils.groovy
@@ -42,25 +42,30 @@ def buildSchema(){
sh "${cmd} -e ${drop_sql}"
sh "${cmd} -e ${create_sql}"
sh "./tuna/miopen/db/build_schema.py"
+ sh "./tuna/example/build_schema.py"
}
def getDockerName(backend)
{
- def docker_registry = "${headnode}:5000"
- def tuna_docker_name = "${docker_registry}/ci-tuna:${branch_id}_${backend}"
+ def tuna_docker_name = "${docker_registry}:ci-tuna_${branch_id}_${backend}"
return tuna_docker_name
}
def buildDockers(){
- def tuna_docker_hipnogpu = docker.build(getDockerName("HIPNOGPU"), " --build-arg BACKEND=HIPNOGPU .")
- tuna_docker_hipnogpu.push()
- def tuna_docker_hip = docker.build(getDockerName("HIP"), " --build-arg BACKEND=HIP .")
- tuna_docker_hip.push()
+ docker.withRegistry('', "$DOCKER_CRED"){
+ def tuna_docker_hipnogpu = docker.build(getDockerName("HIPNOGPU"), " --build-arg BACKEND=HIPNOGPU .")
+ tuna_docker_hipnogpu.push()
+ def tuna_docker_hip = docker.build(getDockerName("HIP"), " --build-arg BACKEND=HIP .")
+ tuna_docker_hip.push()
+ }
}
def getDocker(backend){
- def tuna_docker = docker.image(getDockerName(backend))
- tuna_docker.pull()
+ def tuna_docker
+ docker.withRegistry('', "$DOCKER_CRED"){
+ tuna_docker = docker.image(getDockerName(backend))
+ tuna_docker.pull()
+ }
return tuna_docker
}
@@ -131,9 +136,9 @@ def finApplicability(){
env.PYTHONPATH=env.WORKSPACE
env.PATH="${env.WORKSPACE}/tuna:${env.PATH}"
- sh "./tuna/go_fish.py miopen --init_session -l new_session --arch gfx908 --num_cu 120"
+ sh "./tuna/go_fish.py miopen --init_session -l new_session --arch ${arch} --num_cu ${num_cu}"
def sesh1 = 1 //runsql("select id from session order by id asc limit 1")
- sh "./tuna/go_fish.py miopen --init_session -l new_session2 --arch gfx908 --num_cu 120"
+ sh "./tuna/go_fish.py miopen --init_session -l new_session2 --arch ${arch} --num_cu ${num_cu}"
def sesh2 = 2 //runsql("select id from session order by id desc limit 1")
sh "./tuna/go_fish.py miopen import_configs --add_model Alexnet --md_version 1"
@@ -158,6 +163,7 @@ def finApplicability(){
error("Unable to get applicability from Fin for convolution")
}
+ /*
sh "./tuna/go_fish.py miopen import_configs -t recurrent_${branch_id}_bn --mark_recurrent -f utils/configs/batch_norm.txt -C batch_norm --model Resnet50 --md_version 1 --framework Pytorch --fw_version 1"
runsql("TRUNCATE table bn_solver_applicability")
def num_bn = runsql("SELECT count(*) from bn_config;")
@@ -168,11 +174,11 @@ def finApplicability(){
println "Count(*) bn_solver_applicability table: ${num_sapp_bn}"
if (num_sapp_bn.toInteger() == 0){
error("Unable to get applicability from Fin for batch norm")
- }
+ }*/
}
}
-def finFindCompile(){
+def finFindCompileEnqueue(){
def tuna_docker = getDocker("HIPNOGPU")
tuna_docker.inside("--network host --dns 8.8.8.8 ") {
env.TUNA_DB_HOSTNAME = "${db_host}"
@@ -184,50 +190,40 @@ def finFindCompile(){
env.gateway_port = "${gateway_port}"
env.gateway_user = "${gateway_user}"
env.PATH="${env.WORKSPACE}/tuna:${env.PATH}"
+ env.TUNA_CELERY_BROKER_HOST="${db_host}"
def sesh1 = runsql("select id from session order by id asc limit 1")
+ celery_log="${env.WORKSPACE}/tuna/${branch_id}_find_compile_celery_log.log"
+ sh "touch ${celery_log}"
sh "./tuna/go_fish.py miopen import_configs -t recurrent_${branch_id} --mark_recurrent -f utils/recurrent_cfgs/alexnet_4jobs.txt --model Resnet50 --md_version 1 --framework Pytorch --fw_version 1"
- def num_cfg = runsql("SELECT count(*) from conv_config;")
- println "Count(*) conv_config table: ${num_cfg}"
runsql("delete from conv_job;")
runsql("alter table conv_job AUTO_INCREMENT=1;")
- sh "./tuna/go_fish.py miopen load_job -l finFind_${branch_id} --all_configs --fin_steps \"miopen_find_compile,miopen_find_eval\" --session_id ${sesh1} ${job_lim}"
+
+ sh "./tuna/go_fish.py miopen import_configs -t recurrent_${branch_id} --mark_recurrent -f utils/configs/conv_configs_NHWC.txt --model Resnet50 --md_version 1 --framework Pytorch --fw_version 1"
+
+ sh "./tuna/go_fish.py miopen import_configs -t recurrent_${branch_id} --mark_recurrent -f utils/configs/conv_configs_NCHW.txt --model Resnet50 --md_version 1 --framework Pytorch --fw_version 1"
+ def num_cfg = runsql("SELECT count(*) from conv_config;")
+ println "Count(*) conv_config table: ${num_cfg}"
+ sh "./tuna/go_fish.py miopen load_job -l finFind_${branch_id} -t recurrent_${branch_id} --fin_steps \"miopen_find_compile,miopen_find_eval\" --session_id ${sesh1} ${job_lim}"
+
+ sh "printenv"
def num_jobs = runsql("SELECT count(*) from conv_job WHERE reason = 'finFind_${branch_id}';").toInteger()
- sh "./tuna/go_fish.py miopen --fin_steps miopen_find_compile -l finFind_${branch_id} --session_id ${sesh1}"
+ def pid = sh(script: "celery -A tuna.celery_app.celery_app worker -l debug --logfile=${celery_log} -n tuna_${branch_id} -Q compile_q_${db_name}_sess_${sesh1} & echo \$!", returnStdout: true).trim()
+ sh "cat ${celery_log}"
+
+ sh "printenv"
+ sh "./tuna/go_fish.py miopen --fin_steps miopen_find_compile -l finFind_${branch_id} --session_id ${sesh1} --enqueue_only"
+
+ sh "kill -9 ${pid}"
+ sh "cat ${celery_log}"
def num_compiled_jobs = runsql("SELECT count(*) from conv_job WHERE reason = 'finFind_${branch_id}' AND state = 'compiled';").toInteger()
sh "echo ${num_compiled_jobs} == ${num_jobs}"
if (num_compiled_jobs != num_jobs){
error("Unable to compile find jobs using Fin")
}
-
- sh "./tuna/go_fish.py miopen import_configs -t recurrent_${branch_id}_nhwc --mark_recurrent -f utils/configs/conv_configs_NHWC.txt --model Resnet50 --md_version 1 --framework Pytorch --fw_version 1"
- def num_cfg_nhwc = runsql("SELECT count(*) from conv_config;")
- println "Count(*) conv_config table: ${num_cfg_nhwc}"
- //runsql("delete from conv_job;")
- //runsql("alter table conv_job AUTO_INCREMENT=1;")
- sh "./tuna/go_fish.py miopen load_job -l finFind_${branch_id}_nhwc -t recurrent_${branch_id}_nhwc --fin_steps \"miopen_find_compile,miopen_find_eval\" --session_id ${sesh1} ${job_lim}"
- def num_jobs_nhwc = runsql("SELECT count(*) from conv_job WHERE reason = 'finFind_${branch_id}_nhwc';").toInteger()
- sh "./tuna/go_fish.py miopen --fin_steps miopen_find_compile -l finFind_${branch_id}_nhwc --session_id ${sesh1}"
- def num_compiled_jobs_nhwc = runsql("SELECT count(*) from conv_job WHERE reason = 'finFind_${branch_id}_nhwc' AND state = 'compiled';").toInteger()
- sh "echo ${num_compiled_jobs_nhwc} == ${num_jobs_nhwc}"
- if (num_compiled_jobs_nhwc != num_jobs_nhwc){
- error("Unable to compile find jobs using Fin")
- }
- sh "./tuna/go_fish.py miopen import_configs -t recurrent_${branch_id}_nchw --mark_recurrent -f utils/configs/conv_configs_NCHW.txt --model Resnet50 --md_version 1 --framework Pytorch --fw_version 1"
- def num_cfg_nchw = runsql("SELECT count(*) from conv_config;")
- println "Count(*) conv_config table: ${num_cfg_nchw}"
- sh "./tuna/go_fish.py miopen load_job -l finFind_${branch_id}_nchw -t recurrent_${branch_id}_nchw --fin_steps \"miopen_find_compile,miopen_find_eval\" --session_id ${sesh1} ${job_lim}"
- def num_jobs_nchw = runsql("SELECT count(*) from conv_job WHERE reason = 'finFind_${branch_id}_nchw';").toInteger()
- sh "./tuna/go_fish.py miopen --fin_steps miopen_find_compile -l finFind_${branch_id}_nchw --session_id ${sesh1}"
- def num_compiled_jobs_nchw = runsql("SELECT count(*) from conv_job WHERE reason = 'finFind_${branch_id}_nchw' AND state = 'compiled';").toInteger()
- sh "echo ${num_compiled_jobs_nchw} == ${num_jobs_nchw}"
- if (num_compiled_jobs_nchw != num_jobs_nchw){
- error("Unable to compile find jobs using Fin")
- }
}
}
-
def finFindEval(){
def tuna_docker = getDocker("HIP")
tuna_docker.inside("--network host --dns 8.8.8.8 ${docker_args}") {
@@ -240,54 +236,53 @@ def finFindEval(){
env.gateway_user = "${gateway_user}"
env.PYTHONPATH=env.WORKSPACE
env.PATH="${env.WORKSPACE}/tuna:${env.PATH}"
+ env.TUNA_CELERY_BROKER_HOST="${db_host}"
def sesh1 = runsql("select id from session order by id asc limit 1")
+ def pids = []
def num_jobs = runsql("SELECT count(*) from conv_job WHERE reason = 'finFind_${branch_id}' AND state = 'compiled';").toInteger()
- sh "./tuna/go_fish.py miopen --fin_steps miopen_find_eval -l finFind_${branch_id} --session_id ${sesh1}"
+
+ def num_gpus = sh(script: "/opt/rocm/bin/rocminfo | grep ${arch}:sramecc+:xnack | wc -l", returnStdout: true).trim()
+ num_gpus = num_gpus as Integer
+ sh "echo #GPUs: ${num_gpus}"
+ def gpu_list = (0..(num_gpus-1)).toList()
+ sh "echo ${gpu_list}"
+ def counter = 0
+ def pid_list = []
+
+ sh "printenv"
+ // &=046
+ gpu_list.each{
+ celery_log="${env.WORKSPACE}/tuna/${branch_id}_find_eval_celery_log_${counter}.log"
+ sh "touch ${celery_log}"
+ def proc_id = sh(script: "celery -A tuna.celery_app.celery_app worker -l debug --logfile=${celery_log} -n tuna_${branch_id}_gpu_id_${counter} -Q eval_q_${db_name}_sess_${sesh1} -c 1 2>\0461 1>/dev/null & echo \$!", returnStdout: true).trim()
+ sh "cat ${celery_log}"
+ pid_list.add(proc_id)
+ counter++
+ }
+
+ sh "./tuna/go_fish.py miopen --fin_steps miopen_find_eval -l finFind_${branch_id} --session_id ${sesh1} --enqueue_only"
+ //killing off celery workers by pid
+ pid_list.each{
+ try{
+ sh "kill -9 ${it}"
+ } catch (Exception err) {
+ sh "echo ${err}"
+ }
+ }
+
def num_evaluated_jobs = runsql("SELECT count(*) from conv_job WHERE reason = 'finFind_${branch_id}' AND state = 'evaluated';").toInteger()
sh "echo ${num_evaluated_jobs} == ${num_jobs}"
if (num_evaluated_jobs != num_jobs){
error("Unable to evaluate find jobs using Fin")
}
+
def MIOPEN_BRANCH = runsql("SELECT miopen_v from session WHERE id=1;")
def fdb_file = sh(script: "./tuna/go_fish.py miopen export_db -a ${arch} -n ${num_cu} -f --session_id ${sesh1}", returnStdout: true)
- archiveArtifacts "${fdb_file}"
+ archiveArtifacts "${fdb_file}"
def kdb_file = sh(script: "./tuna/go_fish.py miopen export_db -a ${arch} -n ${num_cu} -k --session_id ${sesh1}", returnStdout: true)
archiveArtifacts "${kdb_file}"
- def num_jobs_nhwc = runsql("SELECT count(*) from conv_job WHERE reason = 'finFind_${branch_id}_nhwc' AND state = 'compiled';").toInteger()
- sh "./tuna/go_fish.py miopen --fin_steps miopen_find_eval -l finFind_${branch_id}_nhwc --session_id ${sesh1}"
- def fdb_file_nhwc = sh(script: "./tuna/go_fish.py miopen export_db -a ${arch} -n ${num_cu} -f --session_id ${sesh1} --filename fdb_nhwc", returnStdout: true)
- def num_evaluated_jobs_nhwc = runsql("SELECT count(*) from conv_job WHERE reason = 'finFind_${branch_id}_nhwc' AND state = 'evaluated';").toInteger()
- sh "echo ${num_evaluated_jobs_nhwc} == ${num_jobs_nhwc}"
- if (num_evaluated_jobs_nhwc != num_jobs_nhwc){
- error("Unable to evaluate find jobs using Fin")
- }
-
- archiveArtifacts "${fdb_file_nhwc}"
- def kdb_file_nhwc = sh(script: "./tuna/go_fish.py miopen export_db -a ${arch} -n ${num_cu} -k --session_id ${sesh1} --filename kdb_nhwc", returnStdout: true)
- archiveArtifacts "${kdb_file_nhwc}"
-
- def num_jobs_nchw = runsql("SELECT count(*) from conv_job WHERE reason = 'finFind_${branch_id}_nchw' AND state = 'compiled';").toInteger()
- sh "./tuna/go_fish.py miopen --fin_steps miopen_find_eval -l finFind_${branch_id}_nchw --session_id ${sesh1}"
- def fdb_file_nchw = sh(script: "./tuna/go_fish.py miopen export_db -a ${arch} -n ${num_cu} -f --session_id ${sesh1}", returnStdout: true)
- def num_evaluated_jobs_nchw = runsql("SELECT count(*) from conv_job WHERE reason = 'finFind_${branch_id}_nchw' AND state = 'evaluated';").toInteger()
- sh "echo ${num_evaluated_jobs_nchw} == ${num_jobs_nchw}"
- if (num_evaluated_jobs_nchw != num_jobs_nchw){
- error("Unable to evaluate find jobs using Fin")
- }
-
- archiveArtifacts "${fdb_file_nchw}"
- def kdb_file_nchw = sh(script: "./tuna/go_fish.py miopen export_db -a ${arch} -n ${num_cu} -k --session_id ${sesh1}", returnStdout: true)
- archiveArtifacts "${kdb_file_nchw}"
- }
-}
-def buildTunaDocker(){
- // The purpose of this job is to ensure that the Tuna Docker is uptodate on the eval/build machine for the CI jobs
- checkout scm
- def tuna_docker = docker.build("ci-tuna:${branch_id}")
- tuna_docker.inside("--network host "){
- sh "pwd"
}
}
@@ -326,7 +321,7 @@ def loadJobTest() {
assert out_bn.toInteger() > 0
sh "./tuna/go_fish.py miopen load_job -t batch_norm_test -l batch_norm_test -C batch_norm --session_id ${sesh2}"
out_bn = runsql("SELECT count(*) FROM bn_job WHERE reason='batch_norm_test' and session=${sesh2} ;")
- assert out_bn.toInteger() > 0
+ //assert out_bn.toInteger() > 0
//reset jobs and test load solver
runsql("DELETE FROM conv_job;")
@@ -358,13 +353,15 @@ def solverAnalyticsTest(){
// install SolverAnalytics
sh "rm -rf SolverAnalytics"
sh "git clone https://${FIN_TOKEN}:x-oauth-basic@github.com/ROCmSoftwarePlatform/SolverAnalytics.git"
- sh "pip3 install --default-timeout=100000 -r SolverAnalytics/requirements.txt"
+ sh "cd SolverAnalytics; git checkout sp/solver_changes; git pull;"
+ //lower version in requirments file causing issues in ci
+ //sh "pip3 install --default-timeout=100000 -r SolverAnalytics/requirements.txt"
// run SolverAnalytics tests
sh "python3 ./SolverAnalytics/tests/clean_finddb_test.py"
sh "python3 ./SolverAnalytics/tests/cli_test.py"
sh "python3 ./SolverAnalytics/tests/generate_analytics_test.py"
- sh "python3 ./SolverAnalytics/tests/get_finddb_test.py"
+ //sh "python3 ./SolverAnalytics/tests/get_finddb_test.py"
sh "python3 ./SolverAnalytics/tests/utils_test/df_tools_test.py"
sh "python3 ./SolverAnalytics/tests/utils_test/fdb_key_utils_test.py"
sh "python3 ./SolverAnalytics/tests/utils_test/helpers_test.py"
@@ -382,39 +379,52 @@ def perfCompile() {
env.gateway_ip = "${gateway_ip}"
env.gateway_port = "${gateway_port}"
env.gateway_user = "${gateway_user}"
- env.TUNA_DOCKER_NAME="ci-tuna:${branch_id}"
+ env.TUNA_DOCKER_NAME="ci-tuna_${branch_id}"
env.PYTHONPATH=env.WORKSPACE
env.PATH="${env.WORKSPACE}/tuna:${env.PATH}"
+ env.TUNA_CELERY_BROKER_HOST="${db_host}"
runsql("DELETE FROM conv_job;")
def sesh1 = runsql("select id from session order by id asc limit 1")
+ celery_log="${env.WORKSPACE}/tuna/${branch_id}_perf_compile_celery_log.log"
+ sh "touch ${celery_log}"
+ def pid = sh(script: "celery -A tuna.celery_app.celery_app worker -l debug -E --detach --logfile=${celery_log} -n tuna_${branch_id} -Q compile_q_${db_name}_sess_${sesh1} & echo \$!", returnStdout: true).trim()
+ sh "echo ${pid}"
+ sh "cat ${celery_log}"
sh "./tuna/go_fish.py miopen import_configs -t alexnet_${branch_id} --mark_recurrent -f utils/recurrent_cfgs/alexnet_4jobs.txt --model Resnet50 --md_version 1 --framework Pytorch --fw_version 1"
sh "./tuna/go_fish.py miopen load_job -t alexnet_${branch_id} -l alexnet_${branch_id} --session_id ${sesh1} --fin_steps miopen_perf_compile,miopen_perf_eval ${job_lim}"
// Get the number of jobs
def num_jobs = runsql("SELECT count(*) from conv_job where state = 'new' and reason = 'alexnet_${branch_id}'");
- sh "./tuna/go_fish.py miopen --fin_steps miopen_perf_compile -l alexnet_${branch_id} --session_id ${sesh1}"
+ sh "./tuna/go_fish.py miopen --fin_steps miopen_perf_compile -l alexnet_${branch_id} --session_id ${sesh1} --enqueue_only"
+ sh "kill -9 ${pid}"
def compiled_jobs = runsql("SELECT count(*) from conv_job where state = 'compiled' and reason = 'alexnet_${branch_id}';")
if(compiled_jobs.toInteger() == 0)
{
error("Unable to compile any jobs for alexnet")
}
+ def pid2 = sh(script: "celery -A tuna.celery_app.celery_app worker -l debug -E --detach --logfile=${celery_log} -n tuna_${branch_id} -Q compile_q_${db_name}_sess_${sesh1} & echo \$!", returnStdout: true).trim()
+ sh "echo ${pid2}"
+ sh "cat ${celery_log}"
+
sh "./tuna/go_fish.py miopen import_configs -t conv_${branch_id}_v2 --mark_recurrent -f utils/configs/conv_configs_NHWC.txt --model Resnet50 --md_version 1 --framework Pytorch --fw_version 1"
sh "./tuna/go_fish.py miopen import_configs -t conv_${branch_id}_v2 --mark_recurrent -f utils/configs/conv_configs_NCHW.txt --model Resnet50 --md_version 1 --framework Pytorch --fw_version 1"
sh "./tuna/go_fish.py miopen load_job -t conv_${branch_id}_v2 -l conv_${branch_id}_v2 --session_id ${sesh1} --fin_steps miopen_perf_compile,miopen_perf_eval ${job_lim}"
// Get the number of jobs
def num_conv_jobs = runsql("SELECT count(*) from conv_job where state = 'new' and reason = 'conv_${branch_id}_v2'");
- sh "./tuna/go_fish.py miopen --fin_steps miopen_perf_compile -l conv_${branch_id}_v2 --session_id ${sesh1}"
+ sh "./tuna/go_fish.py miopen --fin_steps miopen_perf_compile -l conv_${branch_id}_v2 --session_id ${sesh1} --enqueue_only"
+ sh "kill -9 ${pid2}"
def compiled_conv_jobs = runsql("SELECT count(*) from conv_job where state = 'compiled' and reason = 'conv_${branch_id}_v2';")
if(compiled_conv_jobs.toInteger() == 0)
{
error("Unable to compile any conv jobs")
}
echo "${compiled_conv_jobs}"
+
}
}
-def perfEval_gfx908() {
+def perfEval() {
def tuna_docker = getDocker("HIP")
tuna_docker.inside("--network host --dns 8.8.8.8 ${docker_args} ") {
env.TUNA_DB_HOSTNAME = "${db_host}"
@@ -424,21 +434,90 @@ def perfEval_gfx908() {
env.gateway_ip = "${gateway_ip}"
env.gateway_port = "${gateway_port}"
env.gateway_user = "${gateway_user}"
- env.TUNA_DOCKER_NAME="ci-tuna:${branch_id}"
+ env.TUNA_DOCKER_NAME="ci-tuna_${branch_id}"
env.PYTHONPATH=env.WORKSPACE
env.PATH="${env.WORKSPACE}/tuna:${env.PATH}"
+ env.TUNA_CELERY_BROKER_HOST="${db_host}"
def sesh1 = runsql("select id from session order by id asc limit 1")
def compiled_jobs = runsql("SELECT count(*) from conv_job where state = 'compiled' and reason = 'alexnet_${branch_id}';")
- sh "./tuna/go_fish.py miopen --fin_steps miopen_perf_eval -l alexnet_${branch_id} --session_id ${sesh1}"
+ def num_gpus = sh(script: "/opt/rocm/bin/rocminfo | grep ${arch}:sramecc+:xnack | wc -l", returnStdout: true).trim()
+ num_gpus = num_gpus as Integer
+ sh "echo #GPUs: ${num_gpus}"
+ def gpu_list = (1..(num_gpus-1)).toList()
+ sh "echo ${gpu_list}"
+ def counter = 0
+ def pid_list = []
+ def celery_log_list = []
+
+ sh "printenv"
+ // &=046
+ gpu_list.each{
+ celery_log="${env.WORKSPACE}/tuna/${branch_id}_perf_eval_celery_log_${counter}.log"
+ celery_log_list.add(celery_log)
+ sh "touch ${celery_log}"
+ def proc_id = sh(script: "celery -A tuna.celery_app.celery_app worker -l debug --logfile=${celery_log} -n tuna_${branch_id}_gpu_id_${counter} -Q eval_q_${db_name}_sess_${sesh1} -c 1 2>\0461 1>/dev/null & echo \$!", returnStdout: true).trim()
+ //sh "cat ${celery_log}"
+ pid_list.add(proc_id)
+ counter++
+ }
+
+ sh "./tuna/go_fish.py miopen --fin_steps miopen_perf_eval -l alexnet_${branch_id} --session_id ${sesh1} --enqueue_only"
def eval_jobs = runsql("SELECT count(*) from conv_job where state = 'evaluated' and reason = 'alexnet_${branch_id}';")
if(eval_jobs.toInteger() != compiled_jobs.toInteger())
{
error("Unable to eval all jobs for alexnet")
}
+ pid_list.each{
+ try{
+ sh "kill -9 ${it}"
+ } catch (Exception err) {
+ sh "echo ${err}"
+ }
+ }
+
+ celery_log_list.each{
+ try{
+ sh "cat ${it}"
+ } catch (Exception err) {
+ sh "echo ${err}"
+ }
+ }
+
def compiled_conv_jobs = runsql("SELECT count(*) from conv_job where reason = 'conv_${branch_id}_v2' and state = 'compiled';")
- sh "./tuna/go_fish.py miopen --fin_steps miopen_perf_eval -l conv_${branch_id}_v2 --session_id ${sesh1}"
+
+ counter = 0
+ pid_list = []
+ celery_log_list = []
+
+ gpu_list.each{
+ celery_log="${env.WORKSPACE}/tuna/${branch_id}_perf_eval_celery_log_${counter}.log"
+ celery_log_list.add(celery_log)
+ sh "touch ${celery_log}"
+ def proc_id = sh(script: "celery -A tuna.celery_app.celery_app worker -l debug --logfile=${celery_log} -n tuna_${branch_id}_gpu_id_${counter} -Q eval_q_${db_name}_sess_${sesh1} -c 1 2>\0461 1>/dev/null & echo \$!", returnStdout: true).trim()
+ pid_list.add(proc_id)
+ counter++
+ }
+ sh "./tuna/go_fish.py miopen --fin_steps miopen_perf_eval -l conv_${branch_id}_v2 --session_id ${sesh1} --enqueue_only"
+
+
+ pid_list.each{
+ try{
+ sh "kill -9 ${it}"
+ } catch (Exception err) {
+ sh "echo ${err}"
+ }
+ }
+
+ celery_log_list.each{
+ try{
+ sh "cat ${it}"
+ } catch (Exception err) {
+ sh "echo ${err}"
+ }
+ }
+
def eval_conv_jobs = runsql("SELECT count(*) from conv_job where reason = 'conv_${branch_id}_v2' and state = 'evaluated';")
def errored_conv_jobs = runsql("SELECT count(*) from conv_job where reason = 'conv_${branch_id}_v2' and state = 'errored';")
if(eval_conv_jobs.toInteger() != compiled_conv_jobs.toInteger())
@@ -452,6 +531,7 @@ def perfEval_gfx908() {
def last_gold_v = runsql("SELECT max(golden_miopen_v) from conv_golden;")
def next_gold_v = last_gold_v.toInteger() + 1
sh "./tuna/go_fish.py miopen update_golden --session_id ${sesh1} --golden_v ${next_gold_v} --base_golden_v ${last_gold_v}"
+
def golden_entries = runsql("SELECT count(*) from conv_golden where session= ${sesh1};")
def fdb_entries = runsql("SELECT count(*) from conv_golden where session= ${sesh1};")
if(golden_entries.toInteger() != fdb_entries.toInteger())
@@ -473,7 +553,7 @@ def pytestSuite1() {
env.gateway_ip = "${gateway_ip}"
env.gateway_port = "${gateway_port}"
env.gateway_user = "${gateway_user}"
- env.TUNA_DOCKER_NAME="ci-tuna:${branch_id}_pytest1"
+ env.TUNA_DOCKER_NAME="ci-tuna_${branch_id}_pytest1"
env.PYTHONPATH=env.WORKSPACE
env.PATH="${env.WORKSPACE}/tuna:${env.PATH}"
addMachine(arch, num_cu, machine_ip, machine_local_ip, username, pwd, port)
@@ -503,6 +583,8 @@ def pytestSuite1() {
sh "python3 -m coverage run -a -m pytest tests/test_importconfigs_rocmlir.py -s"
sh "python3 -m coverage run -a -m pytest tests/test_load_job_rocmlir.py -s"
sh "python3 -m coverage run -a -m pytest tests/test_rocmlir.py -s"
+ sh "python3 -m coverage run -a -m pytest tests/test_helper.py -s"
+ sh "python3 -m coverage run -a -m pytest tests/test_mituna_interface.py -s"
// The OBMC host used in the following test is down
// sh "pytest tests/test_mmi.py "
}
@@ -521,7 +603,7 @@ def pytestSuite2() {
env.gateway_ip = "${gateway_ip}"
env.gateway_port = "${gateway_port}"
env.gateway_user = "${gateway_user}"
- env.TUNA_DOCKER_NAME="ci-tuna:${branch_id}_pytest2"
+ env.TUNA_DOCKER_NAME="ci-tuna_${branch_id}_pytest2"
env.PYTHONPATH=env.WORKSPACE
env.PATH="${env.WORKSPACE}/tuna:${env.PATH}"
@@ -532,14 +614,15 @@ def pytestSuite2() {
// test fin builder and test fin builder conv in sequence
sh "python3 -m coverage run -a -m pytest tests/test_worker.py -s"
sh "TUNA_LOGLEVEL=INFO python3 -m coverage run -a -m pytest tests/test_fin_builder.py -s"
+ sh "TUNA_LOGLEVEL=INFO python3 -m coverage run -a -m pytest tests/test_celery.py -s"
}
sh "coverage report -m"
}
}
-def pytestSuite3AndCoverage(current_run, main_branch) {
+def pytestSuite3() {
def tuna_docker = getDocker("HIP")
- tuna_docker.inside("--network host --dns 8.8.8.8") {
+ tuna_docker.inside("--network host --dns 8.8.8.8 ${docker_args} ") {
env.TUNA_DB_HOSTNAME = "${db_host}"
env.TUNA_DB_NAME="${db_name}"
env.TUNA_DB_USER_NAME="${db_user}"
@@ -549,12 +632,35 @@ def pytestSuite3AndCoverage(current_run, main_branch) {
env.gateway_user = "${gateway_user}"
env.PYTHONPATH=env.WORKSPACE
env.PATH="${env.WORKSPACE}/tuna:${env.PATH}"
+
+ //addMachine(arch, num_cu, machine_ip, machine_local_ip, username, pwd, port)
+
sshagent (credentials: ['bastion-ssh-key']) {
sh "python3 -m coverage run -a -m pytest tests/test_fin_evaluator.py -s"
sh "python3 -m coverage run -a -m pytest tests/test_update_golden.py -s"
}
sh "coverage report -m"
+ }
+}
+
+
+def Coverage(current_run, main_branch) {
+ def tuna_docker = getDocker("HIP")
+ tuna_docker.inside("--network host --dns 8.8.8.8") {
+ env.TUNA_DB_HOSTNAME = "${db_host}"
+ env.TUNA_DB_NAME="${db_name}"
+ env.TUNA_DB_USER_NAME="${db_user}"
+ env.TUNA_DB_PASSWORD="${db_password}"
+ env.gateway_ip = "${gateway_ip}"
+ env.gateway_port = "${gateway_port}"
+ env.gateway_user = "${gateway_user}"
+ env.PYTHONPATH=env.WORKSPACE
+ env.PATH="${env.WORKSPACE}/tuna:${env.PATH}"
+ sh "coverage report -m"
sh "python3 -m coverage json"
+ sh "coverage html"
+ sh "tar -cjf htmlcov.bz2.tar htmlcov/"
+ archiveArtifacts "htmlcov.bz2.tar"
if (current_run == main_branch) {
sh "python3 tests/covscripts/coverage.py ${main_branch}"
archiveArtifacts artifacts: "${env.COVERAGE_ARTIFACT_FILE_NAME}", allowEmptyArchive: true, fingerprint: true
@@ -579,7 +685,8 @@ def runFormat() {
checkout scm
def tuna_docker = getDocker("HIP")
tuna_docker.inside("") {
- sh "yapf -d -r --style='{based_on_style: google, indent_width: 2}' tuna/ tests/ alembic/"
+ //yapf bug causes it to complain when aioredis await is present
+ sh "yapf -d -r --style='{based_on_style: google, indent_width: 2}' tuna/ tests/ alembic/ --exclude=tests/test_celery.py"
}
}
}
@@ -589,7 +696,7 @@ def runLint() {
checkout scm
def tuna_docker = getDocker("HIP")
tuna_docker.inside("") {
- sh "cd tuna && pylint -f parseable --max-args=8 --ignore-imports=no --indent-string=' ' *.py miopen/*.py example/*.py rocmlir/*.py"
+ sh "cd tuna && pylint -f parseable --max-args=8 --ignore-imports=no --indent-string=' ' *.py miopen/*.py example/*.py rocmlir/*.py utils/*.py miopen/celery_tuning/*.py"
sh "cd tuna && find miopen/scripts/ -type f -name '*.py' | xargs pylint -f parseable --max-args=8 --ignore-imports=no --indent-string=' '"
sh "cd tuna && find miopen/driver/ -type f -name '*.py' | xargs pylint -f parseable --max-args=8 --ignore-imports=no --indent-string=' '"
sh "cd tuna && find miopen/worker/ -type f -name '*.py' | xargs pylint -f parseable --max-args=8 --ignore-imports=no --indent-string=' '"
@@ -609,11 +716,10 @@ def runLint() {
sh "mypy tuna/miopen/subcmd/export_db.py --ignore-missing-imports --follow-imports=skip"
sh "mypy tuna/miopen/subcmd/update_golden.py --ignore-missing-imports --follow-imports=skip"
sh "mypy tuna/miopen/parse_miopen_args.py --ignore-missing-imports --follow-imports=skip"
- sh "mypy tuna/miopen/driver/convolution.py --ignore-missing-imports"
+ sh "mypy tuna/miopen/driver/convolution.py --ignore-missing-imports --follow-imports=skip"
sh "mypy tuna/yaml_parser.py --ignore-missing-imports --follow-imports=skip"
- sh "mypy tuna/flask_example.py --ignore-missing-imports --follow-imports=skip"
sh "mypy tuna/go_fish.py --ignore-missing-imports --follow-imports=skip"
- sh "mypy tuna/miopen/driver/batchnorm.py --ignore-missing-imports"
+ sh "mypy tuna/miopen/driver/batchnorm.py --ignore-missing-imports --follow-imports=skip"
sh "mypy tuna/miopen/worker/fin_class.py --ignore-missing-imports --follow-imports=skip"
sh "mypy tuna/miopen/worker/fin_eval.py --ignore-missing-imports --follow-imports=skip"
sh "mypy tuna/miopen/worker/fin_utils.py --ignore-missing-imports --follow-imports=skip"
@@ -659,7 +765,11 @@ def getJobReason()
def killContainer() {
- sh "srun --no-kill -p ${partition} -N 1-10 -l bash -c 'docker container list | grep ${tuna_docker_name} | sed \"s# #^#g\" | tr -s ^ | cut -d ^ -f 6 | xargs -I _ docker container kill _'"
+ def tuna_docker_name = getDockerName("${backend}")
+ sh "docker container list | grep ${tuna_docker_name} | sed \"s# #^#g\" | tr -s ^ | cut -d ^ -f 6 | xargs -I _ docker kill --signal=\"SIGINT\" _"
+ sh "docker container list | grep ${tuna_docker_name} | sed \"s# #^#g\" | tr -s ^ | cut -d ^ -f 6 | xargs -I _ docker wait _"
+ sh "docker system prune -f"
+ //sh "srun --no-kill -p ${partition} -N 1-10 -l bash -c 'docker container list | grep ${tuna_docker_name} | sed \"s# #^#g\" | tr -s ^ | cut -d ^ -f 6 | xargs -I _ docker container kill _'"
sh "srun --no-kill -p ${partition} -N 1-10 -l bash -c 'docker system prune -f'"
}
@@ -706,36 +816,43 @@ def LoadJobs()
def docker_run_args = "--network host --dns 8.8.8.8 -e TUNA_DB_HOSTNAME=${db_host} -e TUNA_DB_NAME=${params.db_name} -e TUNA_DB_USER_NAME=${db_user} -e TUNA_DB_PASSWORD=${db_password} -e gateway_ip=${gateway_ip} -e gateway_port=${gateway_port} -e gateway_user=${gateway_user} -e TUNA_LOGLEVEL=${params.tuna_loglevel}"
sh "echo ${build_args}"
- def tuna_docker = docker.build("${tuna_docker_name}", "${build_args} ." )
- tuna_docker.inside("${docker_run_args}") {
- env.PYTHONPATH=env.WORKSPACE
- env.PATH="${env.WORKSPACE}/tuna:${env.PATH}"
- env.TUNA_LOGLEVEL="${tuna_loglevel}"
-
- echo "./tuna/go_fish.py miopen load_job --session_id ${params.session_id} ${script_args}"
- sh "python3 ./tuna/go_fish.py miopen load_job --session_id ${params.session_id} ${script_args}"
+ docker.withRegistry('', "$DOCKER_CRED"){
+ def tuna_docker = docker.build("${tuna_docker_name}", "${build_args} ." )
+ tuna_docker.inside("${docker_run_args}") {
+ env.PYTHONPATH=env.WORKSPACE
+ env.PATH="${env.WORKSPACE}/tuna:${env.PATH}"
+ env.TUNA_LOGLEVEL="${tuna_loglevel}"
+
+ echo "./tuna/go_fish.py miopen load_job --session_id ${params.session_id} ${script_args}"
+ sh "python3 ./tuna/go_fish.py miopen load_job --session_id ${params.session_id} ${script_args}"
+ }
+ tuna_docker.push()
}
- tuna_docker.push()
}
def getSessionVals(session_id)
{
- String res = runsql("select arch, num_cu, rocm_v, miopen_v from session where id=${session_id};")
+ String res = runsql("select arch, num_cu, rocm_v, miopen_v, docker from session where id=${session_id};")
- def arch = res.split("[ \t]+")[0]
- def num_cu = res.split("[ \t]+")[1]
- def rocm_v = res.split("[ \t]+")[2]
- def miopen_v = res.split("[ \t]+")[3]
- echo "$arch $num_cu $rocm_v $miopen_v"
+ res_arr = res.split("[ \t]+")
+ def arch = res_arr[0]
+ def num_cu = res_arr[1]
+ def rocm_v = res_arr[2]
+ def miopen_v = res_arr[3]
+ def base_image = ""
+ if(res_arr.size() > 4)
+ base_image = res_arr[4]
+ echo "$arch $num_cu $rocm_v $miopen_v $base_image"
- def partition = "${arch}_${num_cu}"
+ def gfx_target = "${arch}_${num_cu}"
def osdb_bkc_version = ''
def rocm_version = ''
def subv_i = rocm_v.indexOf('-')
def ver_len = rocm_v.length() - subv_i - 1
- if(ver_len > 3)
+ if(base_image != ''){}
+ else if(ver_len > 3)
{
osdb_bkc_version=rocm_v.substring(subv_i+1)
}
@@ -761,27 +878,37 @@ def getSessionVals(session_id)
miopen_v = miopen_v.substring(0, subv_i)
}
- return [partition, osdb_bkc_version, rocm_version, miopen_v]
+ return [gfx_target, osdb_bkc_version, rocm_version, miopen_v, base_image]
}
+def getBuildArgs(){
+ (gfx_target, osdb_bkc_version, rocm_version, miopen_v, base_image) = getSessionVals(params.session_id)
-def applicUpdate(){
- def tuna_docker_name = getDockerName("${backend}")
- def tuna_docker
- (_, osdb_bkc_version, rocm_version, miopen_v) = getSessionVals(params.session_id)
-
- def build_args = " --network host --build-arg ROCMVERSION=${rocm_version} --build-arg OSDB_BKC_VERSION=${osdb_bkc_version} --build-arg BACKEND=${backend} --build-arg MIOPEN_BRANCH=${miopen_v} --build-arg DB_NAME=${params.db_name} --build-arg DB_USER_NAME=${params.db_user} --build-arg DB_USER_PASSWORD=${params.db_password} --build-arg DB_HOSTNAME=${params.db_host} --build-arg MIOPEN_USE_MLIR=${params.use_mlir}"
-
- if(params.base_image != '')
+ def arch = gfx_target.split("_")[0]
+ def build_args = " --network host --build-arg ROCMVERSION=${rocm_version} --build-arg OSDB_BKC_VERSION=${osdb_bkc_version} --build-arg BACKEND=${backend} --build-arg MIOPEN_BRANCH=${miopen_v} --build-arg DB_NAME=${params.db_name} --build-arg DB_USER_NAME=${params.db_user} --build-arg DB_USER_PASSWORD=${params.db_password} --build-arg DB_HOSTNAME=${params.db_host} --build-arg MIOPEN_USE_MLIR=${params.use_mlir} --build-arg ARCH_TARGET=${arch}"
+ if(base_image != '')
{
- build_args = build_args + " --build-arg BASEIMAGE=${params.base_image} --build-arg ROCM_PRE=1"
+ build_args = build_args + " --build-arg BASEIMAGE=${base_image}"
+ ci_str = "rocm/miopen:ci_"
+ if(ci_str != base_image.substring(0, ci_str.length()))
+ {
+ build_args = build_args + " --build-arg BUILD_MIOPEN_DEPS=1"
+ }
}
sh "echo ${build_args}"
- tuna_docker = docker.build("${tuna_docker_name}", "${build_args} ." )
- tuna_docker.push()
+ return [build_args, gfx_target]
+}
- def docker_args = "--network host --dns 8.8.8.8 -e TUNA_DB_HOSTNAME=${db_host} -e TUNA_DB_NAME=${params.db_name} -e TUNA_DB_USER_NAME=${db_user} -e TUNA_DB_PASSWORD=${db_password} -e gateway_ip=${gateway_ip} -e gateway_port=${gateway_port} -e gateway_user=${gateway_user} -e TUNA_LOGLEVEL=${params.tuna_loglevel}"
+def applicUpdate(){
+ (build_args, partition) = getBuildArgs()
+ def tuna_docker_name = getDockerName("${backend}")
+ docker.withRegistry('', "$DOCKER_CRED"){
+ def tuna_docker
+
+ tuna_docker = docker.build("${tuna_docker_name}", "${build_args} ." )
+ tuna_docker.push()
+ }
def use_tag = ''
if(params.config_tag != '')
@@ -791,7 +918,7 @@ def applicUpdate(){
if(params.UPDATE_SOLVERS)
{
- sh "srun --no-kill -p build-only -N 1 -l bash -c 'docker run ${docker_args} ${tuna_docker_name} ./tuna/go_fish.py miopen --update_solvers'"
+ sh "srun --no-kill -p build-only -N 1 -l bash -c 'echo ${env.CREDS_PSW} | HOME=/home/slurm docker login -u ${env.CREDS_USR} --password-stdin && HOME=/home/slurm docker run ${docker_args} ${tuna_docker_name} ./tuna/go_fish.py miopen --update_solvers'"
def num_solvers = runsql("SELECT count(*) from solver;")
println "Number of solvers: ${num_solvers}"
if (num_solvers.toInteger() == 0){
@@ -800,41 +927,35 @@ def applicUpdate(){
}
if(params.UPDATE_APPLICABILITY)
{
- sh "srun --no-kill -p build-only -N 1 -l bash -c 'docker run ${docker_args} ${tuna_docker_name} ./tuna/go_fish.py miopen --update_applicability --session_id ${params.session_id} ${use_tag}'"
+ sh "srun --no-kill -p ${partition} -N 1 -l bash -c 'echo ${env.CREDS_PSW} | HOME=/home/slurm docker login -u ${env.CREDS_USR} --password-stdin && HOME=/home/slurm docker run ${docker_args} ${tuna_docker_name} ./tuna/go_fish.py miopen --update_applicability --session_id ${params.session_id} ${use_tag}'"
def num_sapp = runsql("SELECT count(*) from conv_solver_applicability where session=${params.session_id};")
println "Session ${params.session_id} applicability: ${num_sapp}"
if (num_sapp.toInteger() == 0){
error("Unable to get applicability from Fin")
}
}
-
}
def compile()
{
+ (build_args, _) = getBuildArgs()
def tuna_docker_name = getDockerName("${backend}")
- def tuna_docker
- (_, osdb_bkc_version, rocm_version, miopen_v) = getSessionVals(params.session_id)
+ docker.withRegistry('', "$DOCKER_CRED"){
+ def tuna_docker
- def build_args = " --network host --build-arg ROCMVERSION=${rocm_version} --build-arg OSDB_BKC_VERSION=${osdb_bkc_version} --build-arg BACKEND=${backend} --build-arg MIOPEN_BRANCH=${miopen_v} --build-arg DB_NAME=${params.db_name} --build-arg DB_USER_NAME=${params.db_user} --build-arg DB_USER_PASSWORD=${params.db_password} --build-arg DB_HOSTNAME=${params.db_host} --build-arg MIOPEN_USE_MLIR=${params.use_mlir}"
+ tuna_docker = docker.build("${tuna_docker_name}", "${build_args} ." )
- if(params.base_image != '')
- {
- build_args = build_args + " --build-arg BASEIMAGE=${params.base_image} --build-arg ROCM_PRE=1"
+ tuna_docker.inside("--network host --dns 8.8.8.8 ") {
+ env.PYTHONPATH=env.WORKSPACE
+ env.PATH="${env.WORKSPACE}/tuna:${env.PATH}"
+ env.TUNA_LOGLEVEL="${tuna_loglevel}"
+ sh "pwd"
+ }
+ // push the image
+ tuna_docker.push()
}
- sh "echo ${build_args}"
- tuna_docker = docker.build("${tuna_docker_name}", "${build_args} ." )
-
- tuna_docker.inside("--network host --dns 8.8.8.8 ") {
- env.PYTHONPATH=env.WORKSPACE
- env.PATH="${env.WORKSPACE}/tuna:${env.PATH}"
- env.TUNA_LOGLEVEL="${tuna_loglevel}"
- sh "pwd"
- }
- // push the image
- tuna_docker.push()
env_list = params.env.split(' ')
for(item in env_list)
{
@@ -871,27 +992,22 @@ def compile()
}
// Run the jobs on the cluster
- sh "srun --no-kill -p ${partition} -N 1-10 -l bash -c 'docker run ${docker_args} ${tuna_docker_name} python3 /tuna/tuna/go_fish.py miopen ${compile_cmd} --session_id ${params.session_id}'"
+ sh "docker run ${docker_args} ${tuna_docker_name} python3 /tuna/tuna/go_fish.py miopen ${compile_cmd} --session_id ${params.session_id} --enqueue_only &"
+ sh "srun --no-kill -p ${partition} -N 1-10 -l bash -c 'echo ${env.CREDS_PSW} | HOME=/home/slurm docker login -u ${env.CREDS_USR} --password-stdin && HOME=/home/slurm docker run ${docker_args} ${tuna_docker_name} python3 /tuna/tuna/go_fish.py miopen ${compile_cmd} --session_id ${params.session_id}'"
}
def evaluate(params)
{
+ (build_args, partition) = getBuildArgs()
def tuna_docker_name = getDockerName("${backend}")
- def tuna_docker
- (partition, osdb_bkc_version, rocm_version, miopen_v) = getSessionVals(params.session_id)
-
- def build_args = " --network host --build-arg ROCMVERSION=${rocm_version} --build-arg OSDB_BKC_VERSION=${osdb_bkc_version} --build-arg BACKEND=HIP --build-arg MIOPEN_BRANCH=${miopen_v} --build-arg DB_NAME=${params.db_name} --build-arg DB_USER_NAME=${params.db_user} --build-arg DB_USER_PASSWORD=${params.db_password} --build-arg DB_HOSTNAME=${params.db_host} --build-arg MIOPEN_USE_MLIR=${params.use_mlir}"
- if(params.base_image != '')
- {
- build_args = build_args + " --build-arg BASEIMAGE=${params.base_image} --build-arg ROCM_PRE=1"
+ docker.withRegistry('', "$DOCKER_CRED"){
+ def tuna_docker
+ tuna_docker = docker.build("${tuna_docker_name}", "${build_args} ." )
+ tuna_docker.push()
}
- sh "echo ${build_args}"
- tuna_docker = docker.build("${tuna_docker_name}", "${build_args} ." )
- tuna_docker.push()
-
env_list = params.env.split(' ')
for(item in env_list)
{
@@ -921,13 +1037,14 @@ def evaluate(params)
eval_cmd += ' --dynamic_solvers_only'
}
- sh "srun --no-kill -p ${partition} -N 1-10 -l bash -c 'docker run ${docker_args} ${tuna_docker_name} python3 /tuna/tuna/go_fish.py miopen ${eval_cmd} --session_id ${params.session_id}'"
+ sh "docker run ${docker_args} ${tuna_docker_name} python3 /tuna/tuna/go_fish.py miopen ${eval_cmd} --session_id ${params.session_id} --enqueue_only &"
+ sh "srun --no-kill -p ${partition} -N 1-10 -l bash -c 'echo ${env.CREDS_PSW} | HOME=/home/slurm docker login -u ${env.CREDS_USR} --password-stdin && HOME=/home/slurm docker run ${docker_args} ${tuna_docker_name} python3 /tuna/tuna/go_fish.py miopen ${eval_cmd} --session_id ${params.session_id}'"
}
def doxygen() {
node {
checkout scm
- def tuna_docker = docker.build("ci-tuna:${branch_id}", " .")
+ def tuna_docker = docker.build("ci-tuna_${branch_id}", " .")
tuna_docker.inside("") {
sh "cd doc && doxygen Doxyfile"
def empty = sh returnStdout: true, script: "ls doc | wc -l"
From cb771ae7788680d0d027e102a043f8188558cf64 Mon Sep 17 00:00:00 2001
From: Djordje Antic
Date: Thu, 31 Jul 2025 10:21:11 +0200
Subject: [PATCH 2/7] Update requirements.txt
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 1b29fc500..e76865ef2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -66,7 +66,7 @@ yamllint==1.29.0
yapf==0.40.2
zipp==3.19.1
coverage==7.0.5
-python-logstash-async==3.0.0
+python-logstash-async
mysql-connector-python
prometheus_flask_exporter
tenacity
From dbf92288f2f010abba9e62ec3339259815912759 Mon Sep 17 00:00:00 2001
From: Djordje Antic
Date: Thu, 31 Jul 2025 09:24:23 -0400
Subject: [PATCH 3/7] logstash async 3.0.0
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index e76865ef2..1b29fc500 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -66,7 +66,7 @@ yamllint==1.29.0
yapf==0.40.2
zipp==3.19.1
coverage==7.0.5
-python-logstash-async
+python-logstash-async==3.0.0
mysql-connector-python
prometheus_flask_exporter
tenacity
From 8c9d8e6578b027746994700b93749dd76ca7c33a Mon Sep 17 00:00:00 2001
From: Djordje Antic
Date: Thu, 31 Jul 2025 09:26:40 -0400
Subject: [PATCH 4/7] Try 2.2.1 for logstash async
---
docs/sphinx/requirements.txt | 2 +-
requirements.txt | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt
index b8c28cdf8..62d40eb4a 100644
--- a/docs/sphinx/requirements.txt
+++ b/docs/sphinx/requirements.txt
@@ -66,7 +66,7 @@ yamllint==1.29.0
yapf==0.40.2
zipp==3.19.1
coverage==7.0.5
-python-logstash-async==3.0.0
+python-logstash-async==2.2.1
mysql-connector-python
prometheus_flask_exporter
tenacity
diff --git a/requirements.txt b/requirements.txt
index 1b29fc500..59192008b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -66,7 +66,7 @@ yamllint==1.29.0
yapf==0.40.2
zipp==3.19.1
coverage==7.0.5
-python-logstash-async==3.0.0
+python-logstash-async==2.2.1
mysql-connector-python
prometheus_flask_exporter
tenacity
From 7bf15979233dc571d2facb6835716a6c03a73b78 Mon Sep 17 00:00:00 2001
From: Djordje Antic
Date: Fri, 1 Aug 2025 12:40:53 +0200
Subject: [PATCH 5/7] Update requirements.txt
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 59192008b..1b29fc500 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -66,7 +66,7 @@ yamllint==1.29.0
yapf==0.40.2
zipp==3.19.1
coverage==7.0.5
-python-logstash-async==2.2.1
+python-logstash-async==3.0.0
mysql-connector-python
prometheus_flask_exporter
tenacity
From 0cf24ea82595d8d833983cdf61ed36d4a7abea82 Mon Sep 17 00:00:00 2001
From: Djordje Antic
Date: Fri, 1 Aug 2025 12:41:47 +0200
Subject: [PATCH 6/7] Update requirements.txt
---
docs/sphinx/requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt
index 62d40eb4a..b8c28cdf8 100644
--- a/docs/sphinx/requirements.txt
+++ b/docs/sphinx/requirements.txt
@@ -66,7 +66,7 @@ yamllint==1.29.0
yapf==0.40.2
zipp==3.19.1
coverage==7.0.5
-python-logstash-async==2.2.1
+python-logstash-async==3.0.0
mysql-connector-python
prometheus_flask_exporter
tenacity
From d7e38345899949ff49db2e9f7c9f7b0710c21927 Mon Sep 17 00:00:00 2001
From: Djordje Antic
Date: Fri, 1 Aug 2025 15:46:22 -0400
Subject: [PATCH 7/7] Remove SimpleCSVMixin
---
tuna/rocmlir/rocmlir_tables.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/tuna/rocmlir/rocmlir_tables.py b/tuna/rocmlir/rocmlir_tables.py
index 9eaf9a027..499c8547b 100644
--- a/tuna/rocmlir/rocmlir_tables.py
+++ b/tuna/rocmlir/rocmlir_tables.py
@@ -360,7 +360,7 @@ def get_configurations(self, filename):
return configs
-class ResultsMixin(SimpleCSVMixin): # pylint: disable=too-many-instance-attributes
+class ResultsMixin(): # pylint: disable=too-many-instance-attributes
"""Collects the results of tuning."""
def __init__(self, **kwargs):
@@ -440,7 +440,7 @@ class GEMMJob(BASE, JobMixin):
index=True)
-class GEMMConfig(BASE, SimpleCSVMixin):
+class GEMMConfig(BASE):
"""Represents GEMM config table"""
__tablename__ = "rocmlir_gemm_config"
@@ -609,7 +609,7 @@ class AttentionJob(BASE, JobMixin):
index=True)
-class AttentionConfig(BASE, SimpleCSVMixin):
+class AttentionConfig(BASE):
"""Represents Attention config table"""
__tablename__ = "rocmlir_attention_config"