From 2610b15078ceee4cb49e2521da5c8844822c2939 Mon Sep 17 00:00:00 2001 From: Reid Kawaja <74506315+reidkwja@users.noreply.github.com> Date: Fri, 12 Dec 2025 14:27:36 -0600 Subject: [PATCH 1/6] init - adding tests for subcmd folder --- tests/test_export_db_branches.py | 324 ++++++++++++++++++++++++++ tests/test_import_configs_branches.py | 249 ++++++++++++++++++++ tests/test_load_job_branches.py | 166 +++++++++++++ tests/test_merge_db_branches.py | 116 +++++++++ 4 files changed, 855 insertions(+) create mode 100644 tests/test_export_db_branches.py create mode 100644 tests/test_import_configs_branches.py create mode 100644 tests/test_load_job_branches.py create mode 100644 tests/test_merge_db_branches.py diff --git a/tests/test_export_db_branches.py b/tests/test_export_db_branches.py new file mode 100644 index 00000000..37715213 --- /dev/null +++ b/tests/test_export_db_branches.py @@ -0,0 +1,324 @@ +############################################################################### +# +# MIT License +# +# Copyright (c) 2025 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +############################################################################### +"""Additional branch coverage for export_db.""" +import argparse +import base64 +import json +import os +import sqlite3 +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import tuna.miopen.subcmd.export_db as export_db +from tuna.utils.db_utility import DB_Type +from tuna.miopen.utils.config_type import ConfigType + + +class _FakeColumn: + """Lightweight column stub that records equality checks.""" + + def __init__(self, name): + self.name = name + + def __eq__(self, other): + return SimpleNamespace(value=other) + + def in_(self, _): + return self + + +class _FakeQuery: + """Query stub that carries simple state.""" + + def __init__(self, entries=None, value=None): + self._entries = entries or [] + self.value = value + + def all(self): + return self._entries + + def filter(self, cond): + val = getattr(cond, 'value', None) + return _FakeQuery(self._entries, val) + + def distinct(self): + return self + + def subquery(self): + return self + + def order_by(self, *_, **__): + return self + + +def _fake_session_factory(num_cu_values): + """Build a DbSession replacement that returns provided num_cu values.""" + + class _FakeSession: + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def query(self, *_): + return _FakeQuery(entries=[(val,) for val in num_cu_values]) + + def commit(self): + return None + + return _FakeSession + + +def test_arg_export_db_requires_arch(monkeypatch): + args = argparse.Namespace(golden_v='1.0', arch=None) + logger = MagicMock() + export_db.arg_export_db(args, logger) + logger.error.assert_called_once() + + +def test_get_filename_variants(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + # filename provided + fname = export_db.get_filename('gfx900', None, 'custom', True, + DB_Type.FIND_DB) + assert fname.endswith('custom.OpenCL.fdb.txt') + # num_cu missing falls back to arch + fname = export_db.get_filename('gfx900', None, None, False, DB_Type.PERF_DB) + assert fname.endswith('gfx900.db.txt') + # num_cu over 64 encoded as hex + fname = export_db.get_filename('gfx900', 128, None, False, DB_Type.FIND_DB) + assert fname.endswith('gfx90080.HIP.fdb.txt') + # kernel db extension + fname = export_db.get_filename('gfx900', 32, None, False, DB_Type.KERN_DB) + assert fname.endswith('.kdb') + + +def test_fin_net_cfg_job_non_convolution(monkeypatch): + logger = MagicMock() + jobs = export_db.fin_net_cfg_job([], logger, ConfigType.batch_norm) + logger.error.assert_called_once() + assert jobs == [] + + +def test_add_entry_to_solvers_duplicate(): + solvers = {'k1': {'solver_a': 123}} + entry = SimpleNamespace(fdb_key='k1', solver='solver_a', update_ts=456) + added = export_db.add_entry_to_solvers(entry, solvers, MagicMock()) + assert added is False + assert solvers['k1']['solver_a'] == 123 + + +def test_build_miopen_fdb_trims(monkeypatch): + monkeypatch.setattr(export_db, 'require_id_solvers', lambda: None) + monkeypatch.setattr(export_db, 'ID_SOLVER_MAP', { + 's0': 0, + 's1': 1, + 's2': 2, + 's3': 3, + 's4': 4 + }) + + class _Entry(SimpleNamespace): + pass + + entries = [] + for idx in range(5): + entries.append((_Entry(fdb_key='key', + solver=f's{idx}', + kernel_time=idx, + workspace_sz=0, + update_ts=idx), None)) + + query = _FakeQuery(entries=entries) + result = export_db.build_miopen_fdb(query, MagicMock()) + assert 'key' in result + assert len(result['key']) == 4 + assert all(isinstance(x, _Entry) for x in result['key']) + + +def test_write_kdb_handles_extensions_and_dedup(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + logger = MagicMock() + blob = base64.b64encode(b'abc') + kern1 = SimpleNamespace(kernel_name='k', + kernel_args='arg', + kernel_blob=blob, + kernel_hash='h1', + uncompressed_size=1, + kernel_group=1) + kern2 = SimpleNamespace(kernel_name='k', + kernel_args='arg', + kernel_blob=blob, + kernel_hash='h1', + uncompressed_size=1, + kernel_group=1) + kern3 = SimpleNamespace(kernel_name='k.mlir', + kernel_args='arg3', + kernel_blob=blob, + kernel_hash='h3', + uncompressed_size=1, + kernel_group=2) + fname = export_db.write_kdb('gfx900', + 64, [kern1, kern2, kern3], + logger, + filename='out') + conn = sqlite3.connect(fname) + cur = conn.cursor() + cur.execute("SELECT kernel_name, kernel_args FROM kern_db ORDER BY id") + rows = cur.fetchall() + cur.close() + conn.close() + # duplicate filtered, mlir keeps args as-is without -mcpu + assert len(rows) == 2 + assert rows[0][0] == 'k.o' + assert '-mcpu=gfx900' in rows[0][1] + assert rows[1][0] == 'k.mlir.o' + assert '-mcpu=' not in rows[1][1] + + +def test_build_miopen_fdb_skews(monkeypatch): + # ensure deterministic solver map for sorting + monkeypatch.setattr(export_db, 'require_id_solvers', lambda: None) + monkeypatch.setattr(export_db, 'ID_SOLVER_MAP', {'s': 0}) + + args = SimpleNamespace() + args.src_table = SimpleNamespace(num_cu=_FakeColumn('num_cu'), + id=_FakeColumn('id')) + query = _FakeQuery(entries=[(SimpleNamespace(id=1), None)]) + + def fake_build(miopen_query, logger): + return {f"k_{miopen_query.value}": ['entry']} + + fake_session = _fake_session_factory([64, 32]) + monkeypatch.setattr(export_db, 'DbSession', fake_session) + monkeypatch.setattr(export_db, 'build_miopen_fdb', fake_build) + + res = export_db.build_miopen_fdb_skews(args, query, MagicMock()) + assert set(res.keys()) == {'k_64_cu64', 'k_32_cu32'} + + +def test_export_kdb_uses_skews(monkeypatch): + args = SimpleNamespace(arch='gfx', num_cu=None, filename=None, opencl=False) + dbt = SimpleNamespace() + logger = MagicMock() + monkeypatch.setattr(export_db, 'get_fdb_query', lambda *_: _FakeQuery()) + monkeypatch.setattr(export_db, 'build_miopen_fdb_skews', + lambda *_, **__: {'k': ['v']}) + monkeypatch.setattr(export_db, 'build_miopen_kdb', lambda *_: [1, 2, 3]) + monkeypatch.setattr(export_db, 'write_kdb', lambda *_, **__: 'kdb.out') + result = export_db.export_kdb(dbt, args, logger, skew_fdbs=True) + assert result == 'kdb.out' + + +def test_build_miopen_pdb(monkeypatch): + monkeypatch.setattr(export_db, 'require_id_solvers', lambda: None) + monkeypatch.setattr(export_db, 'ID_SOLVER_MAP', {'s1': 1}) + args = SimpleNamespace() + query_entries = [ + (SimpleNamespace(fdb_key='key1', + solver='s1', + kernel_time=1, + workspace_sz=0, + update_ts=1, + params='p1'), SimpleNamespace(id=1)), + (SimpleNamespace(fdb_key='dup', + solver='s1', + kernel_time=2, + workspace_sz=0, + update_ts=2, + params='p1'), SimpleNamespace(id=1)), + (SimpleNamespace(fdb_key='key2', + solver='s1', + kernel_time=3, + workspace_sz=0, + update_ts=3, + params='p2'), SimpleNamespace(id=2)), + ] + query = _FakeQuery(entries=query_entries) + monkeypatch.setattr(export_db, 'fin_db_key', lambda *_: {1: 'db1', 2: 'db2'}) + res = export_db.build_miopen_pdb(query, MagicMock()) + assert set(res.keys()) == {'db1', 'db2'} + # duplicate solver entry skipped, leaving one record for db1 + assert len(res['db1']) == 1 + + +def test_write_pdb_orders_by_solver(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + monkeypatch.setattr(export_db, 'require_id_solvers', lambda: None) + monkeypatch.setattr(export_db, 'ID_SOLVER_MAP', {'s1': 1, 's2': 2}) + perf_db = { + 'db': [ + SimpleNamespace(solver='s2', params='two'), + SimpleNamespace(solver='s1', params='one') + ] + } + path = export_db.write_pdb('gfx', 1, False, perf_db, filename='perf.txt') + with open(path, 'r', encoding='utf-8') as fp: + line = fp.read().strip() + assert line == 'db=1:one;2:two' + + +def test_export_pdb_txt(monkeypatch): + args = SimpleNamespace(arch='gfx', num_cu=1, opencl=False, filename=None) + dbt = SimpleNamespace() + logger = MagicMock() + monkeypatch.setattr(export_db, 'get_pdb_query', lambda *_: _FakeQuery()) + monkeypatch.setattr(export_db, 'build_miopen_pdb', + lambda *_: {'db': ['entry']}) + monkeypatch.setattr(export_db, 'write_pdb', lambda *_, **__: 'perf.out') + res = export_db.export_pdb_txt(dbt, args, logger) + assert res == 'perf.out' + + +def test_run_export_db_routes(monkeypatch, capsys): + + class _DummyTables: + + def __init__(self, session_id=None, **_): + self.session = SimpleNamespace(arch='gfxA', num_cu=64) + self.find_db_table = 'find' + self.golden_table = 'golden' + self.session_id = session_id + + args = argparse.Namespace(session_id=1, + golden_v='1.0', + find_db=True, + kern_db=False, + perf_db=False, + arch=None, + num_cu=None, + opencl=False, + filename=None) + logger = MagicMock() + monkeypatch.setattr(export_db, 'MIOpenDBTables', _DummyTables) + monkeypatch.setattr(export_db, 'export_fdb', lambda *_: 'out.find') + export_db.run_export_db(args, logger) + captured = capsys.readouterr() + assert 'out.find' in captured.out diff --git a/tests/test_import_configs_branches.py b/tests/test_import_configs_branches.py new file mode 100644 index 00000000..a19bace3 --- /dev/null +++ b/tests/test_import_configs_branches.py @@ -0,0 +1,249 @@ +############################################################################### +# +# MIT License +# +# Copyright (c) 2025 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +############################################################################### +"""Branch coverage for import_configs.""" +import argparse +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import tuna.miopen.subcmd.import_configs as import_configs +from tuna.miopen.utils.config_type import ConfigType +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm.exc import NoResultFound + + +class _FakeSession: + """Simple DbSession stand‑in for unit tests.""" + + def __init__(self, query_result=None, raise_on_commit=None): + self._query_result = query_result or [] + self._raise_on_commit = raise_on_commit + self.added = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def query(self, *_, **__): + + class _Query: + + def __init__(self, result): + self._result = result + + def filter(self, *_): + return self + + def one(self): + if isinstance(self._result, Exception): + raise self._result + return SimpleNamespace(id=self._result) + + def all(self): + return self._result + + return _Query(self._query_result) + + def merge(self, obj): + self.added.append(obj) + if isinstance(self._raise_on_commit, IntegrityError): + raise self._raise_on_commit + return obj + + def add(self, obj): + self.added.append(obj) + if isinstance(self._raise_on_commit, IntegrityError): + raise self._raise_on_commit + + def commit(self): + if isinstance(self._raise_on_commit, IntegrityError): + raise self._raise_on_commit + + def rollback(self): + return None + + +def test_create_query_variants(): + assert import_configs.create_query(None, True, 1) == { + 'config': 1, + 'recurrent': 1 + } + assert import_configs.create_query('tag', False, 2) == { + 'config': 2, + 'tag': 'tag' + } + assert import_configs.create_query('tag', True, 3) == { + 'config': 3, + 'tag': 'tag', + 'recurrent': 1 + } + + +def test_set_import_cfg_batches(monkeypatch): + args = argparse.Namespace(batches='1,2,3', batch_list=None) + import_configs.set_import_cfg_batches(args) + assert args.batch_list == [1, 2, 3] + args = argparse.Namespace(batches=None, batch_list=None) + import_configs.set_import_cfg_batches(args) + assert args.batch_list == [] + + +def test_process_config_line_tag_only(monkeypatch): + driver = MagicMock() + args = argparse.Namespace(tag_only=True) + counts = {} + dbt = MagicMock() + logger = MagicMock() + monkeypatch.setattr(import_configs, 'tag_config_v2', lambda *_, **__: True) + result = import_configs.process_config_line_v2(driver, args, counts, dbt, + logger) + assert result is False + + +def test_parse_line_batches(monkeypatch): + called_batches = [] + + class _Driver(MagicMock): + batchsize = None + + def _process(driver, args, counts, dbt, logger): + called_batches.append(driver.batchsize) + return True + + monkeypatch.setattr(import_configs, 'DriverConvolution', _Driver) + monkeypatch.setattr(import_configs, 'process_config_line_v2', _process) + args = argparse.Namespace(config_type=ConfigType.convolution, + command=None, + batch_list=[2, 4], + tag_only=False) + import_configs.parse_line(args, 'cmd', {}, MagicMock(), MagicMock()) + assert called_batches == [2, 4] + + +def test_add_model_integrity_error(monkeypatch): + fake_session = _FakeSession(raise_on_commit=IntegrityError('dup')) + monkeypatch.setattr(import_configs, 'DbSession', lambda: fake_session) + args = argparse.Namespace(add_model='m', md_version=1) + logger = MagicMock() + assert import_configs.add_model(args, logger) is False + + +def test_get_database_id_not_found(monkeypatch): + fake_session = _FakeSession(query_result=NoResultFound()) + monkeypatch.setattr(import_configs, 'DbSession', lambda: fake_session) + dbt = MagicMock() + logger = MagicMock() + mid, fid = import_configs.get_database_id('fw', 1, 'model', 1.0, dbt, logger) + assert mid == -1 and fid == -1 + + +@pytest.mark.parametrize("mid,fid", [(None, 1), (1, None)]) +def test_add_benchmark_missing_ids(monkeypatch, mid, fid): + monkeypatch.setattr(import_configs, 'get_database_id', + lambda *_args, **_kwargs: (mid, fid)) + args = argparse.Namespace(framework='fw', + fw_version=1, + model='m', + md_version=1.0, + config_type=ConfigType.convolution, + driver='cmd', + file_name=None, + add_model=False, + add_framework=False, + add_benchmark=True, + gpu_count=1) + dbt = MagicMock() + logger = MagicMock() + assert import_configs.add_benchmark(args, dbt, logger) is False + + +def test_check_import_benchmark_args_raises(): + args = argparse.Namespace(add_model=True, + md_version=None, + add_benchmark=True, + model=None, + framework=None, + gpu_count=None, + md_version2=None, + fw_version=None, + driver=None, + file_name=None) + with pytest.raises(ValueError): + import_configs.check_import_benchmark_args(args) + + +def test_run_import_configs_flag_paths(monkeypatch): + args = argparse.Namespace(config_type=ConfigType.convolution, + print_models=True, + add_model=False, + add_framework=False, + add_benchmark=False, + batches=None, + batch_list=[], + tag=None, + tag_only=False, + file_name=None, + mark_recurrent=False) + logger = MagicMock() + monkeypatch.setattr(import_configs, 'MIOpenDBTables', MagicMock()) + monkeypatch.setattr(import_configs, 'check_import_benchmark_args', + lambda *_: None) + called = {} + monkeypatch.setattr(import_configs, 'print_models', + lambda *_: called.setdefault('print', True)) + assert import_configs.run_import_configs(args, logger) is True + + # cover add_model and add_framework branch + args.print_models = False + args.add_model = True + args.add_framework = True + monkeypatch.setattr(import_configs, 'add_model', + lambda *_: called.setdefault('model', True)) + monkeypatch.setattr(import_configs, 'add_frameworks', + lambda *_: called.setdefault('framework', True)) + assert import_configs.run_import_configs(args, logger) is True + + # cover add_benchmark branch + args.add_model = False + args.add_framework = False + args.add_benchmark = True + monkeypatch.setattr(import_configs, 'add_benchmark', + lambda *_: called.setdefault('benchmark', True)) + assert import_configs.run_import_configs(args, logger) is True + + # fall through to import_cfgs path + args.add_benchmark = False + args.print_models = False + args.batches = '8,16' + monkeypatch.setattr( + import_configs, 'import_cfgs', lambda *_args, **_kwargs: { + 'cnt_configs': 0, + 'cnt_tagged_configs': set() + }) + assert import_configs.run_import_configs(args, logger) is True diff --git a/tests/test_load_job_branches.py b/tests/test_load_job_branches.py new file mode 100644 index 00000000..0c302218 --- /dev/null +++ b/tests/test_load_job_branches.py @@ -0,0 +1,166 @@ +############################################################################### +# +# MIT License +# +# Copyright (c) 2025 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +############################################################################### +"""Branch coverage for load_job.""" +import argparse +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import tuna.miopen.subcmd.load_job as load_job +from tuna.miopen.utils.config_type import ConfigType +from tuna.miopen.utils.metadata import ALG_SLV_MAP, TENSOR_PRECISION + + +class _FakeQuery: + """Minimal query stub.""" + + def __init__(self, result=None): + self._result = result or [] + self.filters = [] + + def filter(self, cond): + self.filters.append(cond) + return self + + def all(self): + return self._result + + def subquery(self): + return self + + +class _FakeSession: + """DbSession replacement to avoid hitting the database.""" + + def __init__(self, result=None): + self.result = result or [] + self.executed = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def query(self, *_, **__): + return _FakeQuery(self.result) + + def execute(self, stmt): + self.executed.append(stmt) + return [] + + def add(self, *_): + return None + + def commit(self): + return None + + def rollback(self): + return None + + +def test_arg_solvers_uses_algo(monkeypatch): + algo = next(iter(ALG_SLV_MAP.keys())) + args = argparse.Namespace(solvers=None, algo=algo) + logger = MagicMock() + result = load_job.arg_solvers(args, logger) + assert result.solvers[0][0] in ALG_SLV_MAP[algo] + + +def test_test_tag_name_missing(monkeypatch): + monkeypatch.setattr(load_job, 'DbSession', lambda: _FakeSession(result=[])) + with pytest.raises(ValueError): + load_job.test_tag_name('missing', MagicMock()) + + +def test_config_query_filters(monkeypatch): + cmd_key = next(iter(TENSOR_PRECISION.keys())) + session = _FakeSession(result=[(1,)]) + args = argparse.Namespace(tag='foo', cmd=cmd_key) + dbt = SimpleNamespace( + config_table=SimpleNamespace(id=1, + valid=1, + input_t=SimpleNamespace(data_type=None)), + config_tags_table=SimpleNamespace(config=1, tag='foo'), + ) + query = load_job.config_query(args, session, dbt) + assert isinstance(query, _FakeQuery) + + +def test_compose_query_with_filters(monkeypatch): + args = argparse.Namespace(session_id=1, + solvers=[('s', 1)], + tunable=True, + config_type=ConfigType.batch_norm, + only_dynamic=True) + dbt = SimpleNamespace(solver_app=SimpleNamespace(config=1, + session=1, + solver=1), + config_table=SimpleNamespace(id=1), + job_table=MagicMock()) + session = _FakeSession(result=[(1, 'solver')]) + query = load_job.compose_query(args, session, dbt, _FakeQuery(result=[1])) + assert isinstance(query, _FakeQuery) + + +def test_add_jobs_empty_results(monkeypatch, caplog): + logger = MagicMock() + dbt = SimpleNamespace(job_table=MagicMock()) + args = argparse.Namespace(label='lbl', + fin_steps=None, + session_id=1, + solvers=[('', None)], + tag=None, + cmd=None, + tunable=False, + config_type=ConfigType.convolution, + only_dynamic=False) + monkeypatch.setattr(load_job, 'DbSession', lambda: _FakeSession(result=[])) + monkeypatch.setattr(load_job, 'config_query', lambda *_: _FakeQuery()) + monkeypatch.setattr(load_job, 'compose_query', lambda *_: _FakeQuery()) + count = load_job.add_jobs(args, dbt, logger) + assert count == 0 + logger.error.assert_called_once() + + +def test_run_load_job_tag_error(monkeypatch, capsys): + args = argparse.Namespace(tag='missing', + solvers=None, + algo=None, + fin_steps=None, + config_type=ConfigType.convolution, + session_id=1, + label='lbl', + only_dynamic=False, + tunable=False) + logger = MagicMock() + monkeypatch.setattr(load_job, 'test_tag_name', lambda *_: + (_ for _ in ()).throw(ValueError('missing'))) + monkeypatch.setattr(load_job, 'add_jobs', lambda *_: 0) + load_job.run_load_job(args, logger) + captured = capsys.readouterr() + assert 'New jobs added' in captured.out diff --git a/tests/test_merge_db_branches.py b/tests/test_merge_db_branches.py new file mode 100644 index 00000000..c0c57873 --- /dev/null +++ b/tests/test_merge_db_branches.py @@ -0,0 +1,116 @@ +############################################################################### +# +# MIT License +# +# Copyright (c) 2025 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +############################################################################### +"""Additional merge_db branch coverage.""" +import sqlite3 + +import pytest + +import tuna.miopen.subcmd.merge_db as merge_db + + +def test_parse_args_requires_type(monkeypatch): + monkeypatch.setattr(merge_db.argparse._sys, 'argv', + ['prog', '-m', 'master', '-t', 'target']) + with pytest.raises(SystemExit): + merge_db.parse_args() + + +@pytest.mark.parametrize("master,key,vals,keep_keys", [ + (None, 'k', {}, False), + ({}, '', {}, False), + ({}, 'k', None, False), +]) +def test_target_merge_validation(master, key, vals, keep_keys): + with pytest.raises(ValueError): + merge_db.target_merge(master, key, vals, keep_keys) + + +def test_is_float(): + assert merge_db.is_float('1.0') + assert merge_db.is_float('3') is True + assert merge_db.is_float('bad') is False + + +def test_merge_text_file_copy_only(tmp_path): + master = tmp_path / "gfx900.HIP.fdb.txt" + target = tmp_path / "target.HIP.fdb.txt" + master.write_text("a=1:1.0\n", encoding="utf-8") + target.write_text("b=2:2.0\n", encoding="utf-8") + res = merge_db.merge_text_file(str(master), + copy_only=True, + keep_keys=False, + target_file=str(target)) + assert res is None + + +def test_merge_sqlite_bin_cache(tmp_path): + dest = tmp_path / "gfx900.kdb" + src = tmp_path / "src.kdb" + for path in [dest, src]: + conn = sqlite3.connect(path) + cur = conn.cursor() + cur.execute( + "CREATE TABLE kern_db (kernel_name TEXT, kernel_args TEXT, kernel_blob BLOB, kernel_hash TEXT, uncompressed_size INT, PRIMARY KEY(kernel_name, kernel_args))" + ) + conn.commit() + cur.close() + conn.close() + + # preload destination with one row to force duplicate path + conn = sqlite3.connect(dest) + conn.execute("INSERT INTO kern_db VALUES('k','a',x'00','h',1)") + conn.commit() + conn.close() + + conn_src = sqlite3.connect(src) + conn_src.execute("INSERT INTO kern_db VALUES('k','a',x'00','h',1)") + conn_src.commit() + conn_src.close() + + conn_dest = sqlite3.connect(dest) + merge_db.merge_sqlite_bin_cache(conn_dest, [str(src)]) + cur = conn_dest.cursor() + cur.execute("SELECT count(*) FROM kern_db") + count = cur.fetchone()[0] + cur.close() + conn_dest.close() + assert count == 1 + + +def test_get_file_list_filters(tmp_path): + master_dir = tmp_path / "dir" + master_dir.mkdir() + (master_dir / "gfx803_36.HIP.fdb.txt").write_text("", encoding="utf-8") + (master_dir / "ignore.txt").write_text("", encoding="utf-8") + args = merge_db.argparse.Namespace(master_file=str(master_dir), + find_db=True, + bin_cache=False, + perf_db=False, + copy_only=False, + keep_keys=False, + target_file="") + files = merge_db.get_file_list(args) + assert len(files) == 1 From 00b2963bd045d4435010b079ae5e15770109e4ad Mon Sep 17 00:00:00 2001 From: Reid Kawaja <74506315+reidkwja@users.noreply.github.com> Date: Fri, 12 Dec 2025 14:35:32 -0600 Subject: [PATCH 2/6] init - adding tests to CI --- vars/utils.groovy | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vars/utils.groovy b/vars/utils.groovy index 20e85b25..78347557 100644 --- a/vars/utils.groovy +++ b/vars/utils.groovy @@ -607,11 +607,13 @@ def pytestSuite1() { sshagent (credentials: ['bastion-ssh-key']) { sh "coverage erase" sh "python3 -m coverage run -a -m pytest tests/test_export_db.py -s" + sh "python3 -m coverage run -a -m pytest tests/test_export_db_branches.py -s" sh "python3 -m coverage run -a -m pytest tests/test_abort_file.py -s" sh "python3 -m coverage run -a -m pytest tests/test_analyze_parse_db.py -s" sh "python3 -m coverage run -a -m pytest tests/test_connection.py -s" // builder then evaluator in sequence sh "python3 -m coverage run -a -m pytest tests/test_importconfigs.py -s" + sh "python3 -m coverage run -a -m pytest tests/test_import_configs_branches.py -s" sh "python3 -m coverage run -a -m pytest tests/test_machine.py -s" sh "python3 -m coverage run -a -m pytest tests/test_dbBase.py -s" sh "python3 -m coverage run -a -m pytest tests/test_driver.py -s" @@ -652,12 +654,14 @@ def pytestSuite1() { sh "python3 -m coverage run -a -m pytest tests/test_example_lib_extended.py -s" sh "python3 -m coverage run -a -m pytest tests/test_yaml_parser.py -s" sh "python3 -m coverage run -a -m pytest tests/test_load_job.py -s" + sh "python3 -m coverage run -a -m pytest tests/test_load_job_branches.py -s" sh "python3 -m coverage run -a -m pytest tests/test_add_session_rocmlir.py -s" sh "python3 -m coverage run -a -m pytest tests/test_importconfigs_rocmlir.py -s" sh "python3 -m coverage run -a -m pytest tests/test_load_job_rocmlir.py -s" sh "python3 -m coverage run -a -m pytest tests/test_rocmlir.py -s" sh "python3 -m coverage run -a -m pytest tests/test_helper.py -s" sh "python3 -m coverage run -a -m pytest tests/test_mituna_interface.py -s" + sh "python3 -m coverage run -a -m pytest tests/test_merge_db_branches.py -s" // The OBMC host used in the following test is down // sh "pytest tests/test_mmi.py " } From 60197074469540d69845489ed883f5200e26d5ef Mon Sep 17 00:00:00 2001 From: Reid Kawaja <74506315+reidkwja@users.noreply.github.com> Date: Fri, 12 Dec 2025 15:09:45 -0600 Subject: [PATCH 3/6] fixes: export_db tests --- tests/test_export_db_branches.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_export_db_branches.py b/tests/test_export_db_branches.py index 37715213..96a25d23 100644 --- a/tests/test_export_db_branches.py +++ b/tests/test_export_db_branches.py @@ -265,8 +265,8 @@ def test_build_miopen_pdb(monkeypatch): monkeypatch.setattr(export_db, 'fin_db_key', lambda *_: {1: 'db1', 2: 'db2'}) res = export_db.build_miopen_pdb(query, MagicMock()) assert set(res.keys()) == {'db1', 'db2'} - # duplicate solver entry skipped, leaving one record for db1 - assert len(res['db1']) == 1 + # entries sharing a db_key are preserved even with different fdb_keys + assert len(res['db1']) == 2 def test_write_pdb_orders_by_solver(monkeypatch, tmp_path): From 8c00c1d591f9accb211a8a2c2326d3829689aff8 Mon Sep 17 00:00:00 2001 From: Reid Kawaja <74506315+reidkwja@users.noreply.github.com> Date: Fri, 12 Dec 2025 15:22:35 -0600 Subject: [PATCH 4/6] fixes: import configs --- tests/test_import_configs_branches.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_import_configs_branches.py b/tests/test_import_configs_branches.py index a19bace3..eada76ab 100644 --- a/tests/test_import_configs_branches.py +++ b/tests/test_import_configs_branches.py @@ -147,7 +147,8 @@ def _process(driver, args, counts, dbt, logger): def test_add_model_integrity_error(monkeypatch): - fake_session = _FakeSession(raise_on_commit=IntegrityError('dup')) + fake_session = _FakeSession( + raise_on_commit=IntegrityError('stmt', 'params', Exception('dup'))) monkeypatch.setattr(import_configs, 'DbSession', lambda: fake_session) args = argparse.Namespace(add_model='m', md_version=1) logger = MagicMock() From 05b644f50f48a98a7b8519bb0facf61f99135870 Mon Sep 17 00:00:00 2001 From: Reid Kawaja <74506315+reidkwja@users.noreply.github.com> Date: Fri, 12 Dec 2025 15:36:39 -0600 Subject: [PATCH 5/6] fixes: load job branches --- tests/test_load_job_branches.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/test_load_job_branches.py b/tests/test_load_job_branches.py index 0c302218..f0f6869b 100644 --- a/tests/test_load_job_branches.py +++ b/tests/test_load_job_branches.py @@ -83,6 +83,16 @@ def rollback(self): return None +class _FakeColumn: + """Column stub supporting in_ operator used in filters.""" + + def __init__(self, value=None): + self.value = value + + def in_(self, _): + return self + + def test_arg_solvers_uses_algo(monkeypatch): algo = next(iter(ALG_SLV_MAP.keys())) args = argparse.Namespace(solvers=None, algo=algo) @@ -102,10 +112,11 @@ def test_config_query_filters(monkeypatch): session = _FakeSession(result=[(1,)]) args = argparse.Namespace(tag='foo', cmd=cmd_key) dbt = SimpleNamespace( - config_table=SimpleNamespace(id=1, - valid=1, - input_t=SimpleNamespace(data_type=None)), - config_tags_table=SimpleNamespace(config=1, tag='foo'), + config_table=SimpleNamespace( + id=_FakeColumn(), + valid=1, + input_t=SimpleNamespace(data_type=_FakeColumn())), + config_tags_table=SimpleNamespace(config=_FakeColumn(), tag='foo'), ) query = load_job.config_query(args, session, dbt) assert isinstance(query, _FakeQuery) @@ -117,11 +128,12 @@ def test_compose_query_with_filters(monkeypatch): tunable=True, config_type=ConfigType.batch_norm, only_dynamic=True) - dbt = SimpleNamespace(solver_app=SimpleNamespace(config=1, - session=1, - solver=1), - config_table=SimpleNamespace(id=1), - job_table=MagicMock()) + dbt = SimpleNamespace(solver_app=SimpleNamespace(config=_FakeColumn(), + session=_FakeColumn(), + solver=_FakeColumn(), + applicable=_FakeColumn()), + config_table=SimpleNamespace(id=_FakeColumn()), + job_table=SimpleNamespace(__tablename__='job')) session = _FakeSession(result=[(1, 'solver')]) query = load_job.compose_query(args, session, dbt, _FakeQuery(result=[1])) assert isinstance(query, _FakeQuery) @@ -129,7 +141,7 @@ def test_compose_query_with_filters(monkeypatch): def test_add_jobs_empty_results(monkeypatch, caplog): logger = MagicMock() - dbt = SimpleNamespace(job_table=MagicMock()) + dbt = SimpleNamespace(job_table=SimpleNamespace(__tablename__='job')) args = argparse.Namespace(label='lbl', fin_steps=None, session_id=1, From 7c6debd8bd9ede0186e978bda4253978f1b84c2f Mon Sep 17 00:00:00 2001 From: Reid Kawaja <74506315+reidkwja@users.noreply.github.com> Date: Fri, 12 Dec 2025 15:48:11 -0600 Subject: [PATCH 6/6] fixes: merge db --- tests/test_merge_db_branches.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_merge_db_branches.py b/tests/test_merge_db_branches.py index c0c57873..2af444f7 100644 --- a/tests/test_merge_db_branches.py +++ b/tests/test_merge_db_branches.py @@ -55,7 +55,7 @@ def test_is_float(): def test_merge_text_file_copy_only(tmp_path): - master = tmp_path / "gfx900.HIP.fdb.txt" + master = tmp_path / "gfx90040.HIP.fdb.txt" # gfx900 with 0x40 (64 CU) target = tmp_path / "target.HIP.fdb.txt" master.write_text("a=1:1.0\n", encoding="utf-8") target.write_text("b=2:2.0\n", encoding="utf-8") @@ -103,7 +103,7 @@ def test_merge_sqlite_bin_cache(tmp_path): def test_get_file_list_filters(tmp_path): master_dir = tmp_path / "dir" master_dir.mkdir() - (master_dir / "gfx803_36.HIP.fdb.txt").write_text("", encoding="utf-8") + (master_dir / "gfx900_40.HIP.fdb.txt").write_text("", encoding="utf-8") (master_dir / "ignore.txt").write_text("", encoding="utf-8") args = merge_db.argparse.Namespace(master_file=str(master_dir), find_db=True,