diff --git a/Dockerfile b/Dockerfile index b043df81..9e7dd23e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ ARG OSDB_BKC_VERSION= ARG HASVER=${ROCMVERSION:+$ROCMVERSION} ARG HASVER=${HASVER:-$OSDB_BKC_VERSION} -ARG BASEIMAGE=rocm/miopen:ci_3708da +ARG BASEIMAGE=rocm/miopen:ci_7c45f0 ARG UBUNTU=ubuntu:22.04 #use UBUNTU with rocm version set @@ -18,6 +18,8 @@ FROM $USEIMAGE as dtuna-ver-0 #args before from are wiped ARG ROCMVERSION= ARG OSDB_BKC_VERSION= +# pass through baseimage for later use +ARG BASEIMAGE RUN test -d /opt/rocm*; \ if [ $? -eq 0 ] ; then \ @@ -71,7 +73,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --all apt-utils \ build-essential \ cmake \ - clang-format-12 \ + clang-format \ curl \ doxygen \ gdb \ @@ -124,7 +126,9 @@ ARG MIOPEN_DIR=$ROCM_LIBS_DIR/projects/miopen RUN git clone --filter=blob:none --sparse https://github.com/ROCm/rocm-libraries.git $ROCM_LIBS_DIR WORKDIR $MIOPEN_DIR RUN git sparse-checkout set projects/miopen -ARG MIOPEN_BRANCH=4940cf3ec +# not sure what this commit is, using latest develop for now +ARG MIOPEN_BRANCH=5564e20238 +# ARG MIOPEN_BRANCH=develop RUN git pull && git checkout $MIOPEN_BRANCH ARG PREFIX=/opt/rocm @@ -209,3 +213,26 @@ RUN python3 setup.py install # reset WORKDIR to /tuna WORKDIR /tuna + +# save BASEIMAGE as env variable +ENV BASEIMAGE=${BASEIMAGE} + +# install mysql-server and mysql-client +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + mysql-server \ + mysql-client + +# install redis-server +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + redis-server + +# install RabbitMQ server +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + rabbitmq-server + +# install iproute2 +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + iproute2 + +# clean up apt cache +RUN apt-get clean && rm -rf /var/lib/apt/lists/* \ No newline at end of file diff --git a/tests/test_celery.py b/tests/test_celery.py index 1e4aa01d..4e2a20e8 100644 --- a/tests/test_celery.py +++ b/tests/test_celery.py @@ -51,6 +51,7 @@ from tuna.miopen.worker.fin_utils import compose_config_obj, fin_job from tuna.miopen.utils.lib_helper import get_worker + @pytest.mark.asyncio async def test_celery_workers(): miopen = MIOpen() diff --git a/tuna/miopen/miopen_lib.py b/tuna/miopen/miopen_lib.py index a20e1a23..ab55d23c 100644 --- a/tuna/miopen/miopen_lib.py +++ b/tuna/miopen/miopen_lib.py @@ -60,7 +60,8 @@ from tuna.miopen.db.triggers import drop_miopen_triggers, get_miopen_triggers from tuna.miopen.utils.config_type import ConfigType from tuna.miopen.db.tables import MIOpenDBTables -#from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue + +# from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue from tuna.miopen.utils.json_to_sql import process_fdb_w_kernels, process_tuning_data from tuna.miopen.utils.json_to_sql import process_pdb_compile from tuna.miopen.utils.json_to_sql import clean_cache_table @@ -88,130 +89,165 @@ def parse_args(self): # pylint: disable=too-many-statements """Function to parse arguments""" parser = setup_arg_parser( - 'Run Performance Tuning on a certain architecture', [ - TunaArgs.ARCH, TunaArgs.NUM_CU, TunaArgs.VERSION, - TunaArgs.CONFIG_TYPE, TunaArgs.SESSION_ID, TunaArgs.MACHINES, - TunaArgs.REMOTE_MACHINE, TunaArgs.LABEL, TunaArgs.RESTART_MACHINE, - TunaArgs.DOCKER_NAME, TunaArgs.SHUTDOWN_WORKERS, - TunaArgs.ENQUEUE_ONLY - ]) + "Run Performance Tuning on a certain architecture", + [ + TunaArgs.ARCH, + TunaArgs.NUM_CU, + TunaArgs.VERSION, + TunaArgs.CONFIG_TYPE, + TunaArgs.SESSION_ID, + TunaArgs.MACHINES, + TunaArgs.REMOTE_MACHINE, + TunaArgs.LABEL, + TunaArgs.RESTART_MACHINE, + TunaArgs.DOCKER_NAME, + TunaArgs.SHUTDOWN_WORKERS, + TunaArgs.ENQUEUE_ONLY, + ], + ) parser.add_argument( - '--find_mode', - dest='find_mode', + "--find_mode", + dest="find_mode", type=int, default=1, - help='Set the MIOPEN_FIND_MODE environment variable for MIOpen', - choices=['1', '3']) - parser.add_argument('--ticket', - dest='ticket', - type=str, - default=None, - help='Specify tuning ticket number') + help="Set the MIOPEN_FIND_MODE environment variable for MIOpen", + choices=["1", "3"], + ) + parser.add_argument( + "--ticket", + dest="ticket", + type=str, + default=None, + help="Specify tuning ticket number", + ) parser.add_argument( - '--solver_id', + "--solver_id", type=int, - dest='solver_id', + dest="solver_id", default=None, - help='Specify solver_id. Use --list_solvers to see options') - parser.add_argument('--dynamic_solvers_only', - dest='dynamic_solvers_only', - action='store_true', - default=False, - help='Only tune dynamic solvers.') + help="Specify solver_id. Use --list_solvers to see options", + ) + parser.add_argument( + "--dynamic_solvers_only", + dest="dynamic_solvers_only", + action="store_true", + default=False, + help="Only tune dynamic solvers.", + ) parser.add_argument( - '-B', - '--blacklist', - dest='blacklist', + "-B", + "--blacklist", + dest="blacklist", type=str, default=None, - help='MIOpen blacklist algorithm, if multiple then comma separate') - parser.add_argument('-i', - '--reset_interval', - type=int, - dest='reset_interval', - required=False, - help='Restart interval for job in hours.') + help="MIOpen blacklist algorithm, if multiple then comma separate", + ) parser.add_argument( - '--gpu_lim', - dest='gpu_lim', + "-i", + "--reset_interval", + type=int, + dest="reset_interval", + required=False, + help="Restart interval for job in hours.", + ) + parser.add_argument( + "--gpu_lim", + dest="gpu_lim", type=int, default=None, - help='Limit the number of gpu workers created by Tuna, index from 0') + help="Limit the number of gpu workers created by Tuna, index from 0", + ) parser.add_argument( - '-R', - '--rich_data', - dest='rich_data', - action='store_true', + "-R", + "--rich_data", + dest="rich_data", + action="store_true", default=False, - help='record intermediate parameter results from perf tuning') + help="record intermediate parameter results from perf tuning", + ) subcommands = parser.add_subcommands(required=False) - subcommands.add_subcommand('import_configs', + subcommands.add_subcommand("import_configs", get_import_cfg_parser(), required=False) - subcommands.add_subcommand('load_job', + subcommands.add_subcommand("load_job", get_load_job_parser(), required=False) - subcommands.add_subcommand('export_db', + subcommands.add_subcommand("export_db", get_export_db_parser(), required=False) - subcommands.add_subcommand('update_golden', + subcommands.add_subcommand("update_golden", get_update_golden_parser(), required=False) group = parser.add_mutually_exclusive_group() - group.add_argument('--add_tables', - dest='add_tables', - action='store_true', - help='Add MIOpen library specific tables') - - group.add_argument('--init_session', - action='store_true', - dest='init_session', - help='Set up a new tuning session.') group.add_argument( - '--fin_steps', + "--add_tables", + dest="add_tables", + action="store_true", + help="Add MIOpen library specific tables", + ) + + group.add_argument( + "--init_session", + action="store_true", + dest="init_session", + help="Set up a new tuning session.", + ) + group.add_argument( + "--fin_steps", type=str, - dest='fin_steps', - help='Specify fin steps. Multiple steps should be comma separated.') - group.add_argument('--list_solvers', - action='store_true', - dest='list_solvers', - help='List of solvers from the solver table') + dest="fin_steps", + help="Specify fin steps. Multiple steps should be comma separated.", + ) + group.add_argument( + "--list_solvers", + action="store_true", + dest="list_solvers", + help="List of solvers from the solver table", + ) # JD: implement the following two using fin_steps - group.add_argument('--update_solvers', - dest='update_solvers', - action='store_true', - help='Update the list of solvers in the database') - group.add_argument('--update_applicability', - dest='update_applicability', - action='store_true', - help='Update the applicability table in the database') - group.add_argument('-s', - '--status', - dest='check_status', - action='store_true', - default=False, - help='Check the status of machines') - - group.add_argument('-e', - '--exec', - dest='execute_cmd', - type=str, - default=None, - help='execute on each machine') + group.add_argument( + "--update_solvers", + dest="update_solvers", + action="store_true", + help="Update the list of solvers in the database", + ) + group.add_argument( + "--update_applicability", + dest="update_applicability", + action="store_true", + help="Update the applicability table in the database", + ) + group.add_argument( + "-s", + "--status", + dest="check_status", + action="store_true", + default=False, + help="Check the status of machines", + ) + + group.add_argument( + "-e", + "--exec", + dest="execute_cmd", + type=str, + default=None, + help="execute on each machine", + ) self.args = parser.parse_args() if self.args.config_type is None: self.args.config_type = ConfigType.convolution - #overwritte common lib args with subcommand args value + # overwritte common lib args with subcommand args value if self.args.subcommand is not None: self.overwrite_common_args() @@ -221,16 +257,16 @@ def parse_args(self): if self.args.list_solvers: print_solvers() - raise CustomError('Printing solvers...') + raise CustomError("Printing solvers...") - if self.args.fin_steps and self.args.subcommand != 'load_job': + if self.args.fin_steps and self.args.subcommand != "load_job": self.check_fin_args(parser) self.set_prefix() if self.args.find_mode is None and not (self.args.check_status or self.args.restart_machine or self.args.execute_cmd): - parser.error('find_mode must be specified for a tuning run') + parser.error("find_mode must be specified for a tuning run") if self.args.blacklist: self.check_blacklist(parser) @@ -238,8 +274,13 @@ def parse_args(self): args_check(self.args, parser) fin_session_steps = [ - 'miopen_find_compile', 'miopen_find_eval', 'miopen_perf_compile', - 'miopen_perf_eval', 'get_applicability', 'find_compile', 'find_eval' + "miopen_find_compile", + "miopen_find_eval", + "miopen_perf_compile", + "miopen_perf_eval", + "get_applicability", + "find_compile", + "find_eval", ] has_fin = False if self.args.fin_steps: @@ -255,14 +296,14 @@ def parse_args(self): def set_prefix(self): """Set redis key prefix""" if isinstance(self.args.fin_steps, Iterable): - steps_str = ('-').join(x for x in self.args.fin_steps) - self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_"\ - f"{steps_str}" + steps_str = ("-").join(x for x in self.args.fin_steps) + self.prefix = (f"d_{self.db_name}_sess_{self.args.session_id}_" + f"{steps_str}") else: steps_str = self.args.fin_steps[0] self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_{steps_str}" - self.logger.info('redis prefix: %s', self.prefix) + self.logger.info("redis prefix: %s", self.prefix) def overwrite_common_args(self): """Overwrite common MIOpen_lib args with subcommand args""" @@ -274,12 +315,12 @@ def overwrite_common_args(self): def check_fin_args(self, parser): """! Helper function for fin args - @param parser The command line argument parser + @param parser The command line argument parser """ valid_fin_steps = list(k for k in FinStep.__members__) - if ',' in self.args.fin_steps: - parser.error('Multiple fin_steps currently not supported') - f_steps = self.args.fin_steps.split(',') + if "," in self.args.fin_steps: + parser.error("Multiple fin_steps currently not supported") + f_steps = self.args.fin_steps.split(",") self.args.fin_steps = f_steps for step in self.args.fin_steps: if step not in valid_fin_steps: @@ -288,40 +329,40 @@ def check_fin_args(self, parser): def check_blacklist(self, parser): """! Helper function - @param parser The command line argument parser - @return ret Boolean value - """ - self.args.blacklist = self.args.blacklist.split(',') + @param parser The command line argument parser + @return ret Boolean value + """ + self.args.blacklist = self.args.blacklist.split(",") for sol in self.args.blacklist: if sol not in MIOPEN_ALG_LIST: parser.error("Incorrect blacklist value") def do_fin_work(self, gpu, f_vals): """! Helper function to execute job independendent fin work - @param gpu Unique ID of the GPU - @param f_vals Dict containing runtime information - """ + @param gpu Unique ID of the GPU + @param f_vals Dict containing runtime information + """ kwargs = self.get_kwargs(gpu, f_vals) fin_worker = FinClass(**kwargs) if self.args.update_solvers: if not fin_worker.get_solvers(): - self.logger.error('No solvers returned from Fin class') + self.logger.error("No solvers returned from Fin class") return True def launch_worker(self, gpu_idx, f_vals, worker_lst): """! Function to launch worker - @param gpu_idx Unique ID of the GPU - @param f_vals Dict containing runtime information - @param worker_lst List containing worker instances - @return ret Boolean value - """ + @param gpu_idx Unique ID of the GPU + @param f_vals Dict containing runtime information + @param worker_lst List containing worker instances + @return ret Boolean value + """ # pylint: disable=too-many-branches worker = None kwargs = self.get_kwargs(gpu_idx, f_vals) if self.args.update_applicability: - kwargs['fin_steps'] = ['applicability'] + kwargs["fin_steps"] = ["applicability"] worker = FinClass(**kwargs) worker.start() worker_lst.append(worker) @@ -330,8 +371,13 @@ def launch_worker(self, gpu_idx, f_vals, worker_lst): worker = FinClass(**kwargs) ret = False if self.args.check_status: - if not super().check_status(worker, f_vals["b_first"], gpu_idx, - f_vals["machine"], self.args.docker_name): + if not super().check_status( + worker, + f_vals["b_first"], + gpu_idx, + f_vals["machine"], + self.args.docker_name, + ): ret = True elif self.args.init_session: Session().add_new_session(self.args, worker) @@ -345,8 +391,8 @@ def launch_worker(self, gpu_idx, f_vals, worker_lst): def compose_worker_list(self, machines): # pylint: disable=too-many-branches """! Helper function to compose worker_list - @param machines List of machines to execute on - """ + @param machines List of machines to execute on + """ worker_lst = [] fin_work_done = False for machine in machines: @@ -354,9 +400,9 @@ def compose_worker_list(self, machines): machine.restart_server(wait=False) continue - #fin_steps should only contain one step + # fin_steps should only contain one step worker_ids = None - if self.args.fin_steps and 'eval' in self.args.fin_steps[0]: + if self.args.fin_steps and "eval" in self.args.fin_steps[0]: worker_ids = machine.get_avail_gpus() if self.args.gpu_lim and self.args.gpu_lim < len(worker_ids): worker_ids = range(self.args.gpu_lim) @@ -366,7 +412,7 @@ def compose_worker_list(self, machines): if self.args.update_applicability: f_vals = super().get_f_vals(machine, [1]) kwargs = self.get_kwargs(0, f_vals) - kwargs['fin_steps'] = ['applicability'] + kwargs["fin_steps"] = ["applicability"] worker = FinClass(**kwargs) query = worker.query_cfgs(self.args.label) cfg_rows = query.all() @@ -388,7 +434,7 @@ def compose_worker_list(self, machines): break for gpu_idx in worker_ids: - self.logger.info('launch mid %u, proc %u', machine.id, gpu_idx) + self.logger.info("launch mid %u, proc %u", machine.id, gpu_idx) if not self.launch_worker(gpu_idx, f_vals, worker_lst): break @@ -396,10 +442,10 @@ def compose_worker_list(self, machines): def add_tables(self): """! Function to create new DB tables - @return Bool - """ + @return Bool + """ ret_t = create_tables(get_miopen_tables()) - self.logger.info('DB creation successful: %s', ret_t) + self.logger.info("DB creation successful: %s", ret_t) recreate_triggers(drop_miopen_triggers(), get_miopen_triggers()) return True @@ -414,19 +460,20 @@ def run(self): self.add_tables() return None - if self.args.subcommand is not None and self.args.subcommand == 'import_configs': + if (self.args.subcommand is not None and + self.args.subcommand == "import_configs"): run_import_configs(self.args.import_configs, self.logger) return None - if self.args.subcommand is not None and self.args.subcommand == 'load_job': + if self.args.subcommand is not None and self.args.subcommand == "load_job": run_load_job(self.args.load_job, self.logger) return None - if self.args.subcommand is not None and self.args.subcommand == 'export_db': + if self.args.subcommand is not None and self.args.subcommand == "export_db": run_export_db(self.args.export_db, self.logger) return None - if self.args.subcommand is not None and self.args.subcommand == 'update_golden': + if self.args.subcommand is not None and self.args.subcommand == "update_golden": run_update_golden(self.args.update_golden, self.logger) return None @@ -435,8 +482,7 @@ def run(self): return res def get_envmt(self): - """! Function to construct environment var - """ + """! Function to construct environment var""" envmt = ["MIOPEN_LOG_LEVEL=4"] envmt.append("MIOPEN_SQLITE_KERN_CACHE=ON") @@ -447,58 +493,66 @@ def get_envmt(self): if self.args.blacklist: bk_str = ", ".join([f"{arg}=0" for arg in self.args.blacklist]) - for bk_var in bk_str.split(','): + for bk_var in bk_str.split(","): envmt.append(bk_var) return envmt def get_kwargs(self, gpu_idx, f_vals, tuning=False): """! Helper function to set up kwargs for worker instances - @param gpu_idx Unique ID of the GPU - @param f_vals Dict containing runtime information - @param tuning Boolean that indicates if kwargs are for a tuning step - @return kwargs Dictionary - """ + @param gpu_idx Unique ID of the GPU + @param f_vals Dict containing runtime information + @param tuning Boolean that indicates if kwargs are for a tuning step + @return kwargs Dictionary + """ kwargs = super().get_kwargs(gpu_idx, f_vals, tuning) - kwargs['fin_steps'] = self.args.fin_steps - kwargs['dynamic_solvers_only'] = self.args.dynamic_solvers_only - kwargs['config_type'] = self.args.config_type - kwargs['reset_interval'] = self.args.reset_interval + kwargs["fin_steps"] = self.args.fin_steps + kwargs["dynamic_solvers_only"] = self.args.dynamic_solvers_only + kwargs["config_type"] = self.args.config_type + kwargs["reset_interval"] = self.args.reset_interval return kwargs def get_job_list(self, session, find_state, claim_num): """! Get list of jobs - @param session DB session - @param find_state DB job state - @param claim_num Number of DB jobs to pick up - @return List of DB jobs + @param session DB session + @param find_state DB job state + @param claim_num Number of DB jobs to pick up + @return List of DB jobs - """ - job_list = self.get_job_objs(session, find_state, self.args.label, self.dbt, - self.get_job_attr(), claim_num, - self.args.fin_steps) + """ + job_list = self.get_job_objs( + session, + find_state, + self.args.label, + self.dbt, + self.get_job_attr(), + claim_num, + self.args.fin_steps, + ) return job_list - def get_job_objs(self, - session: DbSession, - find_state: list, - label: str, - dbt: DBTablesInterface, - job_attr: List[str], - claim_num: int = None, - fin_steps: List[str] = None) -> List[SimpleDict]: + def get_job_objs( + self, + session: DbSession, + find_state: list, + label: str, + dbt: DBTablesInterface, + job_attr: List[str], + claim_num: int = None, + fin_steps: List[str] = None, + ) -> List[SimpleDict]: """! Get list of job objects - @param session DB session - @param find_state DB job state - @param label DB job reason - @param dbt Class representing all DB tables associated with this class - @param job_attr List of DB job columns - @param claim_num Number of DB jobs to pick up - @param fin_steps List of MIFin steps - @return List of DB jobs - """ + @param session DB session + @param find_state DB job state + @param label DB job reason + @param dbt Class representing all DB tables associated with this class + @param job_attr List of DB job columns + @param claim_num Number of DB jobs to pick up + @param fin_steps List of MIFin steps + @return List of DB jobs + """ entries: List[Tuple[SimpleDict, ...]] conds: List[str] = [f"session={dbt.session.id}", "valid=1"] @@ -506,38 +560,42 @@ def get_job_objs(self, conds.append(f"reason='{label}'") conds.append(f"retries<{self.max_job_retries}") - conds.append("state in (" + str(find_state).strip('{').strip('}') + ")") + conds.append("state in (" + str(find_state).strip("{").strip("}") + ")") entries = self.compose_work_objs(session, conds, dbt, job_attr, claim_num, fin_steps) return entries - def compose_work_objs(self, - session: DbSession, - conds: List[str], - dbt: DBTablesInterface, - job_attr: List[str], - claim_num: int = None, - fin_steps: List[str] = None) -> List[SimpleDict]: + def compose_work_objs( + self, + session: DbSession, + conds: List[str], + dbt: DBTablesInterface, + job_attr: List[str], + claim_num: int = None, + fin_steps: List[str] = None, + ) -> List[SimpleDict]: """! Query a job list for update - @param session DB session - @param conds List of conditions for DB job WHERE clause - @param dbt Class representing all DB tables associated with this class - @param job_attr List of DB job columns - @param fin_steps List of MIFin steps - @return List of MIFin work objects - """ + @param session DB session + @param conds List of conditions for DB job WHERE clause + @param dbt Class representing all DB tables associated with this class + @param job_attr List of DB job columns + @param fin_steps List of MIFin steps + @return List of MIFin work objects + """ job_entries = [] if fin_steps: conds.append(f"fin_step like '%{fin_steps[0]}%'") else: conds.append("fin_step='not_fin'") - cond_str = ' AND '.join(conds) + cond_str = " AND ".join(conds) if cond_str: cond_str = f"WHERE {cond_str}" if claim_num: - cond_str += f" ORDER BY retries,config ASC LIMIT {claim_num} FOR UPDATE SKIP LOCKED" + cond_str += ( + f" ORDER BY retries,config ASC LIMIT {claim_num} FOR UPDATE SKIP LOCKED" + ) else: cond_str += " ORDER BY retries,config ASC FOR UPDATE SKIP LOCKED" @@ -549,23 +607,23 @@ def compose_work_objs(self, def compose_work_objs_fin(self, session, job_entries, dbt) -> List[Tuple[SimpleDict, SimpleDict]]: """! Return jobs for fin work - @param session DB session - @param job_entries List of DB jobs - @param dbt Class representing all DB tables associated with this class - @return ret Job tuple - """ + @param session DB session + @param job_entries List of DB jobs + @param dbt Class representing all DB tables associated with this class + @return ret Job tuple + """ ret = [] cfg_rel = { key: { - 'key': list(val.local_columns)[0].name, - 'ftble': str(list(val.remote_side)[0]).split('.', maxsplit=1)[0], - 'fkey': str(list(val.remote_side)[0]).split('.')[1] + "key": list(val.local_columns)[0].name, + "ftble": str(list(val.remote_side)[0]).split(".", maxsplit=1)[0], + "fkey": str(list(val.remote_side)[0]).split(".")[1], } for key, val in inspect(dbt.config_table).relationships.items() } if job_entries: - id_str = ','.join({str(job.config) for job in job_entries}) + id_str = ",".join({str(job.config) for job in job_entries}) cfg_cond_str = f"where valid=1 and id in ({id_str})" cfg_attr = [column.name for column in inspect(dbt.config_table).c] cfg_entries = gen_select_objs(session, cfg_attr, @@ -583,30 +641,32 @@ def compose_work_objs_fin(self, session, job_entries, def attach_tensors(self, session, cfg_rel, cfg_entries): """! Attach tensor relationship information to config entries - @param session DB session - @param cfg_rel DB Config col value - @param cfg_entries List of DB Config entries - @return cfg_entries List of DB Config entries with attached tensors (foreign keys) + @param session DB session + @param cfg_rel DB Config col value + @param cfg_entries List of DB Config entries + @return cfg_entries List of DB Config entries with attached tensors (foreign keys) - """ + """ for key, val in cfg_rel.items(): rel_attr = [ column.name - for column in inspect(get_class_by_tablename(val['ftble'])).c + for column in inspect(get_class_by_tablename(val["ftble"])).c ] - val['fattr'] = rel_attr + val["fattr"] = rel_attr for cfg in cfg_entries: for key, val in cfg_rel.items(): - rel_val = getattr(cfg, val['key']) + rel_val = getattr(cfg, val["key"]) rel_cond_str = f"where {val['fkey']}={rel_val}" setattr( - cfg, key, - gen_select_objs(session, val['fattr'], val['ftble'], - rel_cond_str)[0]) + cfg, + key, + gen_select_objs(session, val["fattr"], val["ftble"], + rel_cond_str)[0], + ) return cfg_entries - #deprecated + # deprecated def get_job_tables(self, job_rows: List[Tuple[SimpleDict, ...]], job_attr: List[str]) -> List[SimpleDict]: """Find job tables in query results""" @@ -626,15 +686,16 @@ def get_job_tables(self, job_rows: List[Tuple[SimpleDict, ...]], def update_operation(self): """! Update the workers type that this library needs""" if self.args.fin_steps: - if 'miopen_find_compile' in self.args.fin_steps \ - or 'miopen_perf_compile' in self.args.fin_steps: - self.fetch_state.add('new') - self.set_state = 'compile_start' + if ("miopen_find_compile" in self.args.fin_steps or + "miopen_perf_compile" in self.args.fin_steps): + self.fetch_state.add("new") + self.set_state = "compile_start" self.operation = Operation.COMPILE - elif 'miopen_find_eval' in self.args.fin_steps or 'miopen_perf_eval' in self.args.fin_steps: - self.fetch_state.add('new') - self.fetch_state.add('compiled') - self.set_state = 'eval_start' + elif ("miopen_find_eval" in self.args.fin_steps or + "miopen_perf_eval" in self.args.fin_steps): + self.fetch_state.add("new") + self.fetch_state.add("compiled") + self.set_state = "eval_start" self.operation = Operation.EVAL if self.args.update_applicability: @@ -642,8 +703,8 @@ def update_operation(self): def has_tunable_operation(self): """! Check if its a tuning loop operation - @return Bool value that represents if operation is tuning - """ + @return Bool value that represents if operation is tuning + """ if self.args is None: self.parse_args() if self.args.subcommand and "load_job" in self.args.subcommand: @@ -657,8 +718,8 @@ def has_tunable_operation(self): @lru_cache(1) def get_fdb_attr(self): """! Get find_db table attrs - @return fdb_attr find_db table attributes without timestamps - """ + @return fdb_attr find_db table attributes without timestamps + """ fdb_attr = None fdb_attr = [column.name for column in inspect(self.dbt.find_db_table).c] fdb_attr.remove("insert_ts") @@ -668,8 +729,8 @@ def get_fdb_attr(self): @lru_cache(1) def get_tuning_data_attr(self): """! Get tuning_data table attrs - @return tuning_data_attr tuning_data table attributes without timestamps - """ + @return tuning_data_attr tuning_data table attributes without timestamps + """ tuning_data_attr = None tuning_data_attr = [ column.name for column in inspect(self.dbt.tuning_data_table).c @@ -680,10 +741,10 @@ def get_tuning_data_attr(self): def serialize_jobs(self, session: DbSession, batch_jobs: List[Any]): """! Return list of serialize jobs - @param session DB session - @param batch_jobs List of DB jobs - @return DB jobs, serialized - """ + @param session DB session + @param batch_jobs List of DB jobs + @return DB jobs, serialized + """ entries = self.compose_work_objs_fin(session, batch_jobs, self.dbt) return serialize_chunk(entries) @@ -696,15 +757,15 @@ def build_context( tuning_data_attr = self.get_tuning_data_attr() for job, config in serialized_jobs: context = { - 'job': job, - 'config': config, - 'operation': self.operation, - 'arch': self.dbt.session.arch, - 'num_cu': self.dbt.session.num_cu, - 'kwargs': kwargs, - 'rich_data': self.args.rich_data, - 'fdb_attr': fdb_attr, - 'tuning_data_attr': tuning_data_attr + "job": job, + "config": config, + "operation": self.operation, + "arch": self.dbt.session.arch, + "num_cu": self.dbt.session.num_cu, + "kwargs": kwargs, + "rich_data": self.args.rich_data, + "fdb_attr": fdb_attr, + "tuning_data_attr": tuning_data_attr, } context_list.append(context) @@ -712,48 +773,55 @@ def build_context( def celery_enqueue_call(self, context: dict, q_name: str, task_id=False): """! Enqueue job (context) for queue:q_name - @param context Context for Celery job - @param q_name Custom Celery queue name - @param task_id Custom Redis Key - """ - - #hacky way to get the Q_NAME to the task decorator for interpreter to decorate the - #function with correct q_name arg - #if import is moved to top it will result in circular imports - Q_NAME = q_name #pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name - from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue #pylint: disable=import-outside-toplevel - - return celery_enqueue.apply_async((context,), - task_id=('-').join([self.prefix, - uuid()]), - queue=q_name, - reply_to=q_name) + @param context Context for Celery job + @param q_name Custom Celery queue name + @param task_id Custom Redis Key + """ + + # hacky way to get the Q_NAME to the task decorator for interpreter to decorate the + # function with correct q_name arg + # if import is moved to top it will result in circular imports + Q_NAME = q_name # pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name + from tuna.miopen.celery_tuning.celery_tasks import ( + celery_enqueue,) # pylint: disable=import-outside-toplevel + + return celery_enqueue.apply_async( + (context,), + task_id=("-").join([self.prefix, uuid()]), + queue=q_name, + reply_to=q_name, + ) def process_compile_results(self, session, fin_json, context): """! Process result from fin_build worker - @param session DB session - @param fin_json MIFin results for job - @param context Context for Celery job - @return Boolean value - """ - job = SimpleDict(**context['job']) + @param session DB session + @param fin_json MIFin results for job + @param context Context for Celery job + @return Boolean value + """ + job = SimpleDict(**context["job"]) pending = [] solver_id_map = get_solver_ids() failed_job = False - result_str = '' + result_str = "" status = None try: if fin_json: - if 'success' in fin_json and fin_json["success"] is False: + if "success" in fin_json and fin_json["success"] is False: status = [fin_json] else: - if 'miopen_find_compile_result' in fin_json: - status = process_fdb_w_kernels(session, fin_json, - copy.deepcopy(context), self.dbt, - context['fdb_attr'], pending) - - elif 'miopen_perf_compile_result' in fin_json: + if "miopen_find_compile_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + ) + + elif "miopen_perf_compile_result" in fin_json: status = process_pdb_compile(session, fin_json, job, self.dbt, solver_id_map) @@ -761,22 +829,22 @@ def process_compile_results(self, session, fin_json, context): failed_job = not success except (OperationalError, IntegrityError) as err: - self.logger.warning('FinBuild: Unable to update Database %s', err) + self.logger.warning("FinBuild: Unable to update Database %s", err) session.rollback() failed_job = True except DataError as err: self.logger.warning( - 'FinBuild: Invalid data, likely large workspace. DB Error: %s', err) + "FinBuild: Invalid data, likely large workspace. DB Error: %s", err) session.rollback() failed_job = True if failed_job: - set_job_state(session, job, self.dbt, 'errored', False, result=result_str) + set_job_state(session, job, self.dbt, "errored", False, result=result_str) else: set_job_state(session, job, self.dbt, - 'compiled', + "compiled", False, result=result_str) @@ -784,73 +852,89 @@ def process_compile_results(self, session, fin_json, context): def process_eval_results(self, session, fin_json, context): """! Process fin_json result - @param session DB session - @param fin_json MIFin results for job - @param context Context for Celery job - @return Boolean value - """ - job = SimpleDict(**context['job']) + @param session DB session + @param fin_json MIFin results for job + @param context Context for Celery job + @return Boolean value + """ + job = SimpleDict(**context["job"]) failed_job = True - result_str = '' + result_str = "" pending = [] - orig_state = 'compiled' + orig_state = "compiled" try: if fin_json: - if 'success' in fin_json and fin_json["success"] is False: + if "success" in fin_json and fin_json["success"] is False: status = [fin_json] else: - if 'miopen_find_eval_result' in fin_json: - status = process_fdb_w_kernels(session, - fin_json, - copy.deepcopy(context), - self.dbt, - context['fdb_attr'], - pending, - result_str='miopen_find_eval_result', - check_str='evaluated') - elif 'miopen_perf_eval_result' in fin_json: - status = process_fdb_w_kernels(session, - fin_json, - copy.deepcopy(context), - self.dbt, - context['fdb_attr'], - pending, - result_str='miopen_perf_eval_result', - check_str='evaluated') + if "miopen_find_eval_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + result_str="miopen_find_eval_result", + check_str="evaluated", + ) + elif "miopen_perf_eval_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + result_str="miopen_perf_eval_result", + check_str="evaluated", + ) if context["rich_data"]: - status = process_tuning_data(session, - fin_json, - copy.deepcopy(context), - self.dbt, - context['tuning_data_attr'], - pending, - result_str='miopen_perf_eval_result', - check_str='evaluated') + status = process_tuning_data( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["tuning_data_attr"], + pending, + result_str="miopen_perf_eval_result", + check_str="evaluated", + ) success, result_str = get_fin_result(status) failed_job = not success if failed_job: - if job.retries >= (MAX_ERRORED_JOB_RETRIES - 1): #pylint: disable=no-member - self.logger.warning('max job retries exhausted, setting to errored') - set_job_state(session, job, self.dbt, 'errored', result=result_str) + if job.retries >= (MAX_ERRORED_JOB_RETRIES - 1): # pylint: disable=no-member + self.logger.warning("max job retries exhausted, setting to errored") + set_job_state(session, job, self.dbt, "errored", result=result_str) else: - self.logger.warning('resetting job state to %s, incrementing retries', + self.logger.warning("resetting job state to %s, incrementing retries", orig_state) - set_job_state(session, - job, - self.dbt, - orig_state, - increment_retries=True, - result=result_str) + set_job_state( + session, + job, + self.dbt, + orig_state, + increment_retries=True, + result=result_str, + ) else: self.logger.info("\n\n Setting job state to evaluated") - set_job_state(session, job, self.dbt, 'evaluated', result=result_str) + set_job_state(session, job, self.dbt, "evaluated", result=result_str) clean_cache_table(self.dbt, job) except (OperationalError, IntegrityError) as err: - self.logger.warning('FinBuild: Unable to update Database %s', err) + self.logger.warning("FinBuild: Unable to update Database %s", err) session.rollback() - set_job_state(session, job, self.dbt, 'errored', result=result_str) + set_job_state(session, job, self.dbt, "errored", result=result_str) return True + + def extract_job_id_from_context(self, context): + """Extract job ID from MIOpen celery task context""" + try: + # Extract job ID from the job context + return context.get("job", {}).get("id") + except (AttributeError, KeyError): + return None diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 140d1542..f383a106 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -61,9 +61,9 @@ job_counter_lock = threading.Lock() -class MITunaInterface(): #pylint:disable=too-many-instance-attributes,too-many-public-methods - """ Interface class extended by libraries. The purpose of this class is to define - common functionalities. """ +class MITunaInterface: # pylint:disable=too-many-instance-attributes,too-many-public-methods + """Interface class extended by libraries. The purpose of this class is to define + common functionalities.""" def __init__(self, library=Library.MIOPEN) -> None: @@ -77,16 +77,22 @@ def __init__(self, library=Library.MIOPEN) -> None: self.max_job_retries = 10 self.dbt = None self.operation = None - self.db_name = os.environ['TUNA_DB_NAME'] + self.db_name = os.environ["TUNA_DB_NAME"] self.prefix = None + # Track jobs claimed by this specific instance when in distributor mode + self.claimed_job_ids = set() + self.completed_job_ids = set() + # if less than 25% of the jobs are remaining, we can grab more jobs + self.progress_factor = 0.25 + def check_docker(self, worker: WorkerInterface, dockername="miopentuna") -> bool: """! Checking for docker - @param worker The worker interface instance - @param dockername The name of the docker - """ + @param worker The worker interface instance + @param dockername The name of the docker + """ out2: ChannelFile _, out2, _ = worker.exec_command("sudo docker info") while not out2.channel.exit_status_ready(): @@ -102,34 +108,44 @@ def check_docker(self, for line in out.readlines(): if line is not None: if line.find(dockername) != -1: - self.logger.warning('%s docker image exists', dockername) + self.logger.warning("%s docker image exists", dockername) return True if line is None: - self.logger.warning('%s docker image does not exist', dockername) + self.logger.warning("%s docker image does not exist", dockername) return False return False - def check_status(self, - worker: WorkerInterface, - b_first: int, - gpu_idx: int, - machine: Machine, - dockername: str = "miopentuna") -> bool: + def check_status( + self, + worker: WorkerInterface, + b_first: int, + gpu_idx: int, + machine: Machine, + dockername: str = "miopentuna", + ) -> bool: """! Function to check gpu_status - @param worker The worker interface instance - @param b_first Flag to keep track of visited GPU - @param gpu_idx Unique ID of the GPU - @param machine The machine instance - @param dockername The name of the docker - """ + @param worker The worker interface instance + @param b_first Flag to keep track of visited GPU + @param gpu_idx Unique ID of the GPU + @param machine The machine instance + @param dockername The name of the docker + """ if machine.chk_gpu_status(worker.gpu_id): - self.logger.info('Machine: (%s, %u) GPU_ID: %u OK', machine.hostname, - machine.port, gpu_idx) + self.logger.info( + "Machine: (%s, %u) GPU_ID: %u OK", + machine.hostname, + machine.port, + gpu_idx, + ) else: - self.logger.info('Machine: (%s, %u) GPU_ID: %u ERROR', machine.hostname, - machine.port, gpu_idx) + self.logger.info( + "Machine: (%s, %u) GPU_ID: %u ERROR", + machine.hostname, + machine.port, + gpu_idx, + ) if not b_first: return False @@ -146,10 +162,10 @@ def check_status(self, for line in out.readlines(): if line is not None: if line.find(dockername) != -1: - self.logger.warning('%s docker image exists', dockername) + self.logger.warning("%s docker image exists", dockername) break else: - self.logger.warning('%s docker image does not exist', dockername) + self.logger.warning("%s docker image does not exist", dockername) return True @@ -163,16 +179,16 @@ def get_num_procs(self, machine: Machine) -> List: num_procs: int env: Dict[str, Any] env = get_env_vars() - if env['slurm_cpus'] > 0: - num_procs = int(env['slurm_cpus']) + if env["slurm_cpus"] > 0: + num_procs = int(env["slurm_cpus"]) else: - num_procs = int(machine.get_num_cpus() * .6) + num_procs = int(machine.get_num_cpus() * 0.6) worker_ids = list(range(num_procs)) if len(worker_ids) == 0: - self.logger.error('num_procs must be bigger than zero to launch worker') - self.logger.error('Cannot launch worker on machine: %s', machine.id) + self.logger.error("num_procs must be bigger than zero to launch worker") + self.logger.error("Cannot launch worker on machine: %s", machine.id) worker_ids = [] return worker_ids @@ -181,14 +197,14 @@ def get_f_vals(self, machine: Machine, worker_ids: range, tuning=False) -> Dict[str, Any]: - #pylint:disable=unused-argument + # pylint:disable=unused-argument """Determine kwargs for worker_interface""" f_vals: Dict[str, Any] f_vals = self.compose_f_vals(machine) - f_vals['envmt'] = self.get_envmt() + f_vals["envmt"] = self.get_envmt() if not tuning: - f_vals["num_procs"] = Value('i', len(worker_ids)) + f_vals["num_procs"] = Value("i", len(worker_ids)) return f_vals @@ -198,20 +214,20 @@ def get_envmt(self): def compose_f_vals(self, machine: Machine, tuning=False) -> Dict[str, Any]: """! Compose dict for WorkerInterface constructor - @param args The command line arguments - @param machine Machine instance - """ + @param args The command line arguments + @param machine Machine instance + """ f_vals: Dict[str, Any] = {} f_vals["b_first"] = True - #adding non-serializable obj when not running through celery + # adding non-serializable obj when not running through celery if not tuning: f_vals["machine"] = machine f_vals["bar_lock"] = Lock() - #multiprocess queue for jobs, shared on machine + # multiprocess queue for jobs, shared on machine f_vals["job_queue"] = mpQueue() f_vals["job_queue_lock"] = Lock() - f_vals["end_jobs"] = Value('i', 0) + f_vals["end_jobs"] = Value("i", 0) return f_vals @@ -220,21 +236,21 @@ def get_kwargs(self, f_vals: Dict[str, Any], tuning=False) -> Dict[str, Any]: """! Helper function to set up kwargs for worker instances - @param gpu_idx Unique ID of the GPU - @param f_vals Dict containing runtime information - """ + @param gpu_idx Unique ID of the GPU + @param f_vals Dict containing runtime information + """ envmt: Dict[str, Any] = f_vals["envmt"].copy() kwargs: Dict[str, Any] = {} kwargs = { - 'gpu_id': gpu_idx, - 'envmt': envmt, - 'label': self.args.label, - 'docker_name': self.args.docker_name, - 'session_id': self.args.session_id + "gpu_id": gpu_idx, + "envmt": envmt, + "label": self.args.label, + "docker_name": self.args.docker_name, + "session_id": self.args.session_id, } - #adding non-serializable obj when not running through celery + # adding non-serializable obj when not running through celery if not tuning: kwargs["machine"] = f_vals["machine"] kwargs["job_queue"] = f_vals["job_queue"] @@ -251,19 +267,21 @@ def get_job_list(self, session, find_state, claim_num): """Get list of jobs""" raise NotImplementedError("Not implemented") - def get_jobs(self, - session: DbSession, - find_state: List[str], - set_state: str, - session_id: int, - claim_num: int = None, - no_update=False): + def get_jobs( + self, + session: DbSession, + find_state: List[str], + set_state: str, + session_id: int, + claim_num: int = None, + no_update=False, + ): """Interface function to get jobs based on session and find_state""" - #job_rows: List[SimpleDict] + # job_rows: List[SimpleDict] ids: list row: SimpleDict - self.logger.info('Fetching DB rows...') + self.logger.info("Fetching DB rows...") job_list = self.get_job_list(session, find_state, claim_num) if not self.check_jobs_found(job_list, find_state, session_id): @@ -274,16 +292,24 @@ def get_jobs(self, ids = [row.id for row in job_list] self.logger.info("%s jobs %s", find_state, ids) - self.logger.info('Updating job state to %s', set_state) - for job in job_list: - job.state = set_state - if self.dbt is not None: - query: str = gen_update_query(job, ['state'], - self.dbt.job_table.__tablename__) - else: - raise CustomError('DBTable must be set') + self.logger.info("Updating job state to %s", set_state) + + # OPTIMIZATION: Use bulk UPDATE instead of individual updates + if self.dbt is not None: + id_str = ','.join(map(str, ids)) + query = f""" + UPDATE {self.dbt.job_table.__tablename__} + SET state = '{set_state}' + WHERE id IN ({id_str}) + """ session.execute(query) + # Update local objects to reflect new state + for job in job_list: + job.state = set_state + else: + raise CustomError("DBTable must be set") + session.commit() return job_list @@ -295,68 +321,171 @@ def shutdown_workers(self): def cancel_consumer(self, queue): """Cancel consumers for queue""" try: - cmd = f"celery -A tuna.celery_app.celery_app control cancel_consumer {queue}" - subp = subprocess.Popen( #pylint: disable=consider-using-with + cmd = ( + f"celery -A tuna.celery_app.celery_app control cancel_consumer {queue}" + ) + subp = subprocess.Popen( # pylint: disable=consider-using-with cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True, - universal_newlines=True) + universal_newlines=True, + ) - #filter the workers by session id - sess_str = "sess_" + queue.split('_')[-1] + # filter the workers by session id + sess_str = "sess_" + queue.split("_")[-1] stdout, _ = subp.stdout, subp.stderr while True: line = stdout.readline() if not line: break - #stop workers that were feeding from this queue + # stop workers that were feeding from this queue if "->" in line and sess_str in line: - hostname = line.split('->')[1].split()[0].split(':')[0] + hostname = line.split("->")[1].split()[0].split(":")[0] stop_named_worker(hostname) - except Exception as exp: #pylint: disable=broad-exception-caught + except Exception as exp: # pylint: disable=broad-exception-caught self.logger.warning( - 'Error occurred trying to cancel consumer for queue: %s ', queue) + "Error occurred trying to cancel consumer for queue: %s ", queue) self.logger.warning(exp) return False - self.logger.info('Sucessfully cancelled consumer for queue: %s', queue) + self.logger.info("Sucessfully cancelled consumer for queue: %s", queue) return True def celery_enqueue_call(self, context, q_name, task_id=False): """Wrapper function for celery enqueue func""" - raise NotImplementedError('Not implemented') + raise NotImplementedError("Not implemented") def enqueue_jobs(self, job_counter, job_batch_size, q_name): - """Enqueue celery jobs""" - self.logger.info('Starting enqueue') - with DbSession() as session: - while True: - job_list = [] - #get all the jobs from mySQL - job_list = self.get_jobs( - session, - self.fetch_state, - self.set_state, #pylint: disable=no-member - self.args.session_id, #pylint: disable=no-member - job_batch_size) - - with job_counter_lock: - job_counter.value = job_counter.value + len(job_list) - - for i in range(0, len(job_list), job_batch_size): - batch_jobs = job_list[i:min(i + job_batch_size, len(job_list))] - context_list = self.get_context_list(session, batch_jobs) - for context in context_list: - #calling celery task, enqueuing to celery queue - self.celery_enqueue_call(context, q_name=q_name) - - self.logger.info('Job counter: %s', job_counter.value) - if not job_list: - self.logger.info('All tasks added to queue') - break + """Enqueue celery jobs with machine-specific progress tracking and error handling""" + self.logger.info("Starting enqueue") + current_batch_size = 0 + + max_retries = 3 + retry_delay = 5 # seconds + consecutive_empty_fetches = 0 + max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) + + while True: + # Retry loop for database operations + for attempt in range(max_retries): + try: + with DbSession() as session: + # Check if we should enqueue more jobs based on OUR progress + if current_batch_size > 0: + if not self.should_enqueue_more_jobs(session, current_batch_size): + self.logger.info( + "Waiting for our current batch to progress before enqueuing more" + ) + break # Exit retry loop, will wait and check again + + # Get jobs from database + job_list = self.get_jobs( + session, + self.fetch_state, + self.set_state, # pylint: disable=no-member + self.args.session_id, # pylint: disable=no-member + job_batch_size, + ) + + if not job_list: + consecutive_empty_fetches += 1 + self.logger.info('No jobs found (attempt %d/%d)', + consecutive_empty_fetches, max_empty_fetches) + + if consecutive_empty_fetches >= max_empty_fetches: + self.logger.info( + 'No new jobs after %d attempts. Exiting enqueue loop.', + max_empty_fetches) + return # Exit gracefully + + time.sleep(60) # Wait before next check + break # Break retry loop, continue main loop + + # Reset counter when jobs are found + consecutive_empty_fetches = 0 + + # Track the jobs we just claimed + new_job_ids = {job.id for job in job_list} + self.claimed_job_ids.update(new_job_ids) + + self.logger.info("Claimed jobs: %s", list(new_job_ids)) + + with job_counter_lock: + job_counter.value = job_counter.value + len(job_list) + + # Process all jobs in this batch + context_list = self.get_context_list(session, job_list) + for context in context_list: + try: + # calling celery task, enqueuing to celery queue + self.celery_enqueue_call(context, q_name=q_name) + except Exception as enqueue_err: # pylint: disable=broad-exception-caught + self.logger.error('Failed to enqueue job: %s', enqueue_err) + # Continue with other jobs rather than failing completely + continue + + current_batch_size = len(job_list) + self.logger.info( + "Job counter: %s, enqueued batch size: %s", + job_counter.value, + current_batch_size, + ) + + # Cleanup old tracking data periodically + self.cleanup_completed_jobs() + break # Success, break retry loop + + except Exception as db_err: # pylint: disable=broad-exception-caught + self.logger.warning('Database error on attempt %d/%d: %s', + attempt + 1, max_retries, db_err) + if attempt < max_retries - 1: + time.sleep(retry_delay * (attempt + 1)) # Exponential backoff + else: + self.logger.error( + 'Max retries exceeded for database operation. Exiting.') + raise + + # If we broke out because we're waiting for progress, sleep before next check + if current_batch_size > 0 and (not job_list or not self.should_enqueue_more_jobs(None, current_batch_size)): + self.logger.info("Sleeping 60s before checking for more jobs...") + time.sleep(60) + continue + + # If we got here with no jobs, the consecutive_empty_fetches logic handled it + if not job_list: + continue + + def should_enqueue_more_jobs(self, session, current_batch_size): + """Check if we should enqueue more jobs based on THIS instance's progress""" + # Count only jobs claimed by this machine instance + our_in_progress_count = len(self.claimed_job_ids - self.completed_job_ids) + + # Allow enqueuing when less than 25% of our claimed jobs are still in progress + progress_threshold = current_batch_size * self.progress_factor + + self.logger.info( + "Our jobs in progress: %d, completed: %d, threshold: %d", + our_in_progress_count, + len(self.completed_job_ids), + progress_threshold, + ) + + return our_in_progress_count < progress_threshold + + def cleanup_completed_jobs(self): + """Periodically clean up old job tracking data""" + # Keep sets from growing indefinitely + max_tracking_size = 10000 + if len(self.completed_job_ids) > max_tracking_size: + # Keep only the most recent completions + recent_completions = list(self.completed_job_ids)[-5000:] + self.completed_job_ids = set(recent_completions) + + # Remove old claimed jobs that are completed + self.claimed_job_ids -= set(recent_completions[:-1000]) async def cleanup_redis_results(self, prefix): """Remove stale redis results by key""" @@ -366,25 +495,25 @@ async def cleanup_redis_results(self, prefix): keys = [] cursor = "0" if prefix: - #a prefix is necessary when the need to different results in redis based on operation - #withough a prefix the redis key defaults to: "celery-task-meta-" - #with a prefix the key will look like: "celery-task-meta--" - #the prefix can be applied when filtering the redis keys as bellow + # a prefix is necessary when the need to different results in redis based on operation + # withough a prefix the redis key defaults to: "celery-task-meta-" + # with a prefix the key will look like: "celery-task-meta--" + # the prefix can be applied when filtering the redis keys as bellow cursor, results = await redis.scan(cursor, match=f"*{prefix}*") else: - #no prefix, match any key + # no prefix, match any key cursor, results = await redis.scan(cursor, match="*") keys.extend(results) - self.logger.info('Found %s old results', len(results)) + self.logger.info("Found %s old results", len(results)) for key in keys: try: await redis.delete(key) except aioredis.exceptions.ResponseError as red_err: self.logger.error(red_err) - self.logger.info(key.decode('utf-8')) + self.logger.info(key.decode("utf-8")) continue - self.logger.info('Done removing old redis results for prefix: %s', prefix) + self.logger.info("Done removing old redis results for prefix: %s", prefix) return True @@ -399,30 +528,30 @@ async def consume(self, job_counter, prefix): keys = [] while cursor != 0: if prefix: - #a prefix is necessary when the need to different results in redis based on operation - #withough a prefix the redis key defaults to: "celery-task-meta-" - #with a prefix the key will look like: "celery-task-meta--" - #the prefix can be applied when filtering the redis keys as bellow + # a prefix is necessary when the need to different results in redis based on operation + # withough a prefix the redis key defaults to: "celery-task-meta-" + # with a prefix the key will look like: "celery-task-meta--" + # the prefix can be applied when filtering the redis keys as bellow cursor, results = await redis.scan(cursor, match=f"*{prefix}*") else: - #no prefix, match any key + # no prefix, match any key cursor, results = await redis.scan(cursor, match="*") keys.extend(results) - self.logger.info('Found %s results', len(results)) + self.logger.info("Found %s results", len(results)) for key in keys: try: data = await redis.get(key) if data: - _ = await self.parse_result(data.decode('utf-8')) + _ = await self.parse_result(data.decode("utf-8")) await redis.delete(key) with job_counter_lock: job_counter.value = job_counter.value - 1 except aioredis.exceptions.ResponseError as red_err: self.logger.error(red_err) - self.logger.info(key.decode('utf-8')) + self.logger.info(key.decode("utf-8")) await asyncio.sleep(1) - self.logger.info('Job counter reached 0') + self.logger.info("Job counter reached 0") await redis.close() return True @@ -434,33 +563,33 @@ def prep_tuning(self): q_name = None if self.operation == Operation.COMPILE: q_name = get_q_name(self, op_compile=True) - cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{self.args.session_id} -Q {q_name}" #pylint: disable=line-too-long + cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{self.args.session_id} -Q {q_name}" # pylint: disable=line-too-long else: q_name = get_q_name(self, op_eval=True) - cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -c 1 -n tuna_HOSTNAME_sess_{self.args.session_id}_gpu_id_GPUID -Q {q_name}" #pylint: disable=line-too-long + cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -c 1 -n tuna_HOSTNAME_sess_{self.args.session_id}_gpu_id_GPUID -Q {q_name}" # pylint: disable=line-too-long - self.logger.info('celery Q name: %s', q_name) + self.logger.info("celery Q name: %s", q_name) if not self.args.enqueue_only: try: - self.logger.info('Launching celery workers for queue %s', q_name) + self.logger.info("Launching celery workers for queue %s", q_name) subp_list = launch_celery_worker(self.operation, cmd, self.args, True) - self.logger.info('Done launching celery workers') + self.logger.info("Done launching celery workers") if not subp_list: - raise CustomError('Could not launch celery worker') + raise CustomError("Could not launch celery worker") except kombu.exceptions.OperationalError as k_err: - self.logger.error('Redis error ocurred: %s', k_err) + self.logger.error("Redis error ocurred: %s", k_err) return False else: purge_queue([q_name]) return q_name, subp_list - #pylint: disable=too-many-locals + # pylint: disable=too-many-locals def tune(self, job_batch_size=1000): """tuning loop to spin out celery tasks""" if self.args.shutdown_workers: - self.logger.info('Shutting down all celery workers') + self.logger.info("Shutting down all celery workers") stop_active_workers() return True @@ -471,7 +600,7 @@ def tune(self, job_batch_size=1000): return False try: - #if enqueue_only is False, we launch the celery workers + # if enqueue_only is False, we launch the celery workers if not self.args.enqueue_only: for subp in subp_list: subp.wait() @@ -483,43 +612,49 @@ def tune(self, job_batch_size=1000): start = time.time() - #set job count to 1 until first job fetch is finished - job_counter = Value('i', 1) + # set job count to 1 until first job fetch is finished + job_counter = Value("i", 1) try: enqueue_proc = Process(target=self.enqueue_jobs, args=[job_counter, job_batch_size, q_name]) - #Start enqueue proc + # Start enqueue proc enqueue_proc.start() - #cleanup old results + # cleanup old results cleanup_proc = Process(target=self.async_wrap, args=(self.cleanup_redis_results, self.prefix)) cleanup_proc.start() cleanup_proc.join() - #start async consume thread, blocking + # start async consume thread, blocking consume_proc = Process(target=self.async_wrap, args=(self.consume, job_counter, self.prefix)) - self.logger.info('Starting consume thread') + self.logger.info("Starting consume thread") consume_proc.start() enqueue_proc.join() - #enqueue finished first fetch, remove hold on job_counter + # enqueue finished first fetch, remove hold on job_counter with job_counter_lock: job_counter.value = job_counter.value - 1 - #check for new jobs + # Progress-aware polling - shorter intervals, smarter enqueuing + poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 5)) + + # check for new jobs while consume_proc.is_alive(): enqueue_proc = Process(target=self.enqueue_jobs, args=[job_counter, job_batch_size, q_name]) enqueue_proc.start() enqueue_proc.join() - time.sleep(10) + time.sleep(poll_interval) # Shorter, configurable polling consume_proc.join() - except (KeyboardInterrupt, Exception) as exp: #pylint: disable=broad-exception-caught - self.logger.error('Error ocurred %s', exp) + except ( + KeyboardInterrupt, + Exception, + ) as exp: # pylint: disable=broad-exception-caught + self.logger.error("Error ocurred %s", exp) purge_queue([q_name]) self.cancel_consumer(q_name) self.reset_job_state_on_ctrl_c() @@ -528,7 +663,7 @@ def tune(self, job_batch_size=1000): self.cancel_consumer(q_name) end = time.time() - self.logger.info("Took {:0>8} to tune".format( #pylint: disable=consider-using-f-string + self.logger.info("Took {:0>8} to tune".format( # pylint: disable=consider-using-f-string str(timedelta(seconds=end - start)))) return True @@ -542,38 +677,40 @@ def async_wrap(self, async_func, *args): try: asyncio.run(self.async_callback(async_func, *args)) except KeyboardInterrupt: - self.logger.warning('Keyboard interrupt caught, terminating') + self.logger.warning("Keyboard interrupt caught, terminating") def reset_job_state_on_ctrl_c(self): """Reset job state for jobs in flight""" temp_obj = SimpleDict() - temp_obj.session_id = self.args.session_id #pylint: disable=invalid-name - attribs = ['state'] + temp_obj.session_id = self.args.session_id # pylint: disable=invalid-name + attribs = ["state"] temp_obj.state = 1 - self.logger.info('Resetting job state in DB for in flight jobs') + self.logger.info("Resetting job state in DB for in flight jobs") if self.operation == Operation.COMPILE: state = 16 elif self.operation == Operation.EVAL: state = 12 - query = gen_update_query(temp_obj, attribs, - self.dbt.job_table.__tablename__, - [('session', self.args.session_id), - ('state', state)]) + query = gen_update_query( + temp_obj, + attribs, + self.dbt.job_table.__tablename__, + [("session", self.args.session_id), ("state", state)], + ) with DbSession() as session: - #pylint: disable=duplicate-code + # pylint: disable=duplicate-code def callback() -> bool: session.execute(query) session.commit() return True - #pylint: enable=duplicate-code + # pylint: enable=duplicate-code assert session_retry(session, callback, lambda x: x(), self.logger) - self.logger.info('Sucessfully reset job state') + self.logger.info("Sucessfully reset job state") return True return False @@ -598,7 +735,7 @@ def check_jobs_found(self, job_rows: List[SimpleDict], find_state: List[Any], """check for end of jobs""" if not job_rows: # we are done - self.logger.warning('No %s jobs found, session %s', find_state, + self.logger.warning("No %s jobs found, session %s", find_state, session_id) return False return True @@ -624,7 +761,7 @@ def get_context_list(self, session, batch_jobs): context_list: List[dict] = None serialized_jobs = self.serialize_jobs(session, batch_jobs) - #build context for each celery task + # build context for each celery task context_list = self.build_context(serialized_jobs) return context_list @@ -635,22 +772,35 @@ async def parse_result(self, data): with DbSession() as session: try: - fin_json = data['result']['ret'] - context = data['result']['context'] + fin_json = data["result"]["ret"] + context = data["result"]["context"] + + # Extract job ID from context to track completion + job_id = self.extract_job_id_from_context(context) + if job_id and job_id in self.claimed_job_ids: + self.completed_job_ids.add(job_id) + self.logger.info("Marked job %s as completed", job_id) + except KeyError as kerr: self.logger.error(kerr) return False - self.logger.info('Parsing: %s', fin_json) + self.logger.info("Parsing: %s", fin_json) if self.operation == Operation.COMPILE: self.process_compile_results(session, fin_json, context) elif self.operation == Operation.EVAL: self.process_eval_results(session, fin_json, context) else: - raise CustomError('Unsupported tuning operation') + raise CustomError("Unsupported tuning operation") return True + def extract_job_id_from_context(self, context): + """Extract job ID from celery task context""" + # This needs to be implemented in the MIOpen subclass + # based on how job IDs are stored in the context + raise NotImplementedError("Subclass must implement job ID extraction") + def process_compile_results(self, session, fin_json, context): """Process result from fin_build worker""" raise NotImplementedError("Not implemented")