From 27a67bf3166e9a5e8542c46490eb9bfa57ed5432 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Wed, 8 Oct 2025 11:11:30 -0400 Subject: [PATCH 01/10] Added changes/additions to Dockerfile --- Dockerfile | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index b043df81..cc493e68 100644 --- a/Dockerfile +++ b/Dockerfile @@ -124,7 +124,9 @@ ARG MIOPEN_DIR=$ROCM_LIBS_DIR/projects/miopen RUN git clone --filter=blob:none --sparse https://github.com/ROCm/rocm-libraries.git $ROCM_LIBS_DIR WORKDIR $MIOPEN_DIR RUN git sparse-checkout set projects/miopen -ARG MIOPEN_BRANCH=4940cf3ec +# not sure what this commit is, using latest develop for now +# ARG MIOPEN_BRANCH=4940cf3ec +ARG MIOPEN_BRANCH=develop RUN git pull && git checkout $MIOPEN_BRANCH ARG PREFIX=/opt/rocm @@ -209,3 +211,26 @@ RUN python3 setup.py install # reset WORKDIR to /tuna WORKDIR /tuna + +# save BASEIMAGE as env variable +ENV BASEIMAGE=${BASEIMAGE} + +# install mysql-server and mysql-client +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + mysql-server \ + mysql-client + +# install redis-server +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + redis-server + +# install RabbitMQ server +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + rabbitmq-server + +# install iproute2 +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + iproute2 + +# clean up apt cache +RUN apt-get clean && rm -rf /var/lib/apt/lists/* \ No newline at end of file From 11462b3e604c61c806fadacb6e6902aee0163db7 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Thu, 16 Oct 2025 08:51:42 +0000 Subject: [PATCH 02/10] auto format --- tuna/mituna_interface.py | 1196 +++++++++++++++++++------------------- 1 file changed, 611 insertions(+), 585 deletions(-) diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 140d1542..345a915b 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -61,600 +61,626 @@ job_counter_lock = threading.Lock() -class MITunaInterface(): #pylint:disable=too-many-instance-attributes,too-many-public-methods - """ Interface class extended by libraries. The purpose of this class is to define - common functionalities. """ - - def __init__(self, library=Library.MIOPEN) -> None: - - self.self: Library = self - - self.logger: logging.Logger = setup_logger(logger_name=library.value, - add_streamhandler=True) - self.args: argparse.Namespace - - self.fetch_state: set = set() - self.max_job_retries = 10 - self.dbt = None - self.operation = None - self.db_name = os.environ['TUNA_DB_NAME'] - self.prefix = None - - def check_docker(self, - worker: WorkerInterface, - dockername="miopentuna") -> bool: - """! Checking for docker - @param worker The worker interface instance - @param dockername The name of the docker - """ - out2: ChannelFile - _, out2, _ = worker.exec_command("sudo docker info") - while not out2.channel.exit_status_ready(): - self.logger.warning(out2.readline()) - if out2.channel.exit_status > 0: - self.logger.warning( - "docker not installed or failed to run with sudo .... ") - return False - - out: StringIO = StringIO() - line: Optional[str] = None - _, out, _ = worker.exec_command(f"sudo docker images | grep {dockername}") - for line in out.readlines(): - if line is not None: - if line.find(dockername) != -1: - self.logger.warning('%s docker image exists', dockername) - return True - if line is None: - self.logger.warning('%s docker image does not exist', dockername) - return False - - return False - - def check_status(self, - worker: WorkerInterface, - b_first: int, - gpu_idx: int, - machine: Machine, - dockername: str = "miopentuna") -> bool: - """! Function to check gpu_status - @param worker The worker interface instance - @param b_first Flag to keep track of visited GPU - @param gpu_idx Unique ID of the GPU - @param machine The machine instance - @param dockername The name of the docker - """ - - if machine.chk_gpu_status(worker.gpu_id): - self.logger.info('Machine: (%s, %u) GPU_ID: %u OK', machine.hostname, - machine.port, gpu_idx) - else: - self.logger.info('Machine: (%s, %u) GPU_ID: %u ERROR', machine.hostname, - machine.port, gpu_idx) - - if not b_first: - return False - b_first = False - _, out, _ = worker.exec_command("docker info") - while not out.channel.exit_status_ready(): - pass - - if out.channel.exit_status > 0: - self.check_docker(worker, dockername) - else: - _, out, _ = worker.exec_command(f"docker images | grep {dockername}") - line: Optional[str] = None - for line in out.readlines(): - if line is not None: - if line.find(dockername) != -1: - self.logger.warning('%s docker image exists', dockername) - break +class MITunaInterface: # pylint:disable=too-many-instance-attributes,too-many-public-methods + """Interface class extended by libraries. The purpose of this class is to define + common functionalities.""" + + def __init__(self, library=Library.MIOPEN) -> None: + + self.self: Library = self + + self.logger: logging.Logger = setup_logger( + logger_name=library.value, add_streamhandler=True + ) + self.args: argparse.Namespace + + self.fetch_state: set = set() + self.max_job_retries = 10 + self.dbt = None + self.operation = None + self.db_name = os.environ["TUNA_DB_NAME"] + self.prefix = None + + def check_docker(self, worker: WorkerInterface, dockername="miopentuna") -> bool: + """! Checking for docker + @param worker The worker interface instance + @param dockername The name of the docker + """ + out2: ChannelFile + _, out2, _ = worker.exec_command("sudo docker info") + while not out2.channel.exit_status_ready(): + self.logger.warning(out2.readline()) + if out2.channel.exit_status > 0: + self.logger.warning("docker not installed or failed to run with sudo .... ") + return False + + out: StringIO = StringIO() + line: Optional[str] = None + _, out, _ = worker.exec_command(f"sudo docker images | grep {dockername}") + for line in out.readlines(): + if line is not None: + if line.find(dockername) != -1: + self.logger.warning("%s docker image exists", dockername) + return True + if line is None: + self.logger.warning("%s docker image does not exist", dockername) + return False + + return False + + def check_status( + self, + worker: WorkerInterface, + b_first: int, + gpu_idx: int, + machine: Machine, + dockername: str = "miopentuna", + ) -> bool: + """! Function to check gpu_status + @param worker The worker interface instance + @param b_first Flag to keep track of visited GPU + @param gpu_idx Unique ID of the GPU + @param machine The machine instance + @param dockername The name of the docker + """ + + if machine.chk_gpu_status(worker.gpu_id): + self.logger.info( + "Machine: (%s, %u) GPU_ID: %u OK", + machine.hostname, + machine.port, + gpu_idx, + ) + else: + self.logger.info( + "Machine: (%s, %u) GPU_ID: %u ERROR", + machine.hostname, + machine.port, + gpu_idx, + ) + + if not b_first: + return False + b_first = False + _, out, _ = worker.exec_command("docker info") + while not out.channel.exit_status_ready(): + pass + + if out.channel.exit_status > 0: + self.check_docker(worker, dockername) + else: + _, out, _ = worker.exec_command(f"docker images | grep {dockername}") + line: Optional[str] = None + for line in out.readlines(): + if line is not None: + if line.find(dockername) != -1: + self.logger.warning("%s docker image exists", dockername) + break + else: + self.logger.warning("%s docker image does not exist", dockername) + + return True + + def add_tables(self) -> bool: + """Add self specific tables""" + return self.add_tables() + + def get_num_procs(self, machine: Machine) -> List: + """Determine number of processes by compute capacity""" + worker_ids: List = [] + num_procs: int + env: Dict[str, Any] + env = get_env_vars() + if env["slurm_cpus"] > 0: + num_procs = int(env["slurm_cpus"]) else: - self.logger.warning('%s docker image does not exist', dockername) - - return True - - def add_tables(self) -> bool: - """Add self specific tables""" - return self.add_tables() - - def get_num_procs(self, machine: Machine) -> List: - """Determine number of processes by compute capacity""" - worker_ids: List = [] - num_procs: int - env: Dict[str, Any] - env = get_env_vars() - if env['slurm_cpus'] > 0: - num_procs = int(env['slurm_cpus']) - else: - num_procs = int(machine.get_num_cpus() * .6) - - worker_ids = list(range(num_procs)) - - if len(worker_ids) == 0: - self.logger.error('num_procs must be bigger than zero to launch worker') - self.logger.error('Cannot launch worker on machine: %s', machine.id) - worker_ids = [] - - return worker_ids - - def get_f_vals(self, - machine: Machine, - worker_ids: range, - tuning=False) -> Dict[str, Any]: - #pylint:disable=unused-argument - """Determine kwargs for worker_interface""" - f_vals: Dict[str, Any] - f_vals = self.compose_f_vals(machine) - f_vals['envmt'] = self.get_envmt() - - if not tuning: - f_vals["num_procs"] = Value('i', len(worker_ids)) - - return f_vals - - def get_envmt(self): - """Get runtime envmt""" - raise NotImplementedError("Not implemented") - - def compose_f_vals(self, machine: Machine, tuning=False) -> Dict[str, Any]: - """! Compose dict for WorkerInterface constructor - @param args The command line arguments - @param machine Machine instance - """ - f_vals: Dict[str, Any] = {} - f_vals["b_first"] = True - - #adding non-serializable obj when not running through celery - if not tuning: - f_vals["machine"] = machine - f_vals["bar_lock"] = Lock() - #multiprocess queue for jobs, shared on machine - f_vals["job_queue"] = mpQueue() - f_vals["job_queue_lock"] = Lock() - f_vals["end_jobs"] = Value('i', 0) - - return f_vals - - def get_kwargs(self, - gpu_idx: int, - f_vals: Dict[str, Any], - tuning=False) -> Dict[str, Any]: - """! Helper function to set up kwargs for worker instances - @param gpu_idx Unique ID of the GPU - @param f_vals Dict containing runtime information - """ - envmt: Dict[str, Any] = f_vals["envmt"].copy() - kwargs: Dict[str, Any] = {} - - kwargs = { - 'gpu_id': gpu_idx, - 'envmt': envmt, - 'label': self.args.label, - 'docker_name': self.args.docker_name, - 'session_id': self.args.session_id - } - - #adding non-serializable obj when not running through celery - if not tuning: - kwargs["machine"] = f_vals["machine"] - kwargs["job_queue"] = f_vals["job_queue"] - kwargs["job_queue_lock"] = f_vals["job_queue_lock"] - kwargs["num_procs"] = f_vals["num_procs"] - kwargs["bar_lock"] = f_vals["bar_lock"] - kwargs["end_jobs"] = f_vals["end_jobs"] - kwargs["job_queue"] = f_vals["job_queue"] - kwargs["job_queue_lock"] = f_vals["job_queue_lock"] - - return kwargs - - def get_job_list(self, session, find_state, claim_num): - """Get list of jobs""" - raise NotImplementedError("Not implemented") - - def get_jobs(self, - session: DbSession, - find_state: List[str], - set_state: str, - session_id: int, - claim_num: int = None, - no_update=False): - """Interface function to get jobs based on session and find_state""" - #job_rows: List[SimpleDict] - ids: list - row: SimpleDict - - self.logger.info('Fetching DB rows...') - job_list = self.get_job_list(session, find_state, claim_num) - - if not self.check_jobs_found(job_list, find_state, session_id): - return [] - - if no_update: - return job_list - - ids = [row.id for row in job_list] - self.logger.info("%s jobs %s", find_state, ids) - self.logger.info('Updating job state to %s', set_state) - for job in job_list: - job.state = set_state - if self.dbt is not None: - query: str = gen_update_query(job, ['state'], - self.dbt.job_table.__tablename__) - else: - raise CustomError('DBTable must be set') - session.execute(query) - - session.commit() - - return job_list - - def shutdown_workers(self): - """Shutdown all active celery workers regardless of queue""" - return stop_active_workers() - - def cancel_consumer(self, queue): - """Cancel consumers for queue""" - try: - cmd = f"celery -A tuna.celery_app.celery_app control cancel_consumer {queue}" - subp = subprocess.Popen( #pylint: disable=consider-using-with - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - shell=True, - universal_newlines=True) - - #filter the workers by session id - sess_str = "sess_" + queue.split('_')[-1] - stdout, _ = subp.stdout, subp.stderr - while True: - line = stdout.readline() - if not line: - break - #stop workers that were feeding from this queue - if "->" in line and sess_str in line: - hostname = line.split('->')[1].split()[0].split(':')[0] - stop_named_worker(hostname) - - except Exception as exp: #pylint: disable=broad-exception-caught - self.logger.warning( - 'Error occurred trying to cancel consumer for queue: %s ', queue) - self.logger.warning(exp) - return False - - self.logger.info('Sucessfully cancelled consumer for queue: %s', queue) - - return True - - def celery_enqueue_call(self, context, q_name, task_id=False): - """Wrapper function for celery enqueue func""" - raise NotImplementedError('Not implemented') - - def enqueue_jobs(self, job_counter, job_batch_size, q_name): - """Enqueue celery jobs""" - self.logger.info('Starting enqueue') - with DbSession() as session: - while True: - job_list = [] - #get all the jobs from mySQL - job_list = self.get_jobs( - session, - self.fetch_state, - self.set_state, #pylint: disable=no-member - self.args.session_id, #pylint: disable=no-member - job_batch_size) - - with job_counter_lock: - job_counter.value = job_counter.value + len(job_list) - - for i in range(0, len(job_list), job_batch_size): - batch_jobs = job_list[i:min(i + job_batch_size, len(job_list))] - context_list = self.get_context_list(session, batch_jobs) - for context in context_list: - #calling celery task, enqueuing to celery queue - self.celery_enqueue_call(context, q_name=q_name) - - self.logger.info('Job counter: %s', job_counter.value) - if not job_list: - self.logger.info('All tasks added to queue') - break - - async def cleanup_redis_results(self, prefix): - """Remove stale redis results by key""" - backend_port, backend_host = get_backend_env() - redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") - - keys = [] - cursor = "0" - if prefix: - #a prefix is necessary when the need to different results in redis based on operation - #withough a prefix the redis key defaults to: "celery-task-meta-" - #with a prefix the key will look like: "celery-task-meta--" - #the prefix can be applied when filtering the redis keys as bellow - cursor, results = await redis.scan(cursor, match=f"*{prefix}*") - else: - #no prefix, match any key - cursor, results = await redis.scan(cursor, match="*") - keys.extend(results) - self.logger.info('Found %s old results', len(results)) - for key in keys: - try: - await redis.delete(key) - except aioredis.exceptions.ResponseError as red_err: - self.logger.error(red_err) - self.logger.info(key.decode('utf-8')) - continue - - self.logger.info('Done removing old redis results for prefix: %s', prefix) - - return True - - async def consume(self, job_counter, prefix): - """Retrieve celery results from redis db""" - - backend_port, backend_host = get_backend_env() - redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") - - while job_counter.value > 0: - cursor = "0" - keys = [] - while cursor != 0: + num_procs = int(machine.get_num_cpus() * 0.6) + + worker_ids = list(range(num_procs)) + + if len(worker_ids) == 0: + self.logger.error("num_procs must be bigger than zero to launch worker") + self.logger.error("Cannot launch worker on machine: %s", machine.id) + worker_ids = [] + + return worker_ids + + def get_f_vals( + self, machine: Machine, worker_ids: range, tuning=False + ) -> Dict[str, Any]: + # pylint:disable=unused-argument + """Determine kwargs for worker_interface""" + f_vals: Dict[str, Any] + f_vals = self.compose_f_vals(machine) + f_vals["envmt"] = self.get_envmt() + + if not tuning: + f_vals["num_procs"] = Value("i", len(worker_ids)) + + return f_vals + + def get_envmt(self): + """Get runtime envmt""" + raise NotImplementedError("Not implemented") + + def compose_f_vals(self, machine: Machine, tuning=False) -> Dict[str, Any]: + """! Compose dict for WorkerInterface constructor + @param args The command line arguments + @param machine Machine instance + """ + f_vals: Dict[str, Any] = {} + f_vals["b_first"] = True + + # adding non-serializable obj when not running through celery + if not tuning: + f_vals["machine"] = machine + f_vals["bar_lock"] = Lock() + # multiprocess queue for jobs, shared on machine + f_vals["job_queue"] = mpQueue() + f_vals["job_queue_lock"] = Lock() + f_vals["end_jobs"] = Value("i", 0) + + return f_vals + + def get_kwargs( + self, gpu_idx: int, f_vals: Dict[str, Any], tuning=False + ) -> Dict[str, Any]: + """! Helper function to set up kwargs for worker instances + @param gpu_idx Unique ID of the GPU + @param f_vals Dict containing runtime information + """ + envmt: Dict[str, Any] = f_vals["envmt"].copy() + kwargs: Dict[str, Any] = {} + + kwargs = { + "gpu_id": gpu_idx, + "envmt": envmt, + "label": self.args.label, + "docker_name": self.args.docker_name, + "session_id": self.args.session_id, + } + + # adding non-serializable obj when not running through celery + if not tuning: + kwargs["machine"] = f_vals["machine"] + kwargs["job_queue"] = f_vals["job_queue"] + kwargs["job_queue_lock"] = f_vals["job_queue_lock"] + kwargs["num_procs"] = f_vals["num_procs"] + kwargs["bar_lock"] = f_vals["bar_lock"] + kwargs["end_jobs"] = f_vals["end_jobs"] + kwargs["job_queue"] = f_vals["job_queue"] + kwargs["job_queue_lock"] = f_vals["job_queue_lock"] + + return kwargs + + def get_job_list(self, session, find_state, claim_num): + """Get list of jobs""" + raise NotImplementedError("Not implemented") + + def get_jobs( + self, + session: DbSession, + find_state: List[str], + set_state: str, + session_id: int, + claim_num: int = None, + no_update=False, + ): + """Interface function to get jobs based on session and find_state""" + # job_rows: List[SimpleDict] + ids: list + row: SimpleDict + + self.logger.info("Fetching DB rows...") + job_list = self.get_job_list(session, find_state, claim_num) + + if not self.check_jobs_found(job_list, find_state, session_id): + return [] + + if no_update: + return job_list + + ids = [row.id for row in job_list] + self.logger.info("%s jobs %s", find_state, ids) + self.logger.info("Updating job state to %s", set_state) + for job in job_list: + job.state = set_state + if self.dbt is not None: + query: str = gen_update_query( + job, ["state"], self.dbt.job_table.__tablename__ + ) + else: + raise CustomError("DBTable must be set") + session.execute(query) + + session.commit() + + return job_list + + def shutdown_workers(self): + """Shutdown all active celery workers regardless of queue""" + return stop_active_workers() + + def cancel_consumer(self, queue): + """Cancel consumers for queue""" + try: + cmd = ( + f"celery -A tuna.celery_app.celery_app control cancel_consumer {queue}" + ) + subp = subprocess.Popen( # pylint: disable=consider-using-with + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + shell=True, + universal_newlines=True, + ) + + # filter the workers by session id + sess_str = "sess_" + queue.split("_")[-1] + stdout, _ = subp.stdout, subp.stderr + while True: + line = stdout.readline() + if not line: + break + # stop workers that were feeding from this queue + if "->" in line and sess_str in line: + hostname = line.split("->")[1].split()[0].split(":")[0] + stop_named_worker(hostname) + + except Exception as exp: # pylint: disable=broad-exception-caught + self.logger.warning( + "Error occurred trying to cancel consumer for queue: %s ", queue + ) + self.logger.warning(exp) + return False + + self.logger.info("Sucessfully cancelled consumer for queue: %s", queue) + + return True + + def celery_enqueue_call(self, context, q_name, task_id=False): + """Wrapper function for celery enqueue func""" + raise NotImplementedError("Not implemented") + + def enqueue_jobs(self, job_counter, job_batch_size, q_name): + """Enqueue celery jobs""" + self.logger.info("Starting enqueue") + with DbSession() as session: + while True: + job_list = [] + # get all the jobs from mySQL + job_list = self.get_jobs( + session, + self.fetch_state, + self.set_state, # pylint: disable=no-member + self.args.session_id, # pylint: disable=no-member + job_batch_size, + ) + + with job_counter_lock: + job_counter.value = job_counter.value + len(job_list) + + for i in range(0, len(job_list), job_batch_size): + batch_jobs = job_list[i : min(i + job_batch_size, len(job_list))] + context_list = self.get_context_list(session, batch_jobs) + for context in context_list: + # calling celery task, enqueuing to celery queue + self.celery_enqueue_call(context, q_name=q_name) + + self.logger.info("Job counter: %s", job_counter.value) + if not job_list: + self.logger.info("All tasks added to queue") + break + + async def cleanup_redis_results(self, prefix): + """Remove stale redis results by key""" + backend_port, backend_host = get_backend_env() + redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") + + keys = [] + cursor = "0" if prefix: - #a prefix is necessary when the need to different results in redis based on operation - #withough a prefix the redis key defaults to: "celery-task-meta-" - #with a prefix the key will look like: "celery-task-meta--" - #the prefix can be applied when filtering the redis keys as bellow - cursor, results = await redis.scan(cursor, match=f"*{prefix}*") + # a prefix is necessary when the need to different results in redis based on operation + # withough a prefix the redis key defaults to: "celery-task-meta-" + # with a prefix the key will look like: "celery-task-meta--" + # the prefix can be applied when filtering the redis keys as bellow + cursor, results = await redis.scan(cursor, match=f"*{prefix}*") else: - #no prefix, match any key - cursor, results = await redis.scan(cursor, match="*") + # no prefix, match any key + cursor, results = await redis.scan(cursor, match="*") keys.extend(results) - self.logger.info('Found %s results', len(results)) - for key in keys: + self.logger.info("Found %s old results", len(results)) + for key in keys: + try: + await redis.delete(key) + except aioredis.exceptions.ResponseError as red_err: + self.logger.error(red_err) + self.logger.info(key.decode("utf-8")) + continue + + self.logger.info("Done removing old redis results for prefix: %s", prefix) + + return True + + async def consume(self, job_counter, prefix): + """Retrieve celery results from redis db""" + + backend_port, backend_host = get_backend_env() + redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") + + while job_counter.value > 0: + cursor = "0" + keys = [] + while cursor != 0: + if prefix: + # a prefix is necessary when the need to different results in redis based on operation + # withough a prefix the redis key defaults to: "celery-task-meta-" + # with a prefix the key will look like: "celery-task-meta--" + # the prefix can be applied when filtering the redis keys as bellow + cursor, results = await redis.scan(cursor, match=f"*{prefix}*") + else: + # no prefix, match any key + cursor, results = await redis.scan(cursor, match="*") + keys.extend(results) + self.logger.info("Found %s results", len(results)) + for key in keys: + try: + data = await redis.get(key) + if data: + _ = await self.parse_result(data.decode("utf-8")) + await redis.delete(key) + with job_counter_lock: + job_counter.value = job_counter.value - 1 + except aioredis.exceptions.ResponseError as red_err: + self.logger.error(red_err) + self.logger.info(key.decode("utf-8")) + + await asyncio.sleep(1) + self.logger.info("Job counter reached 0") + await redis.close() + + return True + + def prep_tuning(self): + """Prep env for tuning start""" + cmd = None + subp_list = [] + q_name = None + if self.operation == Operation.COMPILE: + q_name = get_q_name(self, op_compile=True) + cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{self.args.session_id} -Q {q_name}" # pylint: disable=line-too-long + else: + q_name = get_q_name(self, op_eval=True) + cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -c 1 -n tuna_HOSTNAME_sess_{self.args.session_id}_gpu_id_GPUID -Q {q_name}" # pylint: disable=line-too-long + + self.logger.info("celery Q name: %s", q_name) + if not self.args.enqueue_only: + try: + self.logger.info("Launching celery workers for queue %s", q_name) + subp_list = launch_celery_worker(self.operation, cmd, self.args, True) + self.logger.info("Done launching celery workers") + if not subp_list: + raise CustomError("Could not launch celery worker") + except kombu.exceptions.OperationalError as k_err: + self.logger.error("Redis error ocurred: %s", k_err) + return False + else: + purge_queue([q_name]) + + return q_name, subp_list + + # pylint: disable=too-many-locals + def tune(self, job_batch_size=1000): + """tuning loop to spin out celery tasks""" + + if self.args.shutdown_workers: + self.logger.info("Shutting down all celery workers") + stop_active_workers() + return True + + try: + q_name, subp_list = self.prep_tuning() + except CustomError as verr: + self.logger.error(verr) + return False + try: - data = await redis.get(key) - if data: - _ = await self.parse_result(data.decode('utf-8')) - await redis.delete(key) + # if enqueue_only is False, we launch the celery workers + if not self.args.enqueue_only: + for subp in subp_list: + subp.wait() + return True + except KeyboardInterrupt: + for subp in subp_list: + subp.kill() + return False + + start = time.time() + + # set job count to 1 until first job fetch is finished + job_counter = Value("i", 1) + try: + enqueue_proc = Process( + target=self.enqueue_jobs, args=[job_counter, job_batch_size, q_name] + ) + # Start enqueue proc + enqueue_proc.start() + + # cleanup old results + cleanup_proc = Process( + target=self.async_wrap, args=(self.cleanup_redis_results, self.prefix) + ) + cleanup_proc.start() + cleanup_proc.join() + + # start async consume thread, blocking + consume_proc = Process( + target=self.async_wrap, args=(self.consume, job_counter, self.prefix) + ) + self.logger.info("Starting consume thread") + consume_proc.start() + + enqueue_proc.join() + # enqueue finished first fetch, remove hold on job_counter with job_counter_lock: - job_counter.value = job_counter.value - 1 - except aioredis.exceptions.ResponseError as red_err: - self.logger.error(red_err) - self.logger.info(key.decode('utf-8')) - - await asyncio.sleep(1) - self.logger.info('Job counter reached 0') - await redis.close() - - return True - - def prep_tuning(self): - """Prep env for tuning start""" - cmd = None - subp_list = [] - q_name = None - if self.operation == Operation.COMPILE: - q_name = get_q_name(self, op_compile=True) - cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{self.args.session_id} -Q {q_name}" #pylint: disable=line-too-long - else: - q_name = get_q_name(self, op_eval=True) - cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -c 1 -n tuna_HOSTNAME_sess_{self.args.session_id}_gpu_id_GPUID -Q {q_name}" #pylint: disable=line-too-long - - self.logger.info('celery Q name: %s', q_name) - if not self.args.enqueue_only: - try: - self.logger.info('Launching celery workers for queue %s', q_name) - subp_list = launch_celery_worker(self.operation, cmd, self.args, True) - self.logger.info('Done launching celery workers') - if not subp_list: - raise CustomError('Could not launch celery worker') - except kombu.exceptions.OperationalError as k_err: - self.logger.error('Redis error ocurred: %s', k_err) - return False - else: - purge_queue([q_name]) - - return q_name, subp_list - - #pylint: disable=too-many-locals - def tune(self, job_batch_size=1000): - """tuning loop to spin out celery tasks""" - - if self.args.shutdown_workers: - self.logger.info('Shutting down all celery workers') - stop_active_workers() - return True - - try: - q_name, subp_list = self.prep_tuning() - except CustomError as verr: - self.logger.error(verr) - return False - - try: - #if enqueue_only is False, we launch the celery workers - if not self.args.enqueue_only: - for subp in subp_list: - subp.wait() - return True - except KeyboardInterrupt: - for subp in subp_list: - subp.kill() - return False - - start = time.time() - - #set job count to 1 until first job fetch is finished - job_counter = Value('i', 1) - try: - enqueue_proc = Process(target=self.enqueue_jobs, - args=[job_counter, job_batch_size, q_name]) - #Start enqueue proc - enqueue_proc.start() - - #cleanup old results - cleanup_proc = Process(target=self.async_wrap, - args=(self.cleanup_redis_results, self.prefix)) - cleanup_proc.start() - cleanup_proc.join() - - #start async consume thread, blocking - consume_proc = Process(target=self.async_wrap, - args=(self.consume, job_counter, self.prefix)) - self.logger.info('Starting consume thread') - consume_proc.start() - - enqueue_proc.join() - #enqueue finished first fetch, remove hold on job_counter - with job_counter_lock: - job_counter.value = job_counter.value - 1 - - #check for new jobs - while consume_proc.is_alive(): - enqueue_proc = Process(target=self.enqueue_jobs, - args=[job_counter, job_batch_size, q_name]) - enqueue_proc.start() - enqueue_proc.join() - time.sleep(10) - - consume_proc.join() - - except (KeyboardInterrupt, Exception) as exp: #pylint: disable=broad-exception-caught - self.logger.error('Error ocurred %s', exp) - purge_queue([q_name]) - self.cancel_consumer(q_name) - self.reset_job_state_on_ctrl_c() - with job_counter_lock: - job_counter.value = 0 - - self.cancel_consumer(q_name) - end = time.time() - self.logger.info("Took {:0>8} to tune".format( #pylint: disable=consider-using-f-string - str(timedelta(seconds=end - start)))) - - return True - - async def async_callback(self, async_func, *args): - """Wrapper function to await on async function""" - await async_func(*args) - - def async_wrap(self, async_func, *args): - """Run async function""" - try: - asyncio.run(self.async_callback(async_func, *args)) - except KeyboardInterrupt: - self.logger.warning('Keyboard interrupt caught, terminating') - - def reset_job_state_on_ctrl_c(self): - """Reset job state for jobs in flight""" - temp_obj = SimpleDict() - temp_obj.session_id = self.args.session_id #pylint: disable=invalid-name - attribs = ['state'] - temp_obj.state = 1 - - self.logger.info('Resetting job state in DB for in flight jobs') - - if self.operation == Operation.COMPILE: - state = 16 - elif self.operation == Operation.EVAL: - state = 12 - - query = gen_update_query(temp_obj, attribs, - self.dbt.job_table.__tablename__, - [('session', self.args.session_id), - ('state', state)]) - with DbSession() as session: - - #pylint: disable=duplicate-code - def callback() -> bool: - session.execute(query) - session.commit() + job_counter.value = job_counter.value - 1 + + # check for new jobs + while consume_proc.is_alive(): + enqueue_proc = Process( + target=self.enqueue_jobs, args=[job_counter, job_batch_size, q_name] + ) + enqueue_proc.start() + enqueue_proc.join() + time.sleep(10) + + consume_proc.join() + + except ( + KeyboardInterrupt, + Exception, + ) as exp: # pylint: disable=broad-exception-caught + self.logger.error("Error ocurred %s", exp) + purge_queue([q_name]) + self.cancel_consumer(q_name) + self.reset_job_state_on_ctrl_c() + with job_counter_lock: + job_counter.value = 0 + + self.cancel_consumer(q_name) + end = time.time() + self.logger.info( + "Took {:0>8} to tune".format( # pylint: disable=consider-using-f-string + str(timedelta(seconds=end - start)) + ) + ) + return True - #pylint: enable=duplicate-code - - assert session_retry(session, callback, lambda x: x(), self.logger) - self.logger.info('Sucessfully reset job state') - return True - - return False - - def has_tunable_operation(self): - """Check if current operation is a tuning operation""" - raise NotImplementedError("Not implemented") - - def get_job_attr(self): - """Get job attr for row selection""" - job_attr: List[str] = None - try: - job_attr = [column.name for column in inspect(self.dbt.job_table).c] - job_attr.remove("insert_ts") - job_attr.remove("update_ts") - except NoInspectionAvailable as error: - self.logger.warning("Ignoring error for init_session: %s", error) - return job_attr - - def check_jobs_found(self, job_rows: List[SimpleDict], find_state: List[Any], - session_id: int) -> bool: - """check for end of jobs""" - if not job_rows: - # we are done - self.logger.warning('No %s jobs found, session %s', find_state, - session_id) - return False - return True - - @lru_cache(1) - def get_context_items(self): - """Helper function to get items for celery job context""" - kwargs = None - f_vals = self.get_f_vals(Machine(local_machine=True), range(0), tuning=True) - kwargs = self.get_kwargs(0, f_vals, tuning=True) - return kwargs - - def serialize_jobs(self, session, batch_jobs): - """Return list of serialize jobs""" - raise NotImplementedError("Not implemented") - - def build_context(self, serialized_jobs): - """Build context list for enqueue job""" - raise NotImplementedError("Not implemented") - - def get_context_list(self, session, batch_jobs): - """Return list of jobs (context) for celery queue""" - - context_list: List[dict] = None - serialized_jobs = self.serialize_jobs(session, batch_jobs) - #build context for each celery task - context_list = self.build_context(serialized_jobs) - - return context_list - - async def parse_result(self, data): - """Function callback for celery async jobs to store results""" - data = json.loads(data) - - with DbSession() as session: - try: - fin_json = data['result']['ret'] - context = data['result']['context'] - except KeyError as kerr: - self.logger.error(kerr) - return False + async def async_callback(self, async_func, *args): + """Wrapper function to await on async function""" + await async_func(*args) - self.logger.info('Parsing: %s', fin_json) - if self.operation == Operation.COMPILE: - self.process_compile_results(session, fin_json, context) - elif self.operation == Operation.EVAL: - self.process_eval_results(session, fin_json, context) - else: - raise CustomError('Unsupported tuning operation') + def async_wrap(self, async_func, *args): + """Run async function""" + try: + asyncio.run(self.async_callback(async_func, *args)) + except KeyboardInterrupt: + self.logger.warning("Keyboard interrupt caught, terminating") + + def reset_job_state_on_ctrl_c(self): + """Reset job state for jobs in flight""" + temp_obj = SimpleDict() + temp_obj.session_id = self.args.session_id # pylint: disable=invalid-name + attribs = ["state"] + temp_obj.state = 1 + + self.logger.info("Resetting job state in DB for in flight jobs") + + if self.operation == Operation.COMPILE: + state = 16 + elif self.operation == Operation.EVAL: + state = 12 + + query = gen_update_query( + temp_obj, + attribs, + self.dbt.job_table.__tablename__, + [("session", self.args.session_id), ("state", state)], + ) + with DbSession() as session: + + # pylint: disable=duplicate-code + def callback() -> bool: + session.execute(query) + session.commit() + return True + + # pylint: enable=duplicate-code + + assert session_retry(session, callback, lambda x: x(), self.logger) + self.logger.info("Sucessfully reset job state") + return True + + return False - return True + def has_tunable_operation(self): + """Check if current operation is a tuning operation""" + raise NotImplementedError("Not implemented") - def process_compile_results(self, session, fin_json, context): - """Process result from fin_build worker""" - raise NotImplementedError("Not implemented") + def get_job_attr(self): + """Get job attr for row selection""" + job_attr: List[str] = None + try: + job_attr = [column.name for column in inspect(self.dbt.job_table).c] + job_attr.remove("insert_ts") + job_attr.remove("update_ts") + except NoInspectionAvailable as error: + self.logger.warning("Ignoring error for init_session: %s", error) + return job_attr + + def check_jobs_found( + self, job_rows: List[SimpleDict], find_state: List[Any], session_id: int + ) -> bool: + """check for end of jobs""" + if not job_rows: + # we are done + self.logger.warning("No %s jobs found, session %s", find_state, session_id) + return False + return True - def process_eval_results(self, session, fin_json, context): - """Process fin_json result""" - raise NotImplementedError("Not implemented") + @lru_cache(1) + def get_context_items(self): + """Helper function to get items for celery job context""" + kwargs = None + f_vals = self.get_f_vals(Machine(local_machine=True), range(0), tuning=True) + kwargs = self.get_kwargs(0, f_vals, tuning=True) + return kwargs + + def serialize_jobs(self, session, batch_jobs): + """Return list of serialize jobs""" + raise NotImplementedError("Not implemented") + + def build_context(self, serialized_jobs): + """Build context list for enqueue job""" + raise NotImplementedError("Not implemented") + + def get_context_list(self, session, batch_jobs): + """Return list of jobs (context) for celery queue""" + + context_list: List[dict] = None + serialized_jobs = self.serialize_jobs(session, batch_jobs) + # build context for each celery task + context_list = self.build_context(serialized_jobs) + + return context_list + + async def parse_result(self, data): + """Function callback for celery async jobs to store results""" + data = json.loads(data) + + with DbSession() as session: + try: + fin_json = data["result"]["ret"] + context = data["result"]["context"] + except KeyError as kerr: + self.logger.error(kerr) + return False + + self.logger.info("Parsing: %s", fin_json) + if self.operation == Operation.COMPILE: + self.process_compile_results(session, fin_json, context) + elif self.operation == Operation.EVAL: + self.process_eval_results(session, fin_json, context) + else: + raise CustomError("Unsupported tuning operation") + + return True + + def process_compile_results(self, session, fin_json, context): + """Process result from fin_build worker""" + raise NotImplementedError("Not implemented") + + def process_eval_results(self, session, fin_json, context): + """Process fin_json result""" + raise NotImplementedError("Not implemented") From 06fbe3cc6f7cbbef064ecf8d7b2fe7feaf7ed88b Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Thu, 16 Oct 2025 11:09:49 +0000 Subject: [PATCH 03/10] auto format --- tuna/miopen/miopen_lib.py | 1631 +++++++++++++++++++------------------ 1 file changed, 860 insertions(+), 771 deletions(-) diff --git a/tuna/miopen/miopen_lib.py b/tuna/miopen/miopen_lib.py index a20e1a23..3450be34 100644 --- a/tuna/miopen/miopen_lib.py +++ b/tuna/miopen/miopen_lib.py @@ -60,7 +60,8 @@ from tuna.miopen.db.triggers import drop_miopen_triggers, get_miopen_triggers from tuna.miopen.utils.config_type import ConfigType from tuna.miopen.db.tables import MIOpenDBTables -#from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue + +# from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue from tuna.miopen.utils.json_to_sql import process_fdb_w_kernels, process_tuning_data from tuna.miopen.utils.json_to_sql import process_pdb_compile from tuna.miopen.utils.json_to_sql import clean_cache_table @@ -75,782 +76,870 @@ class MIOpen(MITunaInterface): - """Class to support MIOpen specific tuning functionality""" - - # pylint: disable=too-many-public-methods - - def __init__(self): - super().__init__(library=Library.MIOPEN) - self.args = None - self.set_state = None - - def parse_args(self): - # pylint: disable=too-many-statements - """Function to parse arguments""" - parser = setup_arg_parser( - 'Run Performance Tuning on a certain architecture', [ - TunaArgs.ARCH, TunaArgs.NUM_CU, TunaArgs.VERSION, - TunaArgs.CONFIG_TYPE, TunaArgs.SESSION_ID, TunaArgs.MACHINES, - TunaArgs.REMOTE_MACHINE, TunaArgs.LABEL, TunaArgs.RESTART_MACHINE, - TunaArgs.DOCKER_NAME, TunaArgs.SHUTDOWN_WORKERS, - TunaArgs.ENQUEUE_ONLY - ]) - parser.add_argument( - '--find_mode', - dest='find_mode', - type=int, - default=1, - help='Set the MIOPEN_FIND_MODE environment variable for MIOpen', - choices=['1', '3']) - parser.add_argument('--ticket', - dest='ticket', - type=str, - default=None, - help='Specify tuning ticket number') - parser.add_argument( - '--solver_id', - type=int, - dest='solver_id', - default=None, - help='Specify solver_id. Use --list_solvers to see options') - parser.add_argument('--dynamic_solvers_only', - dest='dynamic_solvers_only', - action='store_true', - default=False, - help='Only tune dynamic solvers.') - parser.add_argument( - '-B', - '--blacklist', - dest='blacklist', - type=str, - default=None, - help='MIOpen blacklist algorithm, if multiple then comma separate') - parser.add_argument('-i', - '--reset_interval', - type=int, - dest='reset_interval', - required=False, - help='Restart interval for job in hours.') - parser.add_argument( - '--gpu_lim', - dest='gpu_lim', - type=int, - default=None, - help='Limit the number of gpu workers created by Tuna, index from 0') - - parser.add_argument( - '-R', - '--rich_data', - dest='rich_data', - action='store_true', - default=False, - help='record intermediate parameter results from perf tuning') - - subcommands = parser.add_subcommands(required=False) - subcommands.add_subcommand('import_configs', - get_import_cfg_parser(), - required=False) - - subcommands.add_subcommand('load_job', - get_load_job_parser(), - required=False) - - subcommands.add_subcommand('export_db', - get_export_db_parser(), - required=False) - - subcommands.add_subcommand('update_golden', - get_update_golden_parser(), - required=False) - - group = parser.add_mutually_exclusive_group() - group.add_argument('--add_tables', - dest='add_tables', - action='store_true', - help='Add MIOpen library specific tables') - - group.add_argument('--init_session', - action='store_true', - dest='init_session', - help='Set up a new tuning session.') - group.add_argument( - '--fin_steps', - type=str, - dest='fin_steps', - help='Specify fin steps. Multiple steps should be comma separated.') - group.add_argument('--list_solvers', - action='store_true', - dest='list_solvers', - help='List of solvers from the solver table') - - # JD: implement the following two using fin_steps - group.add_argument('--update_solvers', - dest='update_solvers', - action='store_true', - help='Update the list of solvers in the database') - group.add_argument('--update_applicability', - dest='update_applicability', - action='store_true', - help='Update the applicability table in the database') - group.add_argument('-s', - '--status', - dest='check_status', - action='store_true', - default=False, - help='Check the status of machines') - - group.add_argument('-e', - '--exec', - dest='execute_cmd', - type=str, - default=None, - help='execute on each machine') - - self.args = parser.parse_args() - - if self.args.config_type is None: - self.args.config_type = ConfigType.convolution - - #overwritte common lib args with subcommand args value - if self.args.subcommand is not None: - self.overwrite_common_args() - - if len(sys.argv) == 1: - parser.print_help() - sys.exit(-1) - - if self.args.list_solvers: - print_solvers() - raise CustomError('Printing solvers...') - - if self.args.fin_steps and self.args.subcommand != 'load_job': - self.check_fin_args(parser) - self.set_prefix() - - if self.args.find_mode is None and not (self.args.check_status or - self.args.restart_machine or - self.args.execute_cmd): - parser.error('find_mode must be specified for a tuning run') - - if self.args.blacklist: - self.check_blacklist(parser) - - args_check(self.args, parser) - - fin_session_steps = [ - 'miopen_find_compile', 'miopen_find_eval', 'miopen_perf_compile', - 'miopen_perf_eval', 'get_applicability', 'find_compile', 'find_eval' - ] - has_fin = False - if self.args.fin_steps: - has_fin = all(x in fin_session_steps for x in self.args.fin_steps) - - if (self.args.update_applicability or has_fin) and not self.args.session_id: - parser.error("session_id must be specified with this operation") - - self.dbt = MIOpenDBTables(session_id=self.args.session_id, - config_type=self.args.config_type) - self.update_operation() - - def set_prefix(self): - """Set redis key prefix""" - if isinstance(self.args.fin_steps, Iterable): - steps_str = ('-').join(x for x in self.args.fin_steps) - self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_"\ - f"{steps_str}" - else: - steps_str = self.args.fin_steps[0] - self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_{steps_str}" - - self.logger.info('redis prefix: %s', self.prefix) - - def overwrite_common_args(self): - """Overwrite common MIOpen_lib args with subcommand args""" - if self.args.subcommand is not None: - subc_dict = vars(self.args.get(self.args.subcommand)) - for sub_key in subc_dict: - if sub_key in vars(self.args): - self.args[sub_key] = subc_dict.get(sub_key) - - def check_fin_args(self, parser): - """! Helper function for fin args - @param parser The command line argument parser + """Class to support MIOpen specific tuning functionality""" + + # pylint: disable=too-many-public-methods + + def __init__(self): + super().__init__(library=Library.MIOPEN) + self.args = None + self.set_state = None + + def parse_args(self): + # pylint: disable=too-many-statements + """Function to parse arguments""" + parser = setup_arg_parser( + "Run Performance Tuning on a certain architecture", + [ + TunaArgs.ARCH, + TunaArgs.NUM_CU, + TunaArgs.VERSION, + TunaArgs.CONFIG_TYPE, + TunaArgs.SESSION_ID, + TunaArgs.MACHINES, + TunaArgs.REMOTE_MACHINE, + TunaArgs.LABEL, + TunaArgs.RESTART_MACHINE, + TunaArgs.DOCKER_NAME, + TunaArgs.SHUTDOWN_WORKERS, + TunaArgs.ENQUEUE_ONLY, + ], + ) + parser.add_argument( + "--find_mode", + dest="find_mode", + type=int, + default=1, + help="Set the MIOPEN_FIND_MODE environment variable for MIOpen", + choices=["1", "3"], + ) + parser.add_argument( + "--ticket", + dest="ticket", + type=str, + default=None, + help="Specify tuning ticket number", + ) + parser.add_argument( + "--solver_id", + type=int, + dest="solver_id", + default=None, + help="Specify solver_id. Use --list_solvers to see options", + ) + parser.add_argument( + "--dynamic_solvers_only", + dest="dynamic_solvers_only", + action="store_true", + default=False, + help="Only tune dynamic solvers.", + ) + parser.add_argument( + "-B", + "--blacklist", + dest="blacklist", + type=str, + default=None, + help="MIOpen blacklist algorithm, if multiple then comma separate", + ) + parser.add_argument( + "-i", + "--reset_interval", + type=int, + dest="reset_interval", + required=False, + help="Restart interval for job in hours.", + ) + parser.add_argument( + "--gpu_lim", + dest="gpu_lim", + type=int, + default=None, + help="Limit the number of gpu workers created by Tuna, index from 0", + ) + + parser.add_argument( + "-R", + "--rich_data", + dest="rich_data", + action="store_true", + default=False, + help="record intermediate parameter results from perf tuning", + ) + + subcommands = parser.add_subcommands(required=False) + subcommands.add_subcommand( + "import_configs", get_import_cfg_parser(), required=False + ) + + subcommands.add_subcommand("load_job", get_load_job_parser(), required=False) + + subcommands.add_subcommand("export_db", get_export_db_parser(), required=False) + + subcommands.add_subcommand( + "update_golden", get_update_golden_parser(), required=False + ) + + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--add_tables", + dest="add_tables", + action="store_true", + help="Add MIOpen library specific tables", + ) + + group.add_argument( + "--init_session", + action="store_true", + dest="init_session", + help="Set up a new tuning session.", + ) + group.add_argument( + "--fin_steps", + type=str, + dest="fin_steps", + help="Specify fin steps. Multiple steps should be comma separated.", + ) + group.add_argument( + "--list_solvers", + action="store_true", + dest="list_solvers", + help="List of solvers from the solver table", + ) + + # JD: implement the following two using fin_steps + group.add_argument( + "--update_solvers", + dest="update_solvers", + action="store_true", + help="Update the list of solvers in the database", + ) + group.add_argument( + "--update_applicability", + dest="update_applicability", + action="store_true", + help="Update the applicability table in the database", + ) + group.add_argument( + "-s", + "--status", + dest="check_status", + action="store_true", + default=False, + help="Check the status of machines", + ) + + group.add_argument( + "-e", + "--exec", + dest="execute_cmd", + type=str, + default=None, + help="execute on each machine", + ) + + self.args = parser.parse_args() + + if self.args.config_type is None: + self.args.config_type = ConfigType.convolution + + # overwritte common lib args with subcommand args value + if self.args.subcommand is not None: + self.overwrite_common_args() + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(-1) + + if self.args.list_solvers: + print_solvers() + raise CustomError("Printing solvers...") + + if self.args.fin_steps and self.args.subcommand != "load_job": + self.check_fin_args(parser) + self.set_prefix() + + if self.args.find_mode is None and not ( + self.args.check_status or self.args.restart_machine or self.args.execute_cmd + ): + parser.error("find_mode must be specified for a tuning run") + + if self.args.blacklist: + self.check_blacklist(parser) + + args_check(self.args, parser) + + fin_session_steps = [ + "miopen_find_compile", + "miopen_find_eval", + "miopen_perf_compile", + "miopen_perf_eval", + "get_applicability", + "find_compile", + "find_eval", + ] + has_fin = False + if self.args.fin_steps: + has_fin = all(x in fin_session_steps for x in self.args.fin_steps) + + if (self.args.update_applicability or has_fin) and not self.args.session_id: + parser.error("session_id must be specified with this operation") + + self.dbt = MIOpenDBTables( + session_id=self.args.session_id, config_type=self.args.config_type + ) + self.update_operation() + + def set_prefix(self): + """Set redis key prefix""" + if isinstance(self.args.fin_steps, Iterable): + steps_str = ("-").join(x for x in self.args.fin_steps) + self.prefix = ( + f"d_{self.db_name}_sess_{self.args.session_id}_" f"{steps_str}" + ) + else: + steps_str = self.args.fin_steps[0] + self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_{steps_str}" + + self.logger.info("redis prefix: %s", self.prefix) + + def overwrite_common_args(self): + """Overwrite common MIOpen_lib args with subcommand args""" + if self.args.subcommand is not None: + subc_dict = vars(self.args.get(self.args.subcommand)) + for sub_key in subc_dict: + if sub_key in vars(self.args): + self.args[sub_key] = subc_dict.get(sub_key) + + def check_fin_args(self, parser): + """! Helper function for fin args + @param parser The command line argument parser + """ + valid_fin_steps = list(k for k in FinStep.__members__) + if "," in self.args.fin_steps: + parser.error("Multiple fin_steps currently not supported") + f_steps = self.args.fin_steps.split(",") + self.args.fin_steps = f_steps + for step in self.args.fin_steps: + if step not in valid_fin_steps: + parser.error(f"Supported fin steps are: {valid_fin_steps}") + assert len(self.args.fin_steps) == 1 + + def check_blacklist(self, parser): + """! Helper function + @param parser The command line argument parser + @return ret Boolean value + """ + self.args.blacklist = self.args.blacklist.split(",") + for sol in self.args.blacklist: + if sol not in MIOPEN_ALG_LIST: + parser.error("Incorrect blacklist value") + + def do_fin_work(self, gpu, f_vals): + """! Helper function to execute job independendent fin work + @param gpu Unique ID of the GPU + @param f_vals Dict containing runtime information + """ + kwargs = self.get_kwargs(gpu, f_vals) + fin_worker = FinClass(**kwargs) + + if self.args.update_solvers: + if not fin_worker.get_solvers(): + self.logger.error("No solvers returned from Fin class") + + return True + + def launch_worker(self, gpu_idx, f_vals, worker_lst): + """! Function to launch worker + @param gpu_idx Unique ID of the GPU + @param f_vals Dict containing runtime information + @param worker_lst List containing worker instances + @return ret Boolean value """ - valid_fin_steps = list(k for k in FinStep.__members__) - if ',' in self.args.fin_steps: - parser.error('Multiple fin_steps currently not supported') - f_steps = self.args.fin_steps.split(',') - self.args.fin_steps = f_steps - for step in self.args.fin_steps: - if step not in valid_fin_steps: - parser.error(f"Supported fin steps are: {valid_fin_steps}") - assert len(self.args.fin_steps) == 1 - - def check_blacklist(self, parser): - """! Helper function - @param parser The command line argument parser - @return ret Boolean value - """ - self.args.blacklist = self.args.blacklist.split(',') - for sol in self.args.blacklist: - if sol not in MIOPEN_ALG_LIST: - parser.error("Incorrect blacklist value") - - def do_fin_work(self, gpu, f_vals): - """! Helper function to execute job independendent fin work - @param gpu Unique ID of the GPU - @param f_vals Dict containing runtime information - """ - kwargs = self.get_kwargs(gpu, f_vals) - fin_worker = FinClass(**kwargs) - - if self.args.update_solvers: - if not fin_worker.get_solvers(): - self.logger.error('No solvers returned from Fin class') - - return True - - def launch_worker(self, gpu_idx, f_vals, worker_lst): - """! Function to launch worker - @param gpu_idx Unique ID of the GPU - @param f_vals Dict containing runtime information - @param worker_lst List containing worker instances - @return ret Boolean value - """ - # pylint: disable=too-many-branches - worker = None - kwargs = self.get_kwargs(gpu_idx, f_vals) - if self.args.update_applicability: - kwargs['fin_steps'] = ['applicability'] - worker = FinClass(**kwargs) - worker.start() - worker_lst.append(worker) - return True - - worker = FinClass(**kwargs) - ret = False - if self.args.check_status: - if not super().check_status(worker, f_vals["b_first"], gpu_idx, - f_vals["machine"], self.args.docker_name): - ret = True - elif self.args.init_session: - Session().add_new_session(self.args, worker) - elif self.args.execute_cmd: - # JD: Move the worker.exec_command to machine - self.logger.info(self.args.execute_cmd) - _, _, _ = worker.exec_command(self.args.execute_cmd + " 2>&1 ") - - return ret - - def compose_worker_list(self, machines): - # pylint: disable=too-many-branches - """! Helper function to compose worker_list - @param machines List of machines to execute on - """ - worker_lst = [] - fin_work_done = False - for machine in machines: - if self.args.restart_machine: - machine.restart_server(wait=False) - continue - - #fin_steps should only contain one step - worker_ids = None - if self.args.fin_steps and 'eval' in self.args.fin_steps[0]: - worker_ids = machine.get_avail_gpus() - if self.args.gpu_lim and self.args.gpu_lim < len(worker_ids): - worker_ids = range(self.args.gpu_lim) - else: - worker_ids = super().get_num_procs(machine) - - if self.args.update_applicability: - f_vals = super().get_f_vals(machine, [1]) - kwargs = self.get_kwargs(0, f_vals) - kwargs['fin_steps'] = ['applicability'] + # pylint: disable=too-many-branches + worker = None + kwargs = self.get_kwargs(gpu_idx, f_vals) + if self.args.update_applicability: + kwargs["fin_steps"] = ["applicability"] + worker = FinClass(**kwargs) + worker.start() + worker_lst.append(worker) + return True + worker = FinClass(**kwargs) - query = worker.query_cfgs(self.args.label) - cfg_rows = query.all() - len_rows = len(cfg_rows) - proc_lim = (len_rows + 99) / 100 - if 32 < proc_lim: - proc_lim = 32 - while len(worker_ids) > proc_lim: - worker_ids.pop() - - if len(worker_ids) == 0: - return None - - f_vals = super().get_f_vals(machine, worker_ids) - - if (self.args.update_solvers) and not fin_work_done: - self.do_fin_work(0, f_vals) - fin_work_done = True - break - - for gpu_idx in worker_ids: - self.logger.info('launch mid %u, proc %u', machine.id, gpu_idx) - if not self.launch_worker(gpu_idx, f_vals, worker_lst): - break - - return worker_lst - - def add_tables(self): - """! Function to create new DB tables - @return Bool - """ - ret_t = create_tables(get_miopen_tables()) - self.logger.info('DB creation successful: %s', ret_t) - recreate_triggers(drop_miopen_triggers(), get_miopen_triggers()) - return True - - def run(self): - # pylint: disable=duplicate-code - """! Main function to launch library""" - res = None - if self.args is None: - self.parse_args() - - if self.args.add_tables: - self.add_tables() - return None - - if self.args.subcommand is not None and self.args.subcommand == 'import_configs': - run_import_configs(self.args.import_configs, self.logger) - return None - - if self.args.subcommand is not None and self.args.subcommand == 'load_job': - run_load_job(self.args.load_job, self.logger) - return None - - if self.args.subcommand is not None and self.args.subcommand == 'export_db': - run_export_db(self.args.export_db, self.logger) - return None - - if self.args.subcommand is not None and self.args.subcommand == 'update_golden': - run_update_golden(self.args.update_golden, self.logger) - return None - - machines = load_machines(self.args) - res = self.compose_worker_list(machines) - return res - - def get_envmt(self): - """! Function to construct environment var - """ - envmt = ["MIOPEN_LOG_LEVEL=4"] - - envmt.append("MIOPEN_SQLITE_KERN_CACHE=ON") - envmt.append("MIOPEN_DEBUG_IMPLICIT_GEMM_FIND_ALL_SOLUTIONS=1") - - if self.args.find_mode: - envmt.append(f"MIOPEN_FIND_MODE={self.args.find_mode}") - - if self.args.blacklist: - bk_str = ", ".join([f"{arg}=0" for arg in self.args.blacklist]) - for bk_var in bk_str.split(','): - envmt.append(bk_var) - - return envmt - - def get_kwargs(self, gpu_idx, f_vals, tuning=False): - """! Helper function to set up kwargs for worker instances - @param gpu_idx Unique ID of the GPU - @param f_vals Dict containing runtime information - @param tuning Boolean that indicates if kwargs are for a tuning step - @return kwargs Dictionary - """ - kwargs = super().get_kwargs(gpu_idx, f_vals, tuning) - kwargs['fin_steps'] = self.args.fin_steps - kwargs['dynamic_solvers_only'] = self.args.dynamic_solvers_only - kwargs['config_type'] = self.args.config_type - kwargs['reset_interval'] = self.args.reset_interval - - return kwargs - - def get_job_list(self, session, find_state, claim_num): - """! Get list of jobs - @param session DB session - @param find_state DB job state - @param claim_num Number of DB jobs to pick up - @return List of DB jobs - - """ - job_list = self.get_job_objs(session, find_state, self.args.label, self.dbt, - self.get_job_attr(), claim_num, - self.args.fin_steps) - - return job_list - - def get_job_objs(self, - session: DbSession, - find_state: list, - label: str, - dbt: DBTablesInterface, - job_attr: List[str], - claim_num: int = None, - fin_steps: List[str] = None) -> List[SimpleDict]: - """! Get list of job objects - @param session DB session - @param find_state DB job state - @param label DB job reason - @param dbt Class representing all DB tables associated with this class - @param job_attr List of DB job columns - @param claim_num Number of DB jobs to pick up - @param fin_steps List of MIFin steps - @return List of DB jobs - """ - entries: List[Tuple[SimpleDict, ...]] - conds: List[str] = [f"session={dbt.session.id}", "valid=1"] - - if label: - conds.append(f"reason='{label}'") - - conds.append(f"retries<{self.max_job_retries}") - conds.append("state in (" + str(find_state).strip('{').strip('}') + ")") - - entries = self.compose_work_objs(session, conds, dbt, job_attr, claim_num, - fin_steps) - return entries - - def compose_work_objs(self, - session: DbSession, - conds: List[str], - dbt: DBTablesInterface, - job_attr: List[str], - claim_num: int = None, - fin_steps: List[str] = None) -> List[SimpleDict]: - """! Query a job list for update - @param session DB session - @param conds List of conditions for DB job WHERE clause - @param dbt Class representing all DB tables associated with this class - @param job_attr List of DB job columns - @param fin_steps List of MIFin steps - @return List of MIFin work objects - """ - job_entries = [] - if fin_steps: - conds.append(f"fin_step like '%{fin_steps[0]}%'") - else: - conds.append("fin_step='not_fin'") - - cond_str = ' AND '.join(conds) - if cond_str: - cond_str = f"WHERE {cond_str}" - if claim_num: - cond_str += f" ORDER BY retries,config ASC LIMIT {claim_num} FOR UPDATE SKIP LOCKED" - else: - cond_str += " ORDER BY retries,config ASC FOR UPDATE SKIP LOCKED" - - job_entries = gen_select_objs(session, job_attr, - dbt.job_table.__tablename__, cond_str) - - return job_entries - - def compose_work_objs_fin(self, session, job_entries, - dbt) -> List[Tuple[SimpleDict, SimpleDict]]: - """! Return jobs for fin work - @param session DB session - @param job_entries List of DB jobs - @param dbt Class representing all DB tables associated with this class - @return ret Job tuple - """ - ret = [] - - cfg_rel = { - key: { - 'key': list(val.local_columns)[0].name, - 'ftble': str(list(val.remote_side)[0]).split('.', maxsplit=1)[0], - 'fkey': str(list(val.remote_side)[0]).split('.')[1] - } for key, val in inspect(dbt.config_table).relationships.items() - } - - if job_entries: - id_str = ','.join({str(job.config) for job in job_entries}) - cfg_cond_str = f"where valid=1 and id in ({id_str})" - cfg_attr = [column.name for column in inspect(dbt.config_table).c] - cfg_entries = gen_select_objs(session, cfg_attr, - dbt.config_table.__tablename__, - cfg_cond_str) - - cfg_entries = self.attach_tensors(session, cfg_rel, cfg_entries) - - cfg_map = {cfg.id: cfg for cfg in cfg_entries} - - for job in job_entries: - ret.append((job, cfg_map[job.config])) - - return ret - - def attach_tensors(self, session, cfg_rel, cfg_entries): - """! Attach tensor relationship information to config entries - @param session DB session - @param cfg_rel DB Config col value - @param cfg_entries List of DB Config entries - @return cfg_entries List of DB Config entries with attached tensors (foreign keys) - - """ - for key, val in cfg_rel.items(): - rel_attr = [ - column.name - for column in inspect(get_class_by_tablename(val['ftble'])).c - ] - val['fattr'] = rel_attr - - for cfg in cfg_entries: - for key, val in cfg_rel.items(): - rel_val = getattr(cfg, val['key']) - rel_cond_str = f"where {val['fkey']}={rel_val}" - setattr( - cfg, key, - gen_select_objs(session, val['fattr'], val['ftble'], - rel_cond_str)[0]) - return cfg_entries - - #deprecated - def get_job_tables(self, job_rows: List[Tuple[SimpleDict, ...]], - job_attr: List[str]) -> List[SimpleDict]: - """Find job tables in query results""" - if has_attr_set(job_rows[0], job_attr): - job_tables: List[SimpleDict] = job_rows - else: - job_i: int = 0 - tble: SimpleDict - for i, tble in enumerate(job_rows[0]): - if has_attr_set(tble, job_attr): - job_i = i - break - job_tables = [row[job_i] for row in job_rows] - - return job_tables - - def update_operation(self): - """! Update the workers type that this library needs""" - if self.args.fin_steps: - if 'miopen_find_compile' in self.args.fin_steps \ - or 'miopen_perf_compile' in self.args.fin_steps: - self.fetch_state.add('new') - self.set_state = 'compile_start' - self.operation = Operation.COMPILE - elif 'miopen_find_eval' in self.args.fin_steps or 'miopen_perf_eval' in self.args.fin_steps: - self.fetch_state.add('new') - self.fetch_state.add('compiled') - self.set_state = 'eval_start' - self.operation = Operation.EVAL - - if self.args.update_applicability: - self.fetch_state.add("new") - - def has_tunable_operation(self): - """! Check if its a tuning loop operation - @return Bool value that represents if operation is tuning - """ - if self.args is None: - self.parse_args() - if self.args.subcommand and "load_job" in self.args.subcommand: - return False - if self.args.shutdown_workers: - return True - - return self.args.fin_steps and any( - s in self.args.fin_steps for s in MIOPEN_CELERY_STEPS) - - @lru_cache(1) - def get_fdb_attr(self): - """! Get find_db table attrs - @return fdb_attr find_db table attributes without timestamps - """ - fdb_attr = None - fdb_attr = [column.name for column in inspect(self.dbt.find_db_table).c] - fdb_attr.remove("insert_ts") - fdb_attr.remove("update_ts") - return fdb_attr - - @lru_cache(1) - def get_tuning_data_attr(self): - """! Get tuning_data table attrs - @return tuning_data_attr tuning_data table attributes without timestamps - """ - tuning_data_attr = None - tuning_data_attr = [ - column.name for column in inspect(self.dbt.tuning_data_table).c - ] - tuning_data_attr.remove("insert_ts") - tuning_data_attr.remove("update_ts") - return tuning_data_attr - - def serialize_jobs(self, session: DbSession, batch_jobs: List[Any]): - """! Return list of serialize jobs - @param session DB session - @param batch_jobs List of DB jobs - @return DB jobs, serialized - """ - entries = self.compose_work_objs_fin(session, batch_jobs, self.dbt) - return serialize_chunk(entries) - - def build_context( - self, serialized_jobs: Tuple[SimpleDict, SimpleDict]) -> List[dict]: - """Build context list for enqueue job""" - context_list = [] - kwargs = self.get_context_items() - fdb_attr = self.get_fdb_attr() - tuning_data_attr = self.get_tuning_data_attr() - for job, config in serialized_jobs: - context = { - 'job': job, - 'config': config, - 'operation': self.operation, - 'arch': self.dbt.session.arch, - 'num_cu': self.dbt.session.num_cu, - 'kwargs': kwargs, - 'rich_data': self.args.rich_data, - 'fdb_attr': fdb_attr, - 'tuning_data_attr': tuning_data_attr - } - context_list.append(context) - - return context_list - - def celery_enqueue_call(self, context: dict, q_name: str, task_id=False): - """! Enqueue job (context) for queue:q_name - @param context Context for Celery job - @param q_name Custom Celery queue name - @param task_id Custom Redis Key - """ - - #hacky way to get the Q_NAME to the task decorator for interpreter to decorate the - #function with correct q_name arg - #if import is moved to top it will result in circular imports - Q_NAME = q_name #pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name - from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue #pylint: disable=import-outside-toplevel - - return celery_enqueue.apply_async((context,), - task_id=('-').join([self.prefix, - uuid()]), - queue=q_name, - reply_to=q_name) - - def process_compile_results(self, session, fin_json, context): - """! Process result from fin_build worker - @param session DB session - @param fin_json MIFin results for job - @param context Context for Celery job - @return Boolean value - """ - job = SimpleDict(**context['job']) - pending = [] - solver_id_map = get_solver_ids() - - failed_job = False - result_str = '' - status = None - try: - if fin_json: - if 'success' in fin_json and fin_json["success"] is False: - status = [fin_json] + ret = False + if self.args.check_status: + if not super().check_status( + worker, + f_vals["b_first"], + gpu_idx, + f_vals["machine"], + self.args.docker_name, + ): + ret = True + elif self.args.init_session: + Session().add_new_session(self.args, worker) + elif self.args.execute_cmd: + # JD: Move the worker.exec_command to machine + self.logger.info(self.args.execute_cmd) + _, _, _ = worker.exec_command(self.args.execute_cmd + " 2>&1 ") + + return ret + + def compose_worker_list(self, machines): + # pylint: disable=too-many-branches + """! Helper function to compose worker_list + @param machines List of machines to execute on + """ + worker_lst = [] + fin_work_done = False + for machine in machines: + if self.args.restart_machine: + machine.restart_server(wait=False) + continue + + # fin_steps should only contain one step + worker_ids = None + if self.args.fin_steps and "eval" in self.args.fin_steps[0]: + worker_ids = machine.get_avail_gpus() + if self.args.gpu_lim and self.args.gpu_lim < len(worker_ids): + worker_ids = range(self.args.gpu_lim) + else: + worker_ids = super().get_num_procs(machine) + + if self.args.update_applicability: + f_vals = super().get_f_vals(machine, [1]) + kwargs = self.get_kwargs(0, f_vals) + kwargs["fin_steps"] = ["applicability"] + worker = FinClass(**kwargs) + query = worker.query_cfgs(self.args.label) + cfg_rows = query.all() + len_rows = len(cfg_rows) + proc_lim = (len_rows + 99) / 100 + if 32 < proc_lim: + proc_lim = 32 + while len(worker_ids) > proc_lim: + worker_ids.pop() + + if len(worker_ids) == 0: + return None + + f_vals = super().get_f_vals(machine, worker_ids) + + if (self.args.update_solvers) and not fin_work_done: + self.do_fin_work(0, f_vals) + fin_work_done = True + break + + for gpu_idx in worker_ids: + self.logger.info("launch mid %u, proc %u", machine.id, gpu_idx) + if not self.launch_worker(gpu_idx, f_vals, worker_lst): + break + + return worker_lst + + def add_tables(self): + """! Function to create new DB tables + @return Bool + """ + ret_t = create_tables(get_miopen_tables()) + self.logger.info("DB creation successful: %s", ret_t) + recreate_triggers(drop_miopen_triggers(), get_miopen_triggers()) + return True + + def run(self): + # pylint: disable=duplicate-code + """! Main function to launch library""" + res = None + if self.args is None: + self.parse_args() + + if self.args.add_tables: + self.add_tables() + return None + + if ( + self.args.subcommand is not None + and self.args.subcommand == "import_configs" + ): + run_import_configs(self.args.import_configs, self.logger) + return None + + if self.args.subcommand is not None and self.args.subcommand == "load_job": + run_load_job(self.args.load_job, self.logger) + return None + + if self.args.subcommand is not None and self.args.subcommand == "export_db": + run_export_db(self.args.export_db, self.logger) + return None + + if self.args.subcommand is not None and self.args.subcommand == "update_golden": + run_update_golden(self.args.update_golden, self.logger) + return None + + machines = load_machines(self.args) + res = self.compose_worker_list(machines) + return res + + def get_envmt(self): + """! Function to construct environment var""" + envmt = ["MIOPEN_LOG_LEVEL=4"] + + envmt.append("MIOPEN_SQLITE_KERN_CACHE=ON") + envmt.append("MIOPEN_DEBUG_IMPLICIT_GEMM_FIND_ALL_SOLUTIONS=1") + + if self.args.find_mode: + envmt.append(f"MIOPEN_FIND_MODE={self.args.find_mode}") + + if self.args.blacklist: + bk_str = ", ".join([f"{arg}=0" for arg in self.args.blacklist]) + for bk_var in bk_str.split(","): + envmt.append(bk_var) + + return envmt + + def get_kwargs(self, gpu_idx, f_vals, tuning=False): + """! Helper function to set up kwargs for worker instances + @param gpu_idx Unique ID of the GPU + @param f_vals Dict containing runtime information + @param tuning Boolean that indicates if kwargs are for a tuning step + @return kwargs Dictionary + """ + kwargs = super().get_kwargs(gpu_idx, f_vals, tuning) + kwargs["fin_steps"] = self.args.fin_steps + kwargs["dynamic_solvers_only"] = self.args.dynamic_solvers_only + kwargs["config_type"] = self.args.config_type + kwargs["reset_interval"] = self.args.reset_interval + + return kwargs + + def get_job_list(self, session, find_state, claim_num): + """! Get list of jobs + @param session DB session + @param find_state DB job state + @param claim_num Number of DB jobs to pick up + @return List of DB jobs + + """ + job_list = self.get_job_objs( + session, + find_state, + self.args.label, + self.dbt, + self.get_job_attr(), + claim_num, + self.args.fin_steps, + ) + + return job_list + + def get_job_objs( + self, + session: DbSession, + find_state: list, + label: str, + dbt: DBTablesInterface, + job_attr: List[str], + claim_num: int = None, + fin_steps: List[str] = None, + ) -> List[SimpleDict]: + """! Get list of job objects + @param session DB session + @param find_state DB job state + @param label DB job reason + @param dbt Class representing all DB tables associated with this class + @param job_attr List of DB job columns + @param claim_num Number of DB jobs to pick up + @param fin_steps List of MIFin steps + @return List of DB jobs + """ + entries: List[Tuple[SimpleDict, ...]] + conds: List[str] = [f"session={dbt.session.id}", "valid=1"] + + if label: + conds.append(f"reason='{label}'") + + conds.append(f"retries<{self.max_job_retries}") + conds.append("state in (" + str(find_state).strip("{").strip("}") + ")") + + entries = self.compose_work_objs( + session, conds, dbt, job_attr, claim_num, fin_steps + ) + return entries + + def compose_work_objs( + self, + session: DbSession, + conds: List[str], + dbt: DBTablesInterface, + job_attr: List[str], + claim_num: int = None, + fin_steps: List[str] = None, + ) -> List[SimpleDict]: + """! Query a job list for update + @param session DB session + @param conds List of conditions for DB job WHERE clause + @param dbt Class representing all DB tables associated with this class + @param job_attr List of DB job columns + @param fin_steps List of MIFin steps + @return List of MIFin work objects + """ + job_entries = [] + if fin_steps: + conds.append(f"fin_step like '%{fin_steps[0]}%'") else: - if 'miopen_find_compile_result' in fin_json: - status = process_fdb_w_kernels(session, fin_json, - copy.deepcopy(context), self.dbt, - context['fdb_attr'], pending) - - elif 'miopen_perf_compile_result' in fin_json: - status = process_pdb_compile(session, fin_json, job, self.dbt, - solver_id_map) - - success, result_str = get_fin_result(status) - failed_job = not success - - except (OperationalError, IntegrityError) as err: - self.logger.warning('FinBuild: Unable to update Database %s', err) - session.rollback() - failed_job = True - except DataError as err: - self.logger.warning( - 'FinBuild: Invalid data, likely large workspace. DB Error: %s', err) - session.rollback() - failed_job = True - - if failed_job: - set_job_state(session, job, self.dbt, 'errored', False, result=result_str) - else: - set_job_state(session, - job, - self.dbt, - 'compiled', - False, - result=result_str) - - return True - - def process_eval_results(self, session, fin_json, context): - """! Process fin_json result - @param session DB session - @param fin_json MIFin results for job - @param context Context for Celery job - @return Boolean value - """ - job = SimpleDict(**context['job']) - failed_job = True - result_str = '' - pending = [] - orig_state = 'compiled' - - try: - if fin_json: - if 'success' in fin_json and fin_json["success"] is False: - status = [fin_json] + conds.append("fin_step='not_fin'") + + cond_str = " AND ".join(conds) + if cond_str: + cond_str = f"WHERE {cond_str}" + if claim_num: + cond_str += ( + f" ORDER BY retries,config ASC LIMIT {claim_num} FOR UPDATE SKIP LOCKED" + ) else: - if 'miopen_find_eval_result' in fin_json: - status = process_fdb_w_kernels(session, - fin_json, - copy.deepcopy(context), - self.dbt, - context['fdb_attr'], - pending, - result_str='miopen_find_eval_result', - check_str='evaluated') - elif 'miopen_perf_eval_result' in fin_json: - status = process_fdb_w_kernels(session, - fin_json, - copy.deepcopy(context), - self.dbt, - context['fdb_attr'], - pending, - result_str='miopen_perf_eval_result', - check_str='evaluated') - if context["rich_data"]: - status = process_tuning_data(session, - fin_json, - copy.deepcopy(context), - self.dbt, - context['tuning_data_attr'], - pending, - result_str='miopen_perf_eval_result', - check_str='evaluated') - - success, result_str = get_fin_result(status) - failed_job = not success - - if failed_job: - if job.retries >= (MAX_ERRORED_JOB_RETRIES - 1): #pylint: disable=no-member - self.logger.warning('max job retries exhausted, setting to errored') - set_job_state(session, job, self.dbt, 'errored', result=result_str) + cond_str += " ORDER BY retries,config ASC FOR UPDATE SKIP LOCKED" + + job_entries = gen_select_objs( + session, job_attr, dbt.job_table.__tablename__, cond_str + ) + + return job_entries + + def compose_work_objs_fin( + self, session, job_entries, dbt + ) -> List[Tuple[SimpleDict, SimpleDict]]: + """! Return jobs for fin work + @param session DB session + @param job_entries List of DB jobs + @param dbt Class representing all DB tables associated with this class + @return ret Job tuple + """ + ret = [] + + cfg_rel = { + key: { + "key": list(val.local_columns)[0].name, + "ftble": str(list(val.remote_side)[0]).split(".", maxsplit=1)[0], + "fkey": str(list(val.remote_side)[0]).split(".")[1], + } + for key, val in inspect(dbt.config_table).relationships.items() + } + + if job_entries: + id_str = ",".join({str(job.config) for job in job_entries}) + cfg_cond_str = f"where valid=1 and id in ({id_str})" + cfg_attr = [column.name for column in inspect(dbt.config_table).c] + cfg_entries = gen_select_objs( + session, cfg_attr, dbt.config_table.__tablename__, cfg_cond_str + ) + + cfg_entries = self.attach_tensors(session, cfg_rel, cfg_entries) + + cfg_map = {cfg.id: cfg for cfg in cfg_entries} + + for job in job_entries: + ret.append((job, cfg_map[job.config])) + + return ret + + def attach_tensors(self, session, cfg_rel, cfg_entries): + """! Attach tensor relationship information to config entries + @param session DB session + @param cfg_rel DB Config col value + @param cfg_entries List of DB Config entries + @return cfg_entries List of DB Config entries with attached tensors (foreign keys) + + """ + for key, val in cfg_rel.items(): + rel_attr = [ + column.name + for column in inspect(get_class_by_tablename(val["ftble"])).c + ] + val["fattr"] = rel_attr + + for cfg in cfg_entries: + for key, val in cfg_rel.items(): + rel_val = getattr(cfg, val["key"]) + rel_cond_str = f"where {val['fkey']}={rel_val}" + setattr( + cfg, + key, + gen_select_objs(session, val["fattr"], val["ftble"], rel_cond_str)[ + 0 + ], + ) + return cfg_entries + + # deprecated + def get_job_tables( + self, job_rows: List[Tuple[SimpleDict, ...]], job_attr: List[str] + ) -> List[SimpleDict]: + """Find job tables in query results""" + if has_attr_set(job_rows[0], job_attr): + job_tables: List[SimpleDict] = job_rows else: - self.logger.warning('resetting job state to %s, incrementing retries', - orig_state) - set_job_state(session, + job_i: int = 0 + tble: SimpleDict + for i, tble in enumerate(job_rows[0]): + if has_attr_set(tble, job_attr): + job_i = i + break + job_tables = [row[job_i] for row in job_rows] + + return job_tables + + def update_operation(self): + """! Update the workers type that this library needs""" + if self.args.fin_steps: + if ( + "miopen_find_compile" in self.args.fin_steps + or "miopen_perf_compile" in self.args.fin_steps + ): + self.fetch_state.add("new") + self.set_state = "compile_start" + self.operation = Operation.COMPILE + elif ( + "miopen_find_eval" in self.args.fin_steps + or "miopen_perf_eval" in self.args.fin_steps + ): + self.fetch_state.add("new") + self.fetch_state.add("compiled") + self.set_state = "eval_start" + self.operation = Operation.EVAL + + if self.args.update_applicability: + self.fetch_state.add("new") + + def has_tunable_operation(self): + """! Check if its a tuning loop operation + @return Bool value that represents if operation is tuning + """ + if self.args is None: + self.parse_args() + if self.args.subcommand and "load_job" in self.args.subcommand: + return False + if self.args.shutdown_workers: + return True + + return self.args.fin_steps and any( + s in self.args.fin_steps for s in MIOPEN_CELERY_STEPS + ) + + @lru_cache(1) + def get_fdb_attr(self): + """! Get find_db table attrs + @return fdb_attr find_db table attributes without timestamps + """ + fdb_attr = None + fdb_attr = [column.name for column in inspect(self.dbt.find_db_table).c] + fdb_attr.remove("insert_ts") + fdb_attr.remove("update_ts") + return fdb_attr + + @lru_cache(1) + def get_tuning_data_attr(self): + """! Get tuning_data table attrs + @return tuning_data_attr tuning_data table attributes without timestamps + """ + tuning_data_attr = None + tuning_data_attr = [ + column.name for column in inspect(self.dbt.tuning_data_table).c + ] + tuning_data_attr.remove("insert_ts") + tuning_data_attr.remove("update_ts") + return tuning_data_attr + + def serialize_jobs(self, session: DbSession, batch_jobs: List[Any]): + """! Return list of serialize jobs + @param session DB session + @param batch_jobs List of DB jobs + @return DB jobs, serialized + """ + entries = self.compose_work_objs_fin(session, batch_jobs, self.dbt) + return serialize_chunk(entries) + + def build_context( + self, serialized_jobs: Tuple[SimpleDict, SimpleDict] + ) -> List[dict]: + """Build context list for enqueue job""" + context_list = [] + kwargs = self.get_context_items() + fdb_attr = self.get_fdb_attr() + tuning_data_attr = self.get_tuning_data_attr() + for job, config in serialized_jobs: + context = { + "job": job, + "config": config, + "operation": self.operation, + "arch": self.dbt.session.arch, + "num_cu": self.dbt.session.num_cu, + "kwargs": kwargs, + "rich_data": self.args.rich_data, + "fdb_attr": fdb_attr, + "tuning_data_attr": tuning_data_attr, + } + context_list.append(context) + + return context_list + + def celery_enqueue_call(self, context: dict, q_name: str, task_id=False): + """! Enqueue job (context) for queue:q_name + @param context Context for Celery job + @param q_name Custom Celery queue name + @param task_id Custom Redis Key + """ + + # hacky way to get the Q_NAME to the task decorator for interpreter to decorate the + # function with correct q_name arg + # if import is moved to top it will result in circular imports + Q_NAME = q_name # pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name + from tuna.miopen.celery_tuning.celery_tasks import ( + celery_enqueue, + ) # pylint: disable=import-outside-toplevel + + return celery_enqueue.apply_async( + (context,), + task_id=("-").join([self.prefix, uuid()]), + queue=q_name, + reply_to=q_name, + ) + + def process_compile_results(self, session, fin_json, context): + """! Process result from fin_build worker + @param session DB session + @param fin_json MIFin results for job + @param context Context for Celery job + @return Boolean value + """ + job = SimpleDict(**context["job"]) + pending = [] + solver_id_map = get_solver_ids() + + failed_job = False + result_str = "" + status = None + try: + if fin_json: + if "success" in fin_json and fin_json["success"] is False: + status = [fin_json] + else: + if "miopen_find_compile_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + ) + + elif "miopen_perf_compile_result" in fin_json: + status = process_pdb_compile( + session, fin_json, job, self.dbt, solver_id_map + ) + + success, result_str = get_fin_result(status) + failed_job = not success + + except (OperationalError, IntegrityError) as err: + self.logger.warning("FinBuild: Unable to update Database %s", err) + session.rollback() + failed_job = True + except DataError as err: + self.logger.warning( + "FinBuild: Invalid data, likely large workspace. DB Error: %s", err + ) + session.rollback() + failed_job = True + + if failed_job: + set_job_state(session, job, self.dbt, "errored", False, result=result_str) + else: + set_job_state(session, job, self.dbt, "compiled", False, result=result_str) + + return True + + def process_eval_results(self, session, fin_json, context): + """! Process fin_json result + @param session DB session + @param fin_json MIFin results for job + @param context Context for Celery job + @return Boolean value + """ + job = SimpleDict(**context["job"]) + failed_job = True + result_str = "" + pending = [] + orig_state = "compiled" + + try: + if fin_json: + if "success" in fin_json and fin_json["success"] is False: + status = [fin_json] + else: + if "miopen_find_eval_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + result_str="miopen_find_eval_result", + check_str="evaluated", + ) + elif "miopen_perf_eval_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + result_str="miopen_perf_eval_result", + check_str="evaluated", + ) + if context["rich_data"]: + status = process_tuning_data( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["tuning_data_attr"], + pending, + result_str="miopen_perf_eval_result", + check_str="evaluated", + ) + + success, result_str = get_fin_result(status) + failed_job = not success + + if failed_job: + if job.retries >= ( + MAX_ERRORED_JOB_RETRIES - 1 + ): # pylint: disable=no-member + self.logger.warning("max job retries exhausted, setting to errored") + set_job_state(session, job, self.dbt, "errored", result=result_str) + else: + self.logger.warning( + "resetting job state to %s, incrementing retries", orig_state + ) + set_job_state( + session, job, self.dbt, orig_state, increment_retries=True, - result=result_str) - else: - self.logger.info("\n\n Setting job state to evaluated") - set_job_state(session, job, self.dbt, 'evaluated', result=result_str) - clean_cache_table(self.dbt, job) - except (OperationalError, IntegrityError) as err: - self.logger.warning('FinBuild: Unable to update Database %s', err) - session.rollback() - set_job_state(session, job, self.dbt, 'errored', result=result_str) - - return True + result=result_str, + ) + else: + self.logger.info("\n\n Setting job state to evaluated") + set_job_state(session, job, self.dbt, "evaluated", result=result_str) + clean_cache_table(self.dbt, job) + except (OperationalError, IntegrityError) as err: + self.logger.warning("FinBuild: Unable to update Database %s", err) + session.rollback() + set_job_state(session, job, self.dbt, "errored", result=result_str) + + return True From ebc50d874ee6bf511382c6cfdaaee3c5caeda4ee Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Fri, 17 Oct 2025 08:34:43 +0000 Subject: [PATCH 04/10] WIP: parallell functionality --- tuna/miopen/miopen_lib.py | 8 +++ tuna/mituna_interface.py | 102 ++++++++++++++++++++++++++++++++------ 2 files changed, 96 insertions(+), 14 deletions(-) diff --git a/tuna/miopen/miopen_lib.py b/tuna/miopen/miopen_lib.py index 3450be34..f8be7c31 100644 --- a/tuna/miopen/miopen_lib.py +++ b/tuna/miopen/miopen_lib.py @@ -943,3 +943,11 @@ def process_eval_results(self, session, fin_json, context): set_job_state(session, job, self.dbt, "errored", result=result_str) return True + + def extract_job_id_from_context(self, context): + """Extract job ID from MIOpen celery task context""" + try: + # Extract job ID from the job context + return context.get("job", {}).get("id") + except (AttributeError, KeyError): + return None diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 345a915b..51a11474 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -81,6 +81,12 @@ def __init__(self, library=Library.MIOPEN) -> None: self.db_name = os.environ["TUNA_DB_NAME"] self.prefix = None + # Track jobs claimed by this specific instance when in distributor mode + self.claimed_job_ids = set() + self.completed_job_ids = set() + # if less than 25% of the jobs are remaining, we can grab more jobs + self.progress_factor = 0.25 + def check_docker(self, worker: WorkerInterface, dockername="miopentuna") -> bool: """! Checking for docker @param worker The worker interface instance @@ -343,12 +349,21 @@ def celery_enqueue_call(self, context, q_name, task_id=False): raise NotImplementedError("Not implemented") def enqueue_jobs(self, job_counter, job_batch_size, q_name): - """Enqueue celery jobs""" + """Enqueue celery jobs with machine-specific progress tracking""" self.logger.info("Starting enqueue") + current_batch_size = 0 + with DbSession() as session: while True: - job_list = [] - # get all the jobs from mySQL + # Check if we should enqueue more jobs based on OUR progress + if current_batch_size > 0: + if not self.should_enqueue_more_jobs(session, current_batch_size): + self.logger.info( + "Waiting for our current batch to progress before enqueuing more" + ) + break + + # Get jobs from database job_list = self.get_jobs( session, self.fetch_state, @@ -357,20 +372,63 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): job_batch_size, ) + if not job_list: + self.logger.info("No more jobs available to enqueue") + break + + # Track the jobs we just claimed + new_job_ids = {job.id for job in job_list} + self.claimed_job_ids.update(new_job_ids) + + self.logger.info("Claimed jobs: %s", list(new_job_ids)) + with job_counter_lock: job_counter.value = job_counter.value + len(job_list) - for i in range(0, len(job_list), job_batch_size): - batch_jobs = job_list[i : min(i + job_batch_size, len(job_list))] - context_list = self.get_context_list(session, batch_jobs) - for context in context_list: - # calling celery task, enqueuing to celery queue - self.celery_enqueue_call(context, q_name=q_name) + # Process all jobs in this batch (remove the inner for loop) + context_list = self.get_context_list(session, job_list) + for context in context_list: + # calling celery task, enqueuing to celery queue + self.celery_enqueue_call(context, q_name=q_name) + + current_batch_size = len(job_list) + self.logger.info( + "Job counter: %s, enqueued batch size: %s", + job_counter.value, + current_batch_size, + ) - self.logger.info("Job counter: %s", job_counter.value) - if not job_list: - self.logger.info("All tasks added to queue") - break + # Cleanup old tracking data periodically + self.cleanup_completed_jobs() + + def should_enqueue_more_jobs(self, session, current_batch_size): + """Check if we should enqueue more jobs based on THIS instance's progress""" + # Count only jobs claimed by this machine instance + our_in_progress_count = len(self.claimed_job_ids - self.completed_job_ids) + + # Allow enqueuing when less than 25% of our claimed jobs are still in progress + progress_threshold = current_batch_size * self.progress_factor + + self.logger.info( + "Our jobs in progress: %d, completed: %d, threshold: %d", + our_in_progress_count, + len(self.completed_job_ids), + progress_threshold, + ) + + return our_in_progress_count < progress_threshold + + def cleanup_completed_jobs(self): + """Periodically clean up old job tracking data""" + # Keep sets from growing indefinitely + max_tracking_size = 10000 + if len(self.completed_job_ids) > max_tracking_size: + # Keep only the most recent completions + recent_completions = list(self.completed_job_ids)[-5000:] + self.completed_job_ids = set(recent_completions) + + # Remove old claimed jobs that are completed + self.claimed_job_ids -= set(recent_completions[:-1000]) async def cleanup_redis_results(self, prefix): """Remove stale redis results by key""" @@ -525,6 +583,9 @@ def tune(self, job_batch_size=1000): with job_counter_lock: job_counter.value = job_counter.value - 1 + # Progress-aware polling - shorter intervals, smarter enqueuing + poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 5)) + # check for new jobs while consume_proc.is_alive(): enqueue_proc = Process( @@ -532,7 +593,7 @@ def tune(self, job_batch_size=1000): ) enqueue_proc.start() enqueue_proc.join() - time.sleep(10) + time.sleep(poll_interval) # Shorter, configurable polling consume_proc.join() @@ -663,6 +724,13 @@ async def parse_result(self, data): try: fin_json = data["result"]["ret"] context = data["result"]["context"] + + # Extract job ID from context to track completion + job_id = self.extract_job_id_from_context(context) + if job_id and job_id in self.claimed_job_ids: + self.completed_job_ids.add(job_id) + self.logger.info("Marked job %s as completed", job_id) + except KeyError as kerr: self.logger.error(kerr) return False @@ -677,6 +745,12 @@ async def parse_result(self, data): return True + def extract_job_id_from_context(self, context): + """Extract job ID from celery task context""" + # This needs to be implemented in the MIOpen subclass + # based on how job IDs are stored in the context + raise NotImplementedError("Subclass must implement job ID extraction") + def process_compile_results(self, session, fin_json, context): """Process result from fin_build worker""" raise NotImplementedError("Not implemented") From b30ba371777b098682c38481db8bec0e7d3dc29d Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Fri, 17 Oct 2025 10:35:18 +0000 Subject: [PATCH 05/10] perf(mituna_interface): optimize job state updates and improve enqueue reliability - Replace individual UPDATE queries with bulk UPDATE for job state changes - Add retry logic with configurable max attempts for database operations - Implement consecutive empty fetch tracking to prevent infinite loops - Add proper error handling and recovery for database session failures - Track enqueued jobs to prevent duplicate processing - Add configurable TUNA_MAX_EMPTY_FETCHES environment variable - Improve logging for better observability of enqueue process This optimization significantly reduces database round-trips when updating multiple job states and makes the enqueue process more resilient to transient failures. --- tuna/mituna_interface.py | 146 +++++++++++++++++++++++++-------------- 1 file changed, 96 insertions(+), 50 deletions(-) diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 51a11474..da17340b 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -289,15 +289,22 @@ def get_jobs( ids = [row.id for row in job_list] self.logger.info("%s jobs %s", find_state, ids) self.logger.info("Updating job state to %s", set_state) - for job in job_list: - job.state = set_state - if self.dbt is not None: - query: str = gen_update_query( - job, ["state"], self.dbt.job_table.__tablename__ - ) - else: - raise CustomError("DBTable must be set") + + # OPTIMIZATION: Use bulk UPDATE instead of individual updates + if self.dbt is not None: + id_str = ','.join(map(str, ids)) + query = f""" + UPDATE {self.dbt.job_table.__tablename__} + SET state = '{set_state}' + WHERE id IN ({id_str}) + """ session.execute(query) + + # Update local objects to reflect new state + for job in job_list: + job.state = set_state + else: + raise CustomError("DBTable must be set") session.commit() @@ -349,57 +356,96 @@ def celery_enqueue_call(self, context, q_name, task_id=False): raise NotImplementedError("Not implemented") def enqueue_jobs(self, job_counter, job_batch_size, q_name): - """Enqueue celery jobs with machine-specific progress tracking""" + """Enqueue celery jobs with machine-specific progress tracking and error handling""" self.logger.info("Starting enqueue") current_batch_size = 0 - - with DbSession() as session: - while True: - # Check if we should enqueue more jobs based on OUR progress - if current_batch_size > 0: - if not self.should_enqueue_more_jobs(session, current_batch_size): - self.logger.info( - "Waiting for our current batch to progress before enqueuing more" + + max_retries = 3 + retry_delay = 5 # seconds + consecutive_empty_fetches = 0 + max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) + + while True: + # Retry loop for database operations + for attempt in range(max_retries): + try: + with DbSession() as session: + # Check if we should enqueue more jobs based on OUR progress + if current_batch_size > 0: + if not self.should_enqueue_more_jobs(session, current_batch_size): + self.logger.info( + "Waiting for our current batch to progress before enqueuing more" + ) + return # Exit gracefully + + # Get jobs from database + job_list = self.get_jobs( + session, + self.fetch_state, + self.set_state, # pylint: disable=no-member + self.args.session_id, # pylint: disable=no-member + job_batch_size, ) - break - # Get jobs from database - job_list = self.get_jobs( - session, - self.fetch_state, - self.set_state, # pylint: disable=no-member - self.args.session_id, # pylint: disable=no-member - job_batch_size, - ) - - if not job_list: - self.logger.info("No more jobs available to enqueue") - break + if not job_list: + consecutive_empty_fetches += 1 + self.logger.info('No jobs found (attempt %d/%d)', + consecutive_empty_fetches, max_empty_fetches) + + if consecutive_empty_fetches >= max_empty_fetches: + self.logger.info('No new jobs after %d attempts. Exiting enqueue loop.', + max_empty_fetches) + return # Exit gracefully + + time.sleep(60) # Wait before next check + break # Break retry loop, continue main loop - # Track the jobs we just claimed - new_job_ids = {job.id for job in job_list} - self.claimed_job_ids.update(new_job_ids) + # Reset counter when jobs are found + consecutive_empty_fetches = 0 - self.logger.info("Claimed jobs: %s", list(new_job_ids)) + # Track the jobs we just claimed + new_job_ids = {job.id for job in job_list} + self.claimed_job_ids.update(new_job_ids) - with job_counter_lock: - job_counter.value = job_counter.value + len(job_list) + self.logger.info("Claimed jobs: %s", list(new_job_ids)) - # Process all jobs in this batch (remove the inner for loop) - context_list = self.get_context_list(session, job_list) - for context in context_list: - # calling celery task, enqueuing to celery queue - self.celery_enqueue_call(context, q_name=q_name) - - current_batch_size = len(job_list) - self.logger.info( - "Job counter: %s, enqueued batch size: %s", - job_counter.value, - current_batch_size, - ) + with job_counter_lock: + job_counter.value = job_counter.value + len(job_list) + + # Process all jobs in this batch + context_list = self.get_context_list(session, job_list) + for context in context_list: + try: + # calling celery task, enqueuing to celery queue + self.celery_enqueue_call(context, q_name=q_name) + except Exception as enqueue_err: # pylint: disable=broad-exception-caught + self.logger.error('Failed to enqueue job: %s', enqueue_err) + # Continue with other jobs rather than failing completely + continue + + current_batch_size = len(job_list) + self.logger.info( + "Job counter: %s, enqueued batch size: %s", + job_counter.value, + current_batch_size, + ) - # Cleanup old tracking data periodically - self.cleanup_completed_jobs() + # Cleanup old tracking data periodically + self.cleanup_completed_jobs() + break # Success, break retry loop + + except Exception as db_err: # pylint: disable=broad-exception-caught + self.logger.warning('Database error on attempt %d/%d: %s', + attempt + 1, max_retries, db_err) + if attempt < max_retries - 1: + time.sleep(retry_delay * (attempt + 1)) # Exponential backoff + else: + self.logger.error('Max retries exceeded for database operation. Exiting.') + raise + + # If we got here with no jobs, the consecutive_empty_fetches logic handled it + if not job_list: + continue def should_enqueue_more_jobs(self, session, current_batch_size): """Check if we should enqueue more jobs based on THIS instance's progress""" From 3001f9e21d20c4df1da2dfeb40e687be453a44e4 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 21 Oct 2025 07:44:09 +0000 Subject: [PATCH 06/10] used yapf formatter --- tests/test_celery.py | 1 + tuna/miopen/miopen_lib.py | 1551 ++++++++++++++++++------------------- tuna/mituna_interface.py | 1390 +++++++++++++++++---------------- 3 files changed, 1464 insertions(+), 1478 deletions(-) diff --git a/tests/test_celery.py b/tests/test_celery.py index 1e4aa01d..4e2a20e8 100644 --- a/tests/test_celery.py +++ b/tests/test_celery.py @@ -51,6 +51,7 @@ from tuna.miopen.worker.fin_utils import compose_config_obj, fin_job from tuna.miopen.utils.lib_helper import get_worker + @pytest.mark.asyncio async def test_celery_workers(): miopen = MIOpen() diff --git a/tuna/miopen/miopen_lib.py b/tuna/miopen/miopen_lib.py index f8be7c31..ab55d23c 100644 --- a/tuna/miopen/miopen_lib.py +++ b/tuna/miopen/miopen_lib.py @@ -76,474 +76,474 @@ class MIOpen(MITunaInterface): - """Class to support MIOpen specific tuning functionality""" - - # pylint: disable=too-many-public-methods - - def __init__(self): - super().__init__(library=Library.MIOPEN) - self.args = None - self.set_state = None - - def parse_args(self): - # pylint: disable=too-many-statements - """Function to parse arguments""" - parser = setup_arg_parser( - "Run Performance Tuning on a certain architecture", - [ - TunaArgs.ARCH, - TunaArgs.NUM_CU, - TunaArgs.VERSION, - TunaArgs.CONFIG_TYPE, - TunaArgs.SESSION_ID, - TunaArgs.MACHINES, - TunaArgs.REMOTE_MACHINE, - TunaArgs.LABEL, - TunaArgs.RESTART_MACHINE, - TunaArgs.DOCKER_NAME, - TunaArgs.SHUTDOWN_WORKERS, - TunaArgs.ENQUEUE_ONLY, - ], - ) - parser.add_argument( - "--find_mode", - dest="find_mode", - type=int, - default=1, - help="Set the MIOPEN_FIND_MODE environment variable for MIOpen", - choices=["1", "3"], - ) - parser.add_argument( - "--ticket", - dest="ticket", - type=str, - default=None, - help="Specify tuning ticket number", - ) - parser.add_argument( - "--solver_id", - type=int, - dest="solver_id", - default=None, - help="Specify solver_id. Use --list_solvers to see options", - ) - parser.add_argument( - "--dynamic_solvers_only", - dest="dynamic_solvers_only", - action="store_true", - default=False, - help="Only tune dynamic solvers.", - ) - parser.add_argument( - "-B", - "--blacklist", - dest="blacklist", - type=str, - default=None, - help="MIOpen blacklist algorithm, if multiple then comma separate", - ) - parser.add_argument( - "-i", - "--reset_interval", - type=int, - dest="reset_interval", - required=False, - help="Restart interval for job in hours.", - ) - parser.add_argument( - "--gpu_lim", - dest="gpu_lim", - type=int, - default=None, - help="Limit the number of gpu workers created by Tuna, index from 0", - ) - - parser.add_argument( - "-R", - "--rich_data", - dest="rich_data", - action="store_true", - default=False, - help="record intermediate parameter results from perf tuning", - ) - - subcommands = parser.add_subcommands(required=False) - subcommands.add_subcommand( - "import_configs", get_import_cfg_parser(), required=False - ) - - subcommands.add_subcommand("load_job", get_load_job_parser(), required=False) - - subcommands.add_subcommand("export_db", get_export_db_parser(), required=False) - - subcommands.add_subcommand( - "update_golden", get_update_golden_parser(), required=False - ) - - group = parser.add_mutually_exclusive_group() - group.add_argument( - "--add_tables", - dest="add_tables", - action="store_true", - help="Add MIOpen library specific tables", - ) - - group.add_argument( - "--init_session", - action="store_true", - dest="init_session", - help="Set up a new tuning session.", - ) - group.add_argument( - "--fin_steps", - type=str, - dest="fin_steps", - help="Specify fin steps. Multiple steps should be comma separated.", - ) - group.add_argument( - "--list_solvers", - action="store_true", - dest="list_solvers", - help="List of solvers from the solver table", - ) - - # JD: implement the following two using fin_steps - group.add_argument( - "--update_solvers", - dest="update_solvers", - action="store_true", - help="Update the list of solvers in the database", - ) - group.add_argument( - "--update_applicability", - dest="update_applicability", - action="store_true", - help="Update the applicability table in the database", - ) - group.add_argument( - "-s", - "--status", - dest="check_status", - action="store_true", - default=False, - help="Check the status of machines", - ) - - group.add_argument( - "-e", - "--exec", - dest="execute_cmd", - type=str, - default=None, - help="execute on each machine", - ) - - self.args = parser.parse_args() - - if self.args.config_type is None: - self.args.config_type = ConfigType.convolution - - # overwritte common lib args with subcommand args value - if self.args.subcommand is not None: - self.overwrite_common_args() - - if len(sys.argv) == 1: - parser.print_help() - sys.exit(-1) - - if self.args.list_solvers: - print_solvers() - raise CustomError("Printing solvers...") - - if self.args.fin_steps and self.args.subcommand != "load_job": - self.check_fin_args(parser) - self.set_prefix() - - if self.args.find_mode is None and not ( - self.args.check_status or self.args.restart_machine or self.args.execute_cmd - ): - parser.error("find_mode must be specified for a tuning run") - - if self.args.blacklist: - self.check_blacklist(parser) - - args_check(self.args, parser) - - fin_session_steps = [ - "miopen_find_compile", - "miopen_find_eval", - "miopen_perf_compile", - "miopen_perf_eval", - "get_applicability", - "find_compile", - "find_eval", - ] - has_fin = False - if self.args.fin_steps: - has_fin = all(x in fin_session_steps for x in self.args.fin_steps) - - if (self.args.update_applicability or has_fin) and not self.args.session_id: - parser.error("session_id must be specified with this operation") - - self.dbt = MIOpenDBTables( - session_id=self.args.session_id, config_type=self.args.config_type - ) - self.update_operation() - - def set_prefix(self): - """Set redis key prefix""" - if isinstance(self.args.fin_steps, Iterable): - steps_str = ("-").join(x for x in self.args.fin_steps) - self.prefix = ( - f"d_{self.db_name}_sess_{self.args.session_id}_" f"{steps_str}" - ) - else: - steps_str = self.args.fin_steps[0] - self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_{steps_str}" - - self.logger.info("redis prefix: %s", self.prefix) - - def overwrite_common_args(self): - """Overwrite common MIOpen_lib args with subcommand args""" - if self.args.subcommand is not None: - subc_dict = vars(self.args.get(self.args.subcommand)) - for sub_key in subc_dict: - if sub_key in vars(self.args): - self.args[sub_key] = subc_dict.get(sub_key) - - def check_fin_args(self, parser): - """! Helper function for fin args + """Class to support MIOpen specific tuning functionality""" + + # pylint: disable=too-many-public-methods + + def __init__(self): + super().__init__(library=Library.MIOPEN) + self.args = None + self.set_state = None + + def parse_args(self): + # pylint: disable=too-many-statements + """Function to parse arguments""" + parser = setup_arg_parser( + "Run Performance Tuning on a certain architecture", + [ + TunaArgs.ARCH, + TunaArgs.NUM_CU, + TunaArgs.VERSION, + TunaArgs.CONFIG_TYPE, + TunaArgs.SESSION_ID, + TunaArgs.MACHINES, + TunaArgs.REMOTE_MACHINE, + TunaArgs.LABEL, + TunaArgs.RESTART_MACHINE, + TunaArgs.DOCKER_NAME, + TunaArgs.SHUTDOWN_WORKERS, + TunaArgs.ENQUEUE_ONLY, + ], + ) + parser.add_argument( + "--find_mode", + dest="find_mode", + type=int, + default=1, + help="Set the MIOPEN_FIND_MODE environment variable for MIOpen", + choices=["1", "3"], + ) + parser.add_argument( + "--ticket", + dest="ticket", + type=str, + default=None, + help="Specify tuning ticket number", + ) + parser.add_argument( + "--solver_id", + type=int, + dest="solver_id", + default=None, + help="Specify solver_id. Use --list_solvers to see options", + ) + parser.add_argument( + "--dynamic_solvers_only", + dest="dynamic_solvers_only", + action="store_true", + default=False, + help="Only tune dynamic solvers.", + ) + parser.add_argument( + "-B", + "--blacklist", + dest="blacklist", + type=str, + default=None, + help="MIOpen blacklist algorithm, if multiple then comma separate", + ) + parser.add_argument( + "-i", + "--reset_interval", + type=int, + dest="reset_interval", + required=False, + help="Restart interval for job in hours.", + ) + parser.add_argument( + "--gpu_lim", + dest="gpu_lim", + type=int, + default=None, + help="Limit the number of gpu workers created by Tuna, index from 0", + ) + + parser.add_argument( + "-R", + "--rich_data", + dest="rich_data", + action="store_true", + default=False, + help="record intermediate parameter results from perf tuning", + ) + + subcommands = parser.add_subcommands(required=False) + subcommands.add_subcommand("import_configs", + get_import_cfg_parser(), + required=False) + + subcommands.add_subcommand("load_job", + get_load_job_parser(), + required=False) + + subcommands.add_subcommand("export_db", + get_export_db_parser(), + required=False) + + subcommands.add_subcommand("update_golden", + get_update_golden_parser(), + required=False) + + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--add_tables", + dest="add_tables", + action="store_true", + help="Add MIOpen library specific tables", + ) + + group.add_argument( + "--init_session", + action="store_true", + dest="init_session", + help="Set up a new tuning session.", + ) + group.add_argument( + "--fin_steps", + type=str, + dest="fin_steps", + help="Specify fin steps. Multiple steps should be comma separated.", + ) + group.add_argument( + "--list_solvers", + action="store_true", + dest="list_solvers", + help="List of solvers from the solver table", + ) + + # JD: implement the following two using fin_steps + group.add_argument( + "--update_solvers", + dest="update_solvers", + action="store_true", + help="Update the list of solvers in the database", + ) + group.add_argument( + "--update_applicability", + dest="update_applicability", + action="store_true", + help="Update the applicability table in the database", + ) + group.add_argument( + "-s", + "--status", + dest="check_status", + action="store_true", + default=False, + help="Check the status of machines", + ) + + group.add_argument( + "-e", + "--exec", + dest="execute_cmd", + type=str, + default=None, + help="execute on each machine", + ) + + self.args = parser.parse_args() + + if self.args.config_type is None: + self.args.config_type = ConfigType.convolution + + # overwritte common lib args with subcommand args value + if self.args.subcommand is not None: + self.overwrite_common_args() + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(-1) + + if self.args.list_solvers: + print_solvers() + raise CustomError("Printing solvers...") + + if self.args.fin_steps and self.args.subcommand != "load_job": + self.check_fin_args(parser) + self.set_prefix() + + if self.args.find_mode is None and not (self.args.check_status or + self.args.restart_machine or + self.args.execute_cmd): + parser.error("find_mode must be specified for a tuning run") + + if self.args.blacklist: + self.check_blacklist(parser) + + args_check(self.args, parser) + + fin_session_steps = [ + "miopen_find_compile", + "miopen_find_eval", + "miopen_perf_compile", + "miopen_perf_eval", + "get_applicability", + "find_compile", + "find_eval", + ] + has_fin = False + if self.args.fin_steps: + has_fin = all(x in fin_session_steps for x in self.args.fin_steps) + + if (self.args.update_applicability or has_fin) and not self.args.session_id: + parser.error("session_id must be specified with this operation") + + self.dbt = MIOpenDBTables(session_id=self.args.session_id, + config_type=self.args.config_type) + self.update_operation() + + def set_prefix(self): + """Set redis key prefix""" + if isinstance(self.args.fin_steps, Iterable): + steps_str = ("-").join(x for x in self.args.fin_steps) + self.prefix = (f"d_{self.db_name}_sess_{self.args.session_id}_" + f"{steps_str}") + else: + steps_str = self.args.fin_steps[0] + self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_{steps_str}" + + self.logger.info("redis prefix: %s", self.prefix) + + def overwrite_common_args(self): + """Overwrite common MIOpen_lib args with subcommand args""" + if self.args.subcommand is not None: + subc_dict = vars(self.args.get(self.args.subcommand)) + for sub_key in subc_dict: + if sub_key in vars(self.args): + self.args[sub_key] = subc_dict.get(sub_key) + + def check_fin_args(self, parser): + """! Helper function for fin args @param parser The command line argument parser """ - valid_fin_steps = list(k for k in FinStep.__members__) - if "," in self.args.fin_steps: - parser.error("Multiple fin_steps currently not supported") - f_steps = self.args.fin_steps.split(",") - self.args.fin_steps = f_steps - for step in self.args.fin_steps: - if step not in valid_fin_steps: - parser.error(f"Supported fin steps are: {valid_fin_steps}") - assert len(self.args.fin_steps) == 1 - - def check_blacklist(self, parser): - """! Helper function + valid_fin_steps = list(k for k in FinStep.__members__) + if "," in self.args.fin_steps: + parser.error("Multiple fin_steps currently not supported") + f_steps = self.args.fin_steps.split(",") + self.args.fin_steps = f_steps + for step in self.args.fin_steps: + if step not in valid_fin_steps: + parser.error(f"Supported fin steps are: {valid_fin_steps}") + assert len(self.args.fin_steps) == 1 + + def check_blacklist(self, parser): + """! Helper function @param parser The command line argument parser @return ret Boolean value """ - self.args.blacklist = self.args.blacklist.split(",") - for sol in self.args.blacklist: - if sol not in MIOPEN_ALG_LIST: - parser.error("Incorrect blacklist value") + self.args.blacklist = self.args.blacklist.split(",") + for sol in self.args.blacklist: + if sol not in MIOPEN_ALG_LIST: + parser.error("Incorrect blacklist value") - def do_fin_work(self, gpu, f_vals): - """! Helper function to execute job independendent fin work + def do_fin_work(self, gpu, f_vals): + """! Helper function to execute job independendent fin work @param gpu Unique ID of the GPU @param f_vals Dict containing runtime information """ - kwargs = self.get_kwargs(gpu, f_vals) - fin_worker = FinClass(**kwargs) + kwargs = self.get_kwargs(gpu, f_vals) + fin_worker = FinClass(**kwargs) - if self.args.update_solvers: - if not fin_worker.get_solvers(): - self.logger.error("No solvers returned from Fin class") + if self.args.update_solvers: + if not fin_worker.get_solvers(): + self.logger.error("No solvers returned from Fin class") - return True + return True - def launch_worker(self, gpu_idx, f_vals, worker_lst): - """! Function to launch worker + def launch_worker(self, gpu_idx, f_vals, worker_lst): + """! Function to launch worker @param gpu_idx Unique ID of the GPU @param f_vals Dict containing runtime information @param worker_lst List containing worker instances @return ret Boolean value """ - # pylint: disable=too-many-branches - worker = None - kwargs = self.get_kwargs(gpu_idx, f_vals) - if self.args.update_applicability: - kwargs["fin_steps"] = ["applicability"] - worker = FinClass(**kwargs) - worker.start() - worker_lst.append(worker) - return True - - worker = FinClass(**kwargs) - ret = False - if self.args.check_status: - if not super().check_status( - worker, - f_vals["b_first"], - gpu_idx, - f_vals["machine"], - self.args.docker_name, - ): - ret = True - elif self.args.init_session: - Session().add_new_session(self.args, worker) - elif self.args.execute_cmd: - # JD: Move the worker.exec_command to machine - self.logger.info(self.args.execute_cmd) - _, _, _ = worker.exec_command(self.args.execute_cmd + " 2>&1 ") - - return ret - - def compose_worker_list(self, machines): - # pylint: disable=too-many-branches - """! Helper function to compose worker_list + # pylint: disable=too-many-branches + worker = None + kwargs = self.get_kwargs(gpu_idx, f_vals) + if self.args.update_applicability: + kwargs["fin_steps"] = ["applicability"] + worker = FinClass(**kwargs) + worker.start() + worker_lst.append(worker) + return True + + worker = FinClass(**kwargs) + ret = False + if self.args.check_status: + if not super().check_status( + worker, + f_vals["b_first"], + gpu_idx, + f_vals["machine"], + self.args.docker_name, + ): + ret = True + elif self.args.init_session: + Session().add_new_session(self.args, worker) + elif self.args.execute_cmd: + # JD: Move the worker.exec_command to machine + self.logger.info(self.args.execute_cmd) + _, _, _ = worker.exec_command(self.args.execute_cmd + " 2>&1 ") + + return ret + + def compose_worker_list(self, machines): + # pylint: disable=too-many-branches + """! Helper function to compose worker_list @param machines List of machines to execute on """ - worker_lst = [] - fin_work_done = False - for machine in machines: - if self.args.restart_machine: - machine.restart_server(wait=False) - continue - - # fin_steps should only contain one step - worker_ids = None - if self.args.fin_steps and "eval" in self.args.fin_steps[0]: - worker_ids = machine.get_avail_gpus() - if self.args.gpu_lim and self.args.gpu_lim < len(worker_ids): - worker_ids = range(self.args.gpu_lim) - else: - worker_ids = super().get_num_procs(machine) - - if self.args.update_applicability: - f_vals = super().get_f_vals(machine, [1]) - kwargs = self.get_kwargs(0, f_vals) - kwargs["fin_steps"] = ["applicability"] - worker = FinClass(**kwargs) - query = worker.query_cfgs(self.args.label) - cfg_rows = query.all() - len_rows = len(cfg_rows) - proc_lim = (len_rows + 99) / 100 - if 32 < proc_lim: - proc_lim = 32 - while len(worker_ids) > proc_lim: - worker_ids.pop() - - if len(worker_ids) == 0: - return None - - f_vals = super().get_f_vals(machine, worker_ids) - - if (self.args.update_solvers) and not fin_work_done: - self.do_fin_work(0, f_vals) - fin_work_done = True - break - - for gpu_idx in worker_ids: - self.logger.info("launch mid %u, proc %u", machine.id, gpu_idx) - if not self.launch_worker(gpu_idx, f_vals, worker_lst): - break - - return worker_lst - - def add_tables(self): - """! Function to create new DB tables + worker_lst = [] + fin_work_done = False + for machine in machines: + if self.args.restart_machine: + machine.restart_server(wait=False) + continue + + # fin_steps should only contain one step + worker_ids = None + if self.args.fin_steps and "eval" in self.args.fin_steps[0]: + worker_ids = machine.get_avail_gpus() + if self.args.gpu_lim and self.args.gpu_lim < len(worker_ids): + worker_ids = range(self.args.gpu_lim) + else: + worker_ids = super().get_num_procs(machine) + + if self.args.update_applicability: + f_vals = super().get_f_vals(machine, [1]) + kwargs = self.get_kwargs(0, f_vals) + kwargs["fin_steps"] = ["applicability"] + worker = FinClass(**kwargs) + query = worker.query_cfgs(self.args.label) + cfg_rows = query.all() + len_rows = len(cfg_rows) + proc_lim = (len_rows + 99) / 100 + if 32 < proc_lim: + proc_lim = 32 + while len(worker_ids) > proc_lim: + worker_ids.pop() + + if len(worker_ids) == 0: + return None + + f_vals = super().get_f_vals(machine, worker_ids) + + if (self.args.update_solvers) and not fin_work_done: + self.do_fin_work(0, f_vals) + fin_work_done = True + break + + for gpu_idx in worker_ids: + self.logger.info("launch mid %u, proc %u", machine.id, gpu_idx) + if not self.launch_worker(gpu_idx, f_vals, worker_lst): + break + + return worker_lst + + def add_tables(self): + """! Function to create new DB tables @return Bool """ - ret_t = create_tables(get_miopen_tables()) - self.logger.info("DB creation successful: %s", ret_t) - recreate_triggers(drop_miopen_triggers(), get_miopen_triggers()) - return True - - def run(self): - # pylint: disable=duplicate-code - """! Main function to launch library""" - res = None - if self.args is None: - self.parse_args() - - if self.args.add_tables: - self.add_tables() - return None - - if ( - self.args.subcommand is not None - and self.args.subcommand == "import_configs" - ): - run_import_configs(self.args.import_configs, self.logger) - return None - - if self.args.subcommand is not None and self.args.subcommand == "load_job": - run_load_job(self.args.load_job, self.logger) - return None - - if self.args.subcommand is not None and self.args.subcommand == "export_db": - run_export_db(self.args.export_db, self.logger) - return None - - if self.args.subcommand is not None and self.args.subcommand == "update_golden": - run_update_golden(self.args.update_golden, self.logger) - return None - - machines = load_machines(self.args) - res = self.compose_worker_list(machines) - return res - - def get_envmt(self): - """! Function to construct environment var""" - envmt = ["MIOPEN_LOG_LEVEL=4"] - - envmt.append("MIOPEN_SQLITE_KERN_CACHE=ON") - envmt.append("MIOPEN_DEBUG_IMPLICIT_GEMM_FIND_ALL_SOLUTIONS=1") - - if self.args.find_mode: - envmt.append(f"MIOPEN_FIND_MODE={self.args.find_mode}") - - if self.args.blacklist: - bk_str = ", ".join([f"{arg}=0" for arg in self.args.blacklist]) - for bk_var in bk_str.split(","): - envmt.append(bk_var) - - return envmt - - def get_kwargs(self, gpu_idx, f_vals, tuning=False): - """! Helper function to set up kwargs for worker instances + ret_t = create_tables(get_miopen_tables()) + self.logger.info("DB creation successful: %s", ret_t) + recreate_triggers(drop_miopen_triggers(), get_miopen_triggers()) + return True + + def run(self): + # pylint: disable=duplicate-code + """! Main function to launch library""" + res = None + if self.args is None: + self.parse_args() + + if self.args.add_tables: + self.add_tables() + return None + + if (self.args.subcommand is not None and + self.args.subcommand == "import_configs"): + run_import_configs(self.args.import_configs, self.logger) + return None + + if self.args.subcommand is not None and self.args.subcommand == "load_job": + run_load_job(self.args.load_job, self.logger) + return None + + if self.args.subcommand is not None and self.args.subcommand == "export_db": + run_export_db(self.args.export_db, self.logger) + return None + + if self.args.subcommand is not None and self.args.subcommand == "update_golden": + run_update_golden(self.args.update_golden, self.logger) + return None + + machines = load_machines(self.args) + res = self.compose_worker_list(machines) + return res + + def get_envmt(self): + """! Function to construct environment var""" + envmt = ["MIOPEN_LOG_LEVEL=4"] + + envmt.append("MIOPEN_SQLITE_KERN_CACHE=ON") + envmt.append("MIOPEN_DEBUG_IMPLICIT_GEMM_FIND_ALL_SOLUTIONS=1") + + if self.args.find_mode: + envmt.append(f"MIOPEN_FIND_MODE={self.args.find_mode}") + + if self.args.blacklist: + bk_str = ", ".join([f"{arg}=0" for arg in self.args.blacklist]) + for bk_var in bk_str.split(","): + envmt.append(bk_var) + + return envmt + + def get_kwargs(self, gpu_idx, f_vals, tuning=False): + """! Helper function to set up kwargs for worker instances @param gpu_idx Unique ID of the GPU @param f_vals Dict containing runtime information @param tuning Boolean that indicates if kwargs are for a tuning step @return kwargs Dictionary """ - kwargs = super().get_kwargs(gpu_idx, f_vals, tuning) - kwargs["fin_steps"] = self.args.fin_steps - kwargs["dynamic_solvers_only"] = self.args.dynamic_solvers_only - kwargs["config_type"] = self.args.config_type - kwargs["reset_interval"] = self.args.reset_interval + kwargs = super().get_kwargs(gpu_idx, f_vals, tuning) + kwargs["fin_steps"] = self.args.fin_steps + kwargs["dynamic_solvers_only"] = self.args.dynamic_solvers_only + kwargs["config_type"] = self.args.config_type + kwargs["reset_interval"] = self.args.reset_interval - return kwargs + return kwargs - def get_job_list(self, session, find_state, claim_num): - """! Get list of jobs + def get_job_list(self, session, find_state, claim_num): + """! Get list of jobs @param session DB session @param find_state DB job state @param claim_num Number of DB jobs to pick up @return List of DB jobs """ - job_list = self.get_job_objs( - session, - find_state, - self.args.label, - self.dbt, - self.get_job_attr(), - claim_num, - self.args.fin_steps, - ) - - return job_list - - def get_job_objs( - self, - session: DbSession, - find_state: list, - label: str, - dbt: DBTablesInterface, - job_attr: List[str], - claim_num: int = None, - fin_steps: List[str] = None, - ) -> List[SimpleDict]: - """! Get list of job objects + job_list = self.get_job_objs( + session, + find_state, + self.args.label, + self.dbt, + self.get_job_attr(), + claim_num, + self.args.fin_steps, + ) + + return job_list + + def get_job_objs( + self, + session: DbSession, + find_state: list, + label: str, + dbt: DBTablesInterface, + job_attr: List[str], + claim_num: int = None, + fin_steps: List[str] = None, + ) -> List[SimpleDict]: + """! Get list of job objects @param session DB session @param find_state DB job state @param label DB job reason @@ -553,30 +553,29 @@ def get_job_objs( @param fin_steps List of MIFin steps @return List of DB jobs """ - entries: List[Tuple[SimpleDict, ...]] - conds: List[str] = [f"session={dbt.session.id}", "valid=1"] - - if label: - conds.append(f"reason='{label}'") - - conds.append(f"retries<{self.max_job_retries}") - conds.append("state in (" + str(find_state).strip("{").strip("}") + ")") - - entries = self.compose_work_objs( - session, conds, dbt, job_attr, claim_num, fin_steps - ) - return entries - - def compose_work_objs( - self, - session: DbSession, - conds: List[str], - dbt: DBTablesInterface, - job_attr: List[str], - claim_num: int = None, - fin_steps: List[str] = None, - ) -> List[SimpleDict]: - """! Query a job list for update + entries: List[Tuple[SimpleDict, ...]] + conds: List[str] = [f"session={dbt.session.id}", "valid=1"] + + if label: + conds.append(f"reason='{label}'") + + conds.append(f"retries<{self.max_job_retries}") + conds.append("state in (" + str(find_state).strip("{").strip("}") + ")") + + entries = self.compose_work_objs(session, conds, dbt, job_attr, claim_num, + fin_steps) + return entries + + def compose_work_objs( + self, + session: DbSession, + conds: List[str], + dbt: DBTablesInterface, + job_attr: List[str], + claim_num: int = None, + fin_steps: List[str] = None, + ) -> List[SimpleDict]: + """! Query a job list for update @param session DB session @param conds List of conditions for DB job WHERE clause @param dbt Class representing all DB tables associated with this class @@ -584,370 +583,358 @@ def compose_work_objs( @param fin_steps List of MIFin steps @return List of MIFin work objects """ - job_entries = [] - if fin_steps: - conds.append(f"fin_step like '%{fin_steps[0]}%'") - else: - conds.append("fin_step='not_fin'") - - cond_str = " AND ".join(conds) - if cond_str: - cond_str = f"WHERE {cond_str}" - if claim_num: - cond_str += ( - f" ORDER BY retries,config ASC LIMIT {claim_num} FOR UPDATE SKIP LOCKED" - ) - else: - cond_str += " ORDER BY retries,config ASC FOR UPDATE SKIP LOCKED" - - job_entries = gen_select_objs( - session, job_attr, dbt.job_table.__tablename__, cond_str - ) - - return job_entries - - def compose_work_objs_fin( - self, session, job_entries, dbt - ) -> List[Tuple[SimpleDict, SimpleDict]]: - """! Return jobs for fin work + job_entries = [] + if fin_steps: + conds.append(f"fin_step like '%{fin_steps[0]}%'") + else: + conds.append("fin_step='not_fin'") + + cond_str = " AND ".join(conds) + if cond_str: + cond_str = f"WHERE {cond_str}" + if claim_num: + cond_str += ( + f" ORDER BY retries,config ASC LIMIT {claim_num} FOR UPDATE SKIP LOCKED" + ) + else: + cond_str += " ORDER BY retries,config ASC FOR UPDATE SKIP LOCKED" + + job_entries = gen_select_objs(session, job_attr, + dbt.job_table.__tablename__, cond_str) + + return job_entries + + def compose_work_objs_fin(self, session, job_entries, + dbt) -> List[Tuple[SimpleDict, SimpleDict]]: + """! Return jobs for fin work @param session DB session @param job_entries List of DB jobs @param dbt Class representing all DB tables associated with this class @return ret Job tuple """ - ret = [] - - cfg_rel = { - key: { - "key": list(val.local_columns)[0].name, - "ftble": str(list(val.remote_side)[0]).split(".", maxsplit=1)[0], - "fkey": str(list(val.remote_side)[0]).split(".")[1], - } - for key, val in inspect(dbt.config_table).relationships.items() - } - - if job_entries: - id_str = ",".join({str(job.config) for job in job_entries}) - cfg_cond_str = f"where valid=1 and id in ({id_str})" - cfg_attr = [column.name for column in inspect(dbt.config_table).c] - cfg_entries = gen_select_objs( - session, cfg_attr, dbt.config_table.__tablename__, cfg_cond_str - ) + ret = [] + + cfg_rel = { + key: { + "key": list(val.local_columns)[0].name, + "ftble": str(list(val.remote_side)[0]).split(".", maxsplit=1)[0], + "fkey": str(list(val.remote_side)[0]).split(".")[1], + } for key, val in inspect(dbt.config_table).relationships.items() + } + + if job_entries: + id_str = ",".join({str(job.config) for job in job_entries}) + cfg_cond_str = f"where valid=1 and id in ({id_str})" + cfg_attr = [column.name for column in inspect(dbt.config_table).c] + cfg_entries = gen_select_objs(session, cfg_attr, + dbt.config_table.__tablename__, + cfg_cond_str) - cfg_entries = self.attach_tensors(session, cfg_rel, cfg_entries) + cfg_entries = self.attach_tensors(session, cfg_rel, cfg_entries) - cfg_map = {cfg.id: cfg for cfg in cfg_entries} + cfg_map = {cfg.id: cfg for cfg in cfg_entries} - for job in job_entries: - ret.append((job, cfg_map[job.config])) + for job in job_entries: + ret.append((job, cfg_map[job.config])) - return ret + return ret - def attach_tensors(self, session, cfg_rel, cfg_entries): - """! Attach tensor relationship information to config entries + def attach_tensors(self, session, cfg_rel, cfg_entries): + """! Attach tensor relationship information to config entries @param session DB session @param cfg_rel DB Config col value @param cfg_entries List of DB Config entries @return cfg_entries List of DB Config entries with attached tensors (foreign keys) """ - for key, val in cfg_rel.items(): - rel_attr = [ - column.name - for column in inspect(get_class_by_tablename(val["ftble"])).c - ] - val["fattr"] = rel_attr - - for cfg in cfg_entries: - for key, val in cfg_rel.items(): - rel_val = getattr(cfg, val["key"]) - rel_cond_str = f"where {val['fkey']}={rel_val}" - setattr( - cfg, - key, - gen_select_objs(session, val["fattr"], val["ftble"], rel_cond_str)[ - 0 - ], - ) - return cfg_entries - - # deprecated - def get_job_tables( - self, job_rows: List[Tuple[SimpleDict, ...]], job_attr: List[str] - ) -> List[SimpleDict]: - """Find job tables in query results""" - if has_attr_set(job_rows[0], job_attr): - job_tables: List[SimpleDict] = job_rows - else: - job_i: int = 0 - tble: SimpleDict - for i, tble in enumerate(job_rows[0]): - if has_attr_set(tble, job_attr): - job_i = i - break - job_tables = [row[job_i] for row in job_rows] - - return job_tables - - def update_operation(self): - """! Update the workers type that this library needs""" - if self.args.fin_steps: - if ( - "miopen_find_compile" in self.args.fin_steps - or "miopen_perf_compile" in self.args.fin_steps - ): - self.fetch_state.add("new") - self.set_state = "compile_start" - self.operation = Operation.COMPILE - elif ( - "miopen_find_eval" in self.args.fin_steps - or "miopen_perf_eval" in self.args.fin_steps - ): - self.fetch_state.add("new") - self.fetch_state.add("compiled") - self.set_state = "eval_start" - self.operation = Operation.EVAL - - if self.args.update_applicability: - self.fetch_state.add("new") - - def has_tunable_operation(self): - """! Check if its a tuning loop operation + for key, val in cfg_rel.items(): + rel_attr = [ + column.name + for column in inspect(get_class_by_tablename(val["ftble"])).c + ] + val["fattr"] = rel_attr + + for cfg in cfg_entries: + for key, val in cfg_rel.items(): + rel_val = getattr(cfg, val["key"]) + rel_cond_str = f"where {val['fkey']}={rel_val}" + setattr( + cfg, + key, + gen_select_objs(session, val["fattr"], val["ftble"], + rel_cond_str)[0], + ) + return cfg_entries + + # deprecated + def get_job_tables(self, job_rows: List[Tuple[SimpleDict, ...]], + job_attr: List[str]) -> List[SimpleDict]: + """Find job tables in query results""" + if has_attr_set(job_rows[0], job_attr): + job_tables: List[SimpleDict] = job_rows + else: + job_i: int = 0 + tble: SimpleDict + for i, tble in enumerate(job_rows[0]): + if has_attr_set(tble, job_attr): + job_i = i + break + job_tables = [row[job_i] for row in job_rows] + + return job_tables + + def update_operation(self): + """! Update the workers type that this library needs""" + if self.args.fin_steps: + if ("miopen_find_compile" in self.args.fin_steps or + "miopen_perf_compile" in self.args.fin_steps): + self.fetch_state.add("new") + self.set_state = "compile_start" + self.operation = Operation.COMPILE + elif ("miopen_find_eval" in self.args.fin_steps or + "miopen_perf_eval" in self.args.fin_steps): + self.fetch_state.add("new") + self.fetch_state.add("compiled") + self.set_state = "eval_start" + self.operation = Operation.EVAL + + if self.args.update_applicability: + self.fetch_state.add("new") + + def has_tunable_operation(self): + """! Check if its a tuning loop operation @return Bool value that represents if operation is tuning """ - if self.args is None: - self.parse_args() - if self.args.subcommand and "load_job" in self.args.subcommand: - return False - if self.args.shutdown_workers: - return True - - return self.args.fin_steps and any( - s in self.args.fin_steps for s in MIOPEN_CELERY_STEPS - ) - - @lru_cache(1) - def get_fdb_attr(self): - """! Get find_db table attrs + if self.args is None: + self.parse_args() + if self.args.subcommand and "load_job" in self.args.subcommand: + return False + if self.args.shutdown_workers: + return True + + return self.args.fin_steps and any( + s in self.args.fin_steps for s in MIOPEN_CELERY_STEPS) + + @lru_cache(1) + def get_fdb_attr(self): + """! Get find_db table attrs @return fdb_attr find_db table attributes without timestamps """ - fdb_attr = None - fdb_attr = [column.name for column in inspect(self.dbt.find_db_table).c] - fdb_attr.remove("insert_ts") - fdb_attr.remove("update_ts") - return fdb_attr - - @lru_cache(1) - def get_tuning_data_attr(self): - """! Get tuning_data table attrs + fdb_attr = None + fdb_attr = [column.name for column in inspect(self.dbt.find_db_table).c] + fdb_attr.remove("insert_ts") + fdb_attr.remove("update_ts") + return fdb_attr + + @lru_cache(1) + def get_tuning_data_attr(self): + """! Get tuning_data table attrs @return tuning_data_attr tuning_data table attributes without timestamps """ - tuning_data_attr = None - tuning_data_attr = [ - column.name for column in inspect(self.dbt.tuning_data_table).c - ] - tuning_data_attr.remove("insert_ts") - tuning_data_attr.remove("update_ts") - return tuning_data_attr - - def serialize_jobs(self, session: DbSession, batch_jobs: List[Any]): - """! Return list of serialize jobs + tuning_data_attr = None + tuning_data_attr = [ + column.name for column in inspect(self.dbt.tuning_data_table).c + ] + tuning_data_attr.remove("insert_ts") + tuning_data_attr.remove("update_ts") + return tuning_data_attr + + def serialize_jobs(self, session: DbSession, batch_jobs: List[Any]): + """! Return list of serialize jobs @param session DB session @param batch_jobs List of DB jobs @return DB jobs, serialized """ - entries = self.compose_work_objs_fin(session, batch_jobs, self.dbt) - return serialize_chunk(entries) - - def build_context( - self, serialized_jobs: Tuple[SimpleDict, SimpleDict] - ) -> List[dict]: - """Build context list for enqueue job""" - context_list = [] - kwargs = self.get_context_items() - fdb_attr = self.get_fdb_attr() - tuning_data_attr = self.get_tuning_data_attr() - for job, config in serialized_jobs: - context = { - "job": job, - "config": config, - "operation": self.operation, - "arch": self.dbt.session.arch, - "num_cu": self.dbt.session.num_cu, - "kwargs": kwargs, - "rich_data": self.args.rich_data, - "fdb_attr": fdb_attr, - "tuning_data_attr": tuning_data_attr, - } - context_list.append(context) - - return context_list - - def celery_enqueue_call(self, context: dict, q_name: str, task_id=False): - """! Enqueue job (context) for queue:q_name + entries = self.compose_work_objs_fin(session, batch_jobs, self.dbt) + return serialize_chunk(entries) + + def build_context( + self, serialized_jobs: Tuple[SimpleDict, SimpleDict]) -> List[dict]: + """Build context list for enqueue job""" + context_list = [] + kwargs = self.get_context_items() + fdb_attr = self.get_fdb_attr() + tuning_data_attr = self.get_tuning_data_attr() + for job, config in serialized_jobs: + context = { + "job": job, + "config": config, + "operation": self.operation, + "arch": self.dbt.session.arch, + "num_cu": self.dbt.session.num_cu, + "kwargs": kwargs, + "rich_data": self.args.rich_data, + "fdb_attr": fdb_attr, + "tuning_data_attr": tuning_data_attr, + } + context_list.append(context) + + return context_list + + def celery_enqueue_call(self, context: dict, q_name: str, task_id=False): + """! Enqueue job (context) for queue:q_name @param context Context for Celery job @param q_name Custom Celery queue name @param task_id Custom Redis Key """ - # hacky way to get the Q_NAME to the task decorator for interpreter to decorate the - # function with correct q_name arg - # if import is moved to top it will result in circular imports - Q_NAME = q_name # pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name - from tuna.miopen.celery_tuning.celery_tasks import ( - celery_enqueue, - ) # pylint: disable=import-outside-toplevel - - return celery_enqueue.apply_async( - (context,), - task_id=("-").join([self.prefix, uuid()]), - queue=q_name, - reply_to=q_name, - ) - - def process_compile_results(self, session, fin_json, context): - """! Process result from fin_build worker + # hacky way to get the Q_NAME to the task decorator for interpreter to decorate the + # function with correct q_name arg + # if import is moved to top it will result in circular imports + Q_NAME = q_name # pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name + from tuna.miopen.celery_tuning.celery_tasks import ( + celery_enqueue,) # pylint: disable=import-outside-toplevel + + return celery_enqueue.apply_async( + (context,), + task_id=("-").join([self.prefix, uuid()]), + queue=q_name, + reply_to=q_name, + ) + + def process_compile_results(self, session, fin_json, context): + """! Process result from fin_build worker @param session DB session @param fin_json MIFin results for job @param context Context for Celery job @return Boolean value """ - job = SimpleDict(**context["job"]) - pending = [] - solver_id_map = get_solver_ids() - - failed_job = False - result_str = "" - status = None - try: - if fin_json: - if "success" in fin_json and fin_json["success"] is False: - status = [fin_json] - else: - if "miopen_find_compile_result" in fin_json: - status = process_fdb_w_kernels( - session, - fin_json, - copy.deepcopy(context), - self.dbt, - context["fdb_attr"], - pending, - ) - - elif "miopen_perf_compile_result" in fin_json: - status = process_pdb_compile( - session, fin_json, job, self.dbt, solver_id_map - ) - - success, result_str = get_fin_result(status) - failed_job = not success - - except (OperationalError, IntegrityError) as err: - self.logger.warning("FinBuild: Unable to update Database %s", err) - session.rollback() - failed_job = True - except DataError as err: - self.logger.warning( - "FinBuild: Invalid data, likely large workspace. DB Error: %s", err - ) - session.rollback() - failed_job = True - - if failed_job: - set_job_state(session, job, self.dbt, "errored", False, result=result_str) + job = SimpleDict(**context["job"]) + pending = [] + solver_id_map = get_solver_ids() + + failed_job = False + result_str = "" + status = None + try: + if fin_json: + if "success" in fin_json and fin_json["success"] is False: + status = [fin_json] else: - set_job_state(session, job, self.dbt, "compiled", False, result=result_str) - - return True + if "miopen_find_compile_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + ) - def process_eval_results(self, session, fin_json, context): - """! Process fin_json result + elif "miopen_perf_compile_result" in fin_json: + status = process_pdb_compile(session, fin_json, job, self.dbt, + solver_id_map) + + success, result_str = get_fin_result(status) + failed_job = not success + + except (OperationalError, IntegrityError) as err: + self.logger.warning("FinBuild: Unable to update Database %s", err) + session.rollback() + failed_job = True + except DataError as err: + self.logger.warning( + "FinBuild: Invalid data, likely large workspace. DB Error: %s", err) + session.rollback() + failed_job = True + + if failed_job: + set_job_state(session, job, self.dbt, "errored", False, result=result_str) + else: + set_job_state(session, + job, + self.dbt, + "compiled", + False, + result=result_str) + + return True + + def process_eval_results(self, session, fin_json, context): + """! Process fin_json result @param session DB session @param fin_json MIFin results for job @param context Context for Celery job @return Boolean value """ - job = SimpleDict(**context["job"]) - failed_job = True - result_str = "" - pending = [] - orig_state = "compiled" - - try: - if fin_json: - if "success" in fin_json and fin_json["success"] is False: - status = [fin_json] - else: - if "miopen_find_eval_result" in fin_json: - status = process_fdb_w_kernels( - session, - fin_json, - copy.deepcopy(context), - self.dbt, - context["fdb_attr"], - pending, - result_str="miopen_find_eval_result", - check_str="evaluated", - ) - elif "miopen_perf_eval_result" in fin_json: - status = process_fdb_w_kernels( - session, - fin_json, - copy.deepcopy(context), - self.dbt, - context["fdb_attr"], - pending, - result_str="miopen_perf_eval_result", - check_str="evaluated", - ) - if context["rich_data"]: - status = process_tuning_data( - session, - fin_json, - copy.deepcopy(context), - self.dbt, - context["tuning_data_attr"], - pending, - result_str="miopen_perf_eval_result", - check_str="evaluated", - ) - - success, result_str = get_fin_result(status) - failed_job = not success - - if failed_job: - if job.retries >= ( - MAX_ERRORED_JOB_RETRIES - 1 - ): # pylint: disable=no-member - self.logger.warning("max job retries exhausted, setting to errored") - set_job_state(session, job, self.dbt, "errored", result=result_str) - else: - self.logger.warning( - "resetting job state to %s, incrementing retries", orig_state - ) - set_job_state( - session, - job, - self.dbt, - orig_state, - increment_retries=True, - result=result_str, - ) - else: - self.logger.info("\n\n Setting job state to evaluated") - set_job_state(session, job, self.dbt, "evaluated", result=result_str) - clean_cache_table(self.dbt, job) - except (OperationalError, IntegrityError) as err: - self.logger.warning("FinBuild: Unable to update Database %s", err) - session.rollback() - set_job_state(session, job, self.dbt, "errored", result=result_str) - - return True - - def extract_job_id_from_context(self, context): - """Extract job ID from MIOpen celery task context""" - try: - # Extract job ID from the job context - return context.get("job", {}).get("id") - except (AttributeError, KeyError): - return None + job = SimpleDict(**context["job"]) + failed_job = True + result_str = "" + pending = [] + orig_state = "compiled" + + try: + if fin_json: + if "success" in fin_json and fin_json["success"] is False: + status = [fin_json] + else: + if "miopen_find_eval_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + result_str="miopen_find_eval_result", + check_str="evaluated", + ) + elif "miopen_perf_eval_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + result_str="miopen_perf_eval_result", + check_str="evaluated", + ) + if context["rich_data"]: + status = process_tuning_data( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["tuning_data_attr"], + pending, + result_str="miopen_perf_eval_result", + check_str="evaluated", + ) + + success, result_str = get_fin_result(status) + failed_job = not success + + if failed_job: + if job.retries >= (MAX_ERRORED_JOB_RETRIES - 1): # pylint: disable=no-member + self.logger.warning("max job retries exhausted, setting to errored") + set_job_state(session, job, self.dbt, "errored", result=result_str) + else: + self.logger.warning("resetting job state to %s, incrementing retries", + orig_state) + set_job_state( + session, + job, + self.dbt, + orig_state, + increment_retries=True, + result=result_str, + ) + else: + self.logger.info("\n\n Setting job state to evaluated") + set_job_state(session, job, self.dbt, "evaluated", result=result_str) + clean_cache_table(self.dbt, job) + except (OperationalError, IntegrityError) as err: + self.logger.warning("FinBuild: Unable to update Database %s", err) + session.rollback() + set_job_state(session, job, self.dbt, "errored", result=result_str) + + return True + + def extract_job_id_from_context(self, context): + """Extract job ID from MIOpen celery task context""" + try: + # Extract job ID from the job context + return context.get("job", {}).get("id") + except (AttributeError, KeyError): + return None diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index da17340b..4c03ed0b 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -62,67 +62,69 @@ class MITunaInterface: # pylint:disable=too-many-instance-attributes,too-many-public-methods - """Interface class extended by libraries. The purpose of this class is to define + """Interface class extended by libraries. The purpose of this class is to define common functionalities.""" - def __init__(self, library=Library.MIOPEN) -> None: + def __init__(self, library=Library.MIOPEN) -> None: - self.self: Library = self + self.self: Library = self - self.logger: logging.Logger = setup_logger( - logger_name=library.value, add_streamhandler=True - ) - self.args: argparse.Namespace + self.logger: logging.Logger = setup_logger(logger_name=library.value, + add_streamhandler=True) + self.args: argparse.Namespace - self.fetch_state: set = set() - self.max_job_retries = 10 - self.dbt = None - self.operation = None - self.db_name = os.environ["TUNA_DB_NAME"] - self.prefix = None + self.fetch_state: set = set() + self.max_job_retries = 10 + self.dbt = None + self.operation = None + self.db_name = os.environ["TUNA_DB_NAME"] + self.prefix = None - # Track jobs claimed by this specific instance when in distributor mode - self.claimed_job_ids = set() - self.completed_job_ids = set() - # if less than 25% of the jobs are remaining, we can grab more jobs - self.progress_factor = 0.25 + # Track jobs claimed by this specific instance when in distributor mode + self.claimed_job_ids = set() + self.completed_job_ids = set() + # if less than 25% of the jobs are remaining, we can grab more jobs + self.progress_factor = 0.25 - def check_docker(self, worker: WorkerInterface, dockername="miopentuna") -> bool: - """! Checking for docker + def check_docker(self, + worker: WorkerInterface, + dockername="miopentuna") -> bool: + """! Checking for docker @param worker The worker interface instance @param dockername The name of the docker """ - out2: ChannelFile - _, out2, _ = worker.exec_command("sudo docker info") - while not out2.channel.exit_status_ready(): - self.logger.warning(out2.readline()) - if out2.channel.exit_status > 0: - self.logger.warning("docker not installed or failed to run with sudo .... ") - return False - - out: StringIO = StringIO() - line: Optional[str] = None - _, out, _ = worker.exec_command(f"sudo docker images | grep {dockername}") - for line in out.readlines(): - if line is not None: - if line.find(dockername) != -1: - self.logger.warning("%s docker image exists", dockername) - return True - if line is None: - self.logger.warning("%s docker image does not exist", dockername) - return False - - return False - - def check_status( - self, - worker: WorkerInterface, - b_first: int, - gpu_idx: int, - machine: Machine, - dockername: str = "miopentuna", - ) -> bool: - """! Function to check gpu_status + out2: ChannelFile + _, out2, _ = worker.exec_command("sudo docker info") + while not out2.channel.exit_status_ready(): + self.logger.warning(out2.readline()) + if out2.channel.exit_status > 0: + self.logger.warning( + "docker not installed or failed to run with sudo .... ") + return False + + out: StringIO = StringIO() + line: Optional[str] = None + _, out, _ = worker.exec_command(f"sudo docker images | grep {dockername}") + for line in out.readlines(): + if line is not None: + if line.find(dockername) != -1: + self.logger.warning("%s docker image exists", dockername) + return True + if line is None: + self.logger.warning("%s docker image does not exist", dockername) + return False + + return False + + def check_status( + self, + worker: WorkerInterface, + b_first: int, + gpu_idx: int, + machine: Machine, + dockername: str = "miopentuna", + ) -> bool: + """! Function to check gpu_status @param worker The worker interface instance @param b_first Flag to keep track of visited GPU @param gpu_idx Unique ID of the GPU @@ -130,677 +132,673 @@ def check_status( @param dockername The name of the docker """ - if machine.chk_gpu_status(worker.gpu_id): - self.logger.info( - "Machine: (%s, %u) GPU_ID: %u OK", - machine.hostname, - machine.port, - gpu_idx, - ) - else: - self.logger.info( - "Machine: (%s, %u) GPU_ID: %u ERROR", - machine.hostname, - machine.port, - gpu_idx, - ) - - if not b_first: - return False - b_first = False - _, out, _ = worker.exec_command("docker info") - while not out.channel.exit_status_ready(): - pass - - if out.channel.exit_status > 0: - self.check_docker(worker, dockername) + if machine.chk_gpu_status(worker.gpu_id): + self.logger.info( + "Machine: (%s, %u) GPU_ID: %u OK", + machine.hostname, + machine.port, + gpu_idx, + ) + else: + self.logger.info( + "Machine: (%s, %u) GPU_ID: %u ERROR", + machine.hostname, + machine.port, + gpu_idx, + ) + + if not b_first: + return False + b_first = False + _, out, _ = worker.exec_command("docker info") + while not out.channel.exit_status_ready(): + pass + + if out.channel.exit_status > 0: + self.check_docker(worker, dockername) + else: + _, out, _ = worker.exec_command(f"docker images | grep {dockername}") + line: Optional[str] = None + for line in out.readlines(): + if line is not None: + if line.find(dockername) != -1: + self.logger.warning("%s docker image exists", dockername) + break else: - _, out, _ = worker.exec_command(f"docker images | grep {dockername}") - line: Optional[str] = None - for line in out.readlines(): - if line is not None: - if line.find(dockername) != -1: - self.logger.warning("%s docker image exists", dockername) - break - else: - self.logger.warning("%s docker image does not exist", dockername) - - return True - - def add_tables(self) -> bool: - """Add self specific tables""" - return self.add_tables() - - def get_num_procs(self, machine: Machine) -> List: - """Determine number of processes by compute capacity""" - worker_ids: List = [] - num_procs: int - env: Dict[str, Any] - env = get_env_vars() - if env["slurm_cpus"] > 0: - num_procs = int(env["slurm_cpus"]) - else: - num_procs = int(machine.get_num_cpus() * 0.6) - - worker_ids = list(range(num_procs)) - - if len(worker_ids) == 0: - self.logger.error("num_procs must be bigger than zero to launch worker") - self.logger.error("Cannot launch worker on machine: %s", machine.id) - worker_ids = [] - - return worker_ids - - def get_f_vals( - self, machine: Machine, worker_ids: range, tuning=False - ) -> Dict[str, Any]: - # pylint:disable=unused-argument - """Determine kwargs for worker_interface""" - f_vals: Dict[str, Any] - f_vals = self.compose_f_vals(machine) - f_vals["envmt"] = self.get_envmt() - - if not tuning: - f_vals["num_procs"] = Value("i", len(worker_ids)) - - return f_vals - - def get_envmt(self): - """Get runtime envmt""" - raise NotImplementedError("Not implemented") - - def compose_f_vals(self, machine: Machine, tuning=False) -> Dict[str, Any]: - """! Compose dict for WorkerInterface constructor + self.logger.warning("%s docker image does not exist", dockername) + + return True + + def add_tables(self) -> bool: + """Add self specific tables""" + return self.add_tables() + + def get_num_procs(self, machine: Machine) -> List: + """Determine number of processes by compute capacity""" + worker_ids: List = [] + num_procs: int + env: Dict[str, Any] + env = get_env_vars() + if env["slurm_cpus"] > 0: + num_procs = int(env["slurm_cpus"]) + else: + num_procs = int(machine.get_num_cpus() * 0.6) + + worker_ids = list(range(num_procs)) + + if len(worker_ids) == 0: + self.logger.error("num_procs must be bigger than zero to launch worker") + self.logger.error("Cannot launch worker on machine: %s", machine.id) + worker_ids = [] + + return worker_ids + + def get_f_vals(self, + machine: Machine, + worker_ids: range, + tuning=False) -> Dict[str, Any]: + # pylint:disable=unused-argument + """Determine kwargs for worker_interface""" + f_vals: Dict[str, Any] + f_vals = self.compose_f_vals(machine) + f_vals["envmt"] = self.get_envmt() + + if not tuning: + f_vals["num_procs"] = Value("i", len(worker_ids)) + + return f_vals + + def get_envmt(self): + """Get runtime envmt""" + raise NotImplementedError("Not implemented") + + def compose_f_vals(self, machine: Machine, tuning=False) -> Dict[str, Any]: + """! Compose dict for WorkerInterface constructor @param args The command line arguments @param machine Machine instance """ - f_vals: Dict[str, Any] = {} - f_vals["b_first"] = True - - # adding non-serializable obj when not running through celery - if not tuning: - f_vals["machine"] = machine - f_vals["bar_lock"] = Lock() - # multiprocess queue for jobs, shared on machine - f_vals["job_queue"] = mpQueue() - f_vals["job_queue_lock"] = Lock() - f_vals["end_jobs"] = Value("i", 0) - - return f_vals - - def get_kwargs( - self, gpu_idx: int, f_vals: Dict[str, Any], tuning=False - ) -> Dict[str, Any]: - """! Helper function to set up kwargs for worker instances + f_vals: Dict[str, Any] = {} + f_vals["b_first"] = True + + # adding non-serializable obj when not running through celery + if not tuning: + f_vals["machine"] = machine + f_vals["bar_lock"] = Lock() + # multiprocess queue for jobs, shared on machine + f_vals["job_queue"] = mpQueue() + f_vals["job_queue_lock"] = Lock() + f_vals["end_jobs"] = Value("i", 0) + + return f_vals + + def get_kwargs(self, + gpu_idx: int, + f_vals: Dict[str, Any], + tuning=False) -> Dict[str, Any]: + """! Helper function to set up kwargs for worker instances @param gpu_idx Unique ID of the GPU @param f_vals Dict containing runtime information """ - envmt: Dict[str, Any] = f_vals["envmt"].copy() - kwargs: Dict[str, Any] = {} - - kwargs = { - "gpu_id": gpu_idx, - "envmt": envmt, - "label": self.args.label, - "docker_name": self.args.docker_name, - "session_id": self.args.session_id, - } - - # adding non-serializable obj when not running through celery - if not tuning: - kwargs["machine"] = f_vals["machine"] - kwargs["job_queue"] = f_vals["job_queue"] - kwargs["job_queue_lock"] = f_vals["job_queue_lock"] - kwargs["num_procs"] = f_vals["num_procs"] - kwargs["bar_lock"] = f_vals["bar_lock"] - kwargs["end_jobs"] = f_vals["end_jobs"] - kwargs["job_queue"] = f_vals["job_queue"] - kwargs["job_queue_lock"] = f_vals["job_queue_lock"] - - return kwargs - - def get_job_list(self, session, find_state, claim_num): - """Get list of jobs""" - raise NotImplementedError("Not implemented") - - def get_jobs( - self, - session: DbSession, - find_state: List[str], - set_state: str, - session_id: int, - claim_num: int = None, - no_update=False, - ): - """Interface function to get jobs based on session and find_state""" - # job_rows: List[SimpleDict] - ids: list - row: SimpleDict - - self.logger.info("Fetching DB rows...") - job_list = self.get_job_list(session, find_state, claim_num) - - if not self.check_jobs_found(job_list, find_state, session_id): - return [] - - if no_update: - return job_list - - ids = [row.id for row in job_list] - self.logger.info("%s jobs %s", find_state, ids) - self.logger.info("Updating job state to %s", set_state) - - # OPTIMIZATION: Use bulk UPDATE instead of individual updates - if self.dbt is not None: - id_str = ','.join(map(str, ids)) - query = f""" + envmt: Dict[str, Any] = f_vals["envmt"].copy() + kwargs: Dict[str, Any] = {} + + kwargs = { + "gpu_id": gpu_idx, + "envmt": envmt, + "label": self.args.label, + "docker_name": self.args.docker_name, + "session_id": self.args.session_id, + } + + # adding non-serializable obj when not running through celery + if not tuning: + kwargs["machine"] = f_vals["machine"] + kwargs["job_queue"] = f_vals["job_queue"] + kwargs["job_queue_lock"] = f_vals["job_queue_lock"] + kwargs["num_procs"] = f_vals["num_procs"] + kwargs["bar_lock"] = f_vals["bar_lock"] + kwargs["end_jobs"] = f_vals["end_jobs"] + kwargs["job_queue"] = f_vals["job_queue"] + kwargs["job_queue_lock"] = f_vals["job_queue_lock"] + + return kwargs + + def get_job_list(self, session, find_state, claim_num): + """Get list of jobs""" + raise NotImplementedError("Not implemented") + + def get_jobs( + self, + session: DbSession, + find_state: List[str], + set_state: str, + session_id: int, + claim_num: int = None, + no_update=False, + ): + """Interface function to get jobs based on session and find_state""" + # job_rows: List[SimpleDict] + ids: list + row: SimpleDict + + self.logger.info("Fetching DB rows...") + job_list = self.get_job_list(session, find_state, claim_num) + + if not self.check_jobs_found(job_list, find_state, session_id): + return [] + + if no_update: + return job_list + + ids = [row.id for row in job_list] + self.logger.info("%s jobs %s", find_state, ids) + self.logger.info("Updating job state to %s", set_state) + + # OPTIMIZATION: Use bulk UPDATE instead of individual updates + if self.dbt is not None: + id_str = ','.join(map(str, ids)) + query = f""" UPDATE {self.dbt.job_table.__tablename__} SET state = '{set_state}' WHERE id IN ({id_str}) """ - session.execute(query) - - # Update local objects to reflect new state - for job in job_list: - job.state = set_state - else: - raise CustomError("DBTable must be set") - - session.commit() - - return job_list - - def shutdown_workers(self): - """Shutdown all active celery workers regardless of queue""" - return stop_active_workers() - - def cancel_consumer(self, queue): - """Cancel consumers for queue""" + session.execute(query) + + # Update local objects to reflect new state + for job in job_list: + job.state = set_state + else: + raise CustomError("DBTable must be set") + + session.commit() + + return job_list + + def shutdown_workers(self): + """Shutdown all active celery workers regardless of queue""" + return stop_active_workers() + + def cancel_consumer(self, queue): + """Cancel consumers for queue""" + try: + cmd = ( + f"celery -A tuna.celery_app.celery_app control cancel_consumer {queue}" + ) + subp = subprocess.Popen( # pylint: disable=consider-using-with + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + shell=True, + universal_newlines=True, + ) + + # filter the workers by session id + sess_str = "sess_" + queue.split("_")[-1] + stdout, _ = subp.stdout, subp.stderr + while True: + line = stdout.readline() + if not line: + break + # stop workers that were feeding from this queue + if "->" in line and sess_str in line: + hostname = line.split("->")[1].split()[0].split(":")[0] + stop_named_worker(hostname) + + except Exception as exp: # pylint: disable=broad-exception-caught + self.logger.warning( + "Error occurred trying to cancel consumer for queue: %s ", queue) + self.logger.warning(exp) + return False + + self.logger.info("Sucessfully cancelled consumer for queue: %s", queue) + + return True + + def celery_enqueue_call(self, context, q_name, task_id=False): + """Wrapper function for celery enqueue func""" + raise NotImplementedError("Not implemented") + + def enqueue_jobs(self, job_counter, job_batch_size, q_name): + """Enqueue celery jobs with machine-specific progress tracking and error handling""" + self.logger.info("Starting enqueue") + current_batch_size = 0 + + max_retries = 3 + retry_delay = 5 # seconds + consecutive_empty_fetches = 0 + max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) + + while True: + # Retry loop for database operations + for attempt in range(max_retries): try: - cmd = ( - f"celery -A tuna.celery_app.celery_app control cancel_consumer {queue}" - ) - subp = subprocess.Popen( # pylint: disable=consider-using-with - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - shell=True, - universal_newlines=True, - ) - - # filter the workers by session id - sess_str = "sess_" + queue.split("_")[-1] - stdout, _ = subp.stdout, subp.stderr - while True: - line = stdout.readline() - if not line: - break - # stop workers that were feeding from this queue - if "->" in line and sess_str in line: - hostname = line.split("->")[1].split()[0].split(":")[0] - stop_named_worker(hostname) - - except Exception as exp: # pylint: disable=broad-exception-caught - self.logger.warning( - "Error occurred trying to cancel consumer for queue: %s ", queue + with DbSession() as session: + # Check if we should enqueue more jobs based on OUR progress + if current_batch_size > 0: + if not self.should_enqueue_more_jobs(session, current_batch_size): + self.logger.info( + "Waiting for our current batch to progress before enqueuing more" + ) + return # Exit gracefully + + # Get jobs from database + job_list = self.get_jobs( + session, + self.fetch_state, + self.set_state, # pylint: disable=no-member + self.args.session_id, # pylint: disable=no-member + job_batch_size, ) - self.logger.warning(exp) - return False - - self.logger.info("Sucessfully cancelled consumer for queue: %s", queue) - - return True - def celery_enqueue_call(self, context, q_name, task_id=False): - """Wrapper function for celery enqueue func""" - raise NotImplementedError("Not implemented") - - def enqueue_jobs(self, job_counter, job_batch_size, q_name): - """Enqueue celery jobs with machine-specific progress tracking and error handling""" - self.logger.info("Starting enqueue") - current_batch_size = 0 - - max_retries = 3 - retry_delay = 5 # seconds - consecutive_empty_fetches = 0 - max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) - - while True: - # Retry loop for database operations - for attempt in range(max_retries): - try: - with DbSession() as session: - # Check if we should enqueue more jobs based on OUR progress - if current_batch_size > 0: - if not self.should_enqueue_more_jobs(session, current_batch_size): - self.logger.info( - "Waiting for our current batch to progress before enqueuing more" - ) - return # Exit gracefully - - # Get jobs from database - job_list = self.get_jobs( - session, - self.fetch_state, - self.set_state, # pylint: disable=no-member - self.args.session_id, # pylint: disable=no-member - job_batch_size, - ) - - if not job_list: - consecutive_empty_fetches += 1 - self.logger.info('No jobs found (attempt %d/%d)', - consecutive_empty_fetches, max_empty_fetches) - - if consecutive_empty_fetches >= max_empty_fetches: - self.logger.info('No new jobs after %d attempts. Exiting enqueue loop.', - max_empty_fetches) - return # Exit gracefully - - time.sleep(60) # Wait before next check - break # Break retry loop, continue main loop - - # Reset counter when jobs are found - consecutive_empty_fetches = 0 - - # Track the jobs we just claimed - new_job_ids = {job.id for job in job_list} - self.claimed_job_ids.update(new_job_ids) - - self.logger.info("Claimed jobs: %s", list(new_job_ids)) - - with job_counter_lock: - job_counter.value = job_counter.value + len(job_list) - - # Process all jobs in this batch - context_list = self.get_context_list(session, job_list) - for context in context_list: - try: - # calling celery task, enqueuing to celery queue - self.celery_enqueue_call(context, q_name=q_name) - except Exception as enqueue_err: # pylint: disable=broad-exception-caught - self.logger.error('Failed to enqueue job: %s', enqueue_err) - # Continue with other jobs rather than failing completely - continue - - current_batch_size = len(job_list) - self.logger.info( - "Job counter: %s, enqueued batch size: %s", - job_counter.value, - current_batch_size, - ) - - # Cleanup old tracking data periodically - self.cleanup_completed_jobs() - break # Success, break retry loop - - except Exception as db_err: # pylint: disable=broad-exception-caught - self.logger.warning('Database error on attempt %d/%d: %s', - attempt + 1, max_retries, db_err) - if attempt < max_retries - 1: - time.sleep(retry_delay * (attempt + 1)) # Exponential backoff - else: - self.logger.error('Max retries exceeded for database operation. Exiting.') - raise - - # If we got here with no jobs, the consecutive_empty_fetches logic handled it if not job_list: - continue + consecutive_empty_fetches += 1 + self.logger.info('No jobs found (attempt %d/%d)', + consecutive_empty_fetches, max_empty_fetches) - def should_enqueue_more_jobs(self, session, current_batch_size): - """Check if we should enqueue more jobs based on THIS instance's progress""" - # Count only jobs claimed by this machine instance - our_in_progress_count = len(self.claimed_job_ids - self.completed_job_ids) - - # Allow enqueuing when less than 25% of our claimed jobs are still in progress - progress_threshold = current_batch_size * self.progress_factor - - self.logger.info( - "Our jobs in progress: %d, completed: %d, threshold: %d", - our_in_progress_count, - len(self.completed_job_ids), - progress_threshold, - ) - - return our_in_progress_count < progress_threshold - - def cleanup_completed_jobs(self): - """Periodically clean up old job tracking data""" - # Keep sets from growing indefinitely - max_tracking_size = 10000 - if len(self.completed_job_ids) > max_tracking_size: - # Keep only the most recent completions - recent_completions = list(self.completed_job_ids)[-5000:] - self.completed_job_ids = set(recent_completions) - - # Remove old claimed jobs that are completed - self.claimed_job_ids -= set(recent_completions[:-1000]) - - async def cleanup_redis_results(self, prefix): - """Remove stale redis results by key""" - backend_port, backend_host = get_backend_env() - redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") - - keys = [] - cursor = "0" - if prefix: - # a prefix is necessary when the need to different results in redis based on operation - # withough a prefix the redis key defaults to: "celery-task-meta-" - # with a prefix the key will look like: "celery-task-meta--" - # the prefix can be applied when filtering the redis keys as bellow - cursor, results = await redis.scan(cursor, match=f"*{prefix}*") - else: - # no prefix, match any key - cursor, results = await redis.scan(cursor, match="*") - keys.extend(results) - self.logger.info("Found %s old results", len(results)) - for key in keys: - try: - await redis.delete(key) - except aioredis.exceptions.ResponseError as red_err: - self.logger.error(red_err) - self.logger.info(key.decode("utf-8")) - continue + if consecutive_empty_fetches >= max_empty_fetches: + self.logger.info( + 'No new jobs after %d attempts. Exiting enqueue loop.', + max_empty_fetches) + return # Exit gracefully - self.logger.info("Done removing old redis results for prefix: %s", prefix) + time.sleep(60) # Wait before next check + break # Break retry loop, continue main loop - return True - - async def consume(self, job_counter, prefix): - """Retrieve celery results from redis db""" - - backend_port, backend_host = get_backend_env() - redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") - - while job_counter.value > 0: - cursor = "0" - keys = [] - while cursor != 0: - if prefix: - # a prefix is necessary when the need to different results in redis based on operation - # withough a prefix the redis key defaults to: "celery-task-meta-" - # with a prefix the key will look like: "celery-task-meta--" - # the prefix can be applied when filtering the redis keys as bellow - cursor, results = await redis.scan(cursor, match=f"*{prefix}*") - else: - # no prefix, match any key - cursor, results = await redis.scan(cursor, match="*") - keys.extend(results) - self.logger.info("Found %s results", len(results)) - for key in keys: - try: - data = await redis.get(key) - if data: - _ = await self.parse_result(data.decode("utf-8")) - await redis.delete(key) - with job_counter_lock: - job_counter.value = job_counter.value - 1 - except aioredis.exceptions.ResponseError as red_err: - self.logger.error(red_err) - self.logger.info(key.decode("utf-8")) - - await asyncio.sleep(1) - self.logger.info("Job counter reached 0") - await redis.close() - - return True - - def prep_tuning(self): - """Prep env for tuning start""" - cmd = None - subp_list = [] - q_name = None - if self.operation == Operation.COMPILE: - q_name = get_q_name(self, op_compile=True) - cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{self.args.session_id} -Q {q_name}" # pylint: disable=line-too-long - else: - q_name = get_q_name(self, op_eval=True) - cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -c 1 -n tuna_HOSTNAME_sess_{self.args.session_id}_gpu_id_GPUID -Q {q_name}" # pylint: disable=line-too-long - - self.logger.info("celery Q name: %s", q_name) - if not self.args.enqueue_only: - try: - self.logger.info("Launching celery workers for queue %s", q_name) - subp_list = launch_celery_worker(self.operation, cmd, self.args, True) - self.logger.info("Done launching celery workers") - if not subp_list: - raise CustomError("Could not launch celery worker") - except kombu.exceptions.OperationalError as k_err: - self.logger.error("Redis error ocurred: %s", k_err) - return False - else: - purge_queue([q_name]) - - return q_name, subp_list - - # pylint: disable=too-many-locals - def tune(self, job_batch_size=1000): - """tuning loop to spin out celery tasks""" - - if self.args.shutdown_workers: - self.logger.info("Shutting down all celery workers") - stop_active_workers() - return True + # Reset counter when jobs are found + consecutive_empty_fetches = 0 - try: - q_name, subp_list = self.prep_tuning() - except CustomError as verr: - self.logger.error(verr) - return False - - try: - # if enqueue_only is False, we launch the celery workers - if not self.args.enqueue_only: - for subp in subp_list: - subp.wait() - return True - except KeyboardInterrupt: - for subp in subp_list: - subp.kill() - return False - - start = time.time() - - # set job count to 1 until first job fetch is finished - job_counter = Value("i", 1) - try: - enqueue_proc = Process( - target=self.enqueue_jobs, args=[job_counter, job_batch_size, q_name] - ) - # Start enqueue proc - enqueue_proc.start() + # Track the jobs we just claimed + new_job_ids = {job.id for job in job_list} + self.claimed_job_ids.update(new_job_ids) - # cleanup old results - cleanup_proc = Process( - target=self.async_wrap, args=(self.cleanup_redis_results, self.prefix) - ) - cleanup_proc.start() - cleanup_proc.join() - - # start async consume thread, blocking - consume_proc = Process( - target=self.async_wrap, args=(self.consume, job_counter, self.prefix) - ) - self.logger.info("Starting consume thread") - consume_proc.start() + self.logger.info("Claimed jobs: %s", list(new_job_ids)) - enqueue_proc.join() - # enqueue finished first fetch, remove hold on job_counter with job_counter_lock: - job_counter.value = job_counter.value - 1 - - # Progress-aware polling - shorter intervals, smarter enqueuing - poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 5)) - - # check for new jobs - while consume_proc.is_alive(): - enqueue_proc = Process( - target=self.enqueue_jobs, args=[job_counter, job_batch_size, q_name] - ) - enqueue_proc.start() - enqueue_proc.join() - time.sleep(poll_interval) # Shorter, configurable polling - - consume_proc.join() - - except ( - KeyboardInterrupt, - Exception, - ) as exp: # pylint: disable=broad-exception-caught - self.logger.error("Error ocurred %s", exp) - purge_queue([q_name]) - self.cancel_consumer(q_name) - self.reset_job_state_on_ctrl_c() - with job_counter_lock: - job_counter.value = 0 + job_counter.value = job_counter.value + len(job_list) + + # Process all jobs in this batch + context_list = self.get_context_list(session, job_list) + for context in context_list: + try: + # calling celery task, enqueuing to celery queue + self.celery_enqueue_call(context, q_name=q_name) + except Exception as enqueue_err: # pylint: disable=broad-exception-caught + self.logger.error('Failed to enqueue job: %s', enqueue_err) + # Continue with other jobs rather than failing completely + continue - self.cancel_consumer(q_name) - end = time.time() - self.logger.info( - "Took {:0>8} to tune".format( # pylint: disable=consider-using-f-string - str(timedelta(seconds=end - start)) + current_batch_size = len(job_list) + self.logger.info( + "Job counter: %s, enqueued batch size: %s", + job_counter.value, + current_batch_size, ) - ) - - return True - - async def async_callback(self, async_func, *args): - """Wrapper function to await on async function""" - await async_func(*args) - def async_wrap(self, async_func, *args): - """Run async function""" + # Cleanup old tracking data periodically + self.cleanup_completed_jobs() + break # Success, break retry loop + + except Exception as db_err: # pylint: disable=broad-exception-caught + self.logger.warning('Database error on attempt %d/%d: %s', + attempt + 1, max_retries, db_err) + if attempt < max_retries - 1: + time.sleep(retry_delay * (attempt + 1)) # Exponential backoff + else: + self.logger.error( + 'Max retries exceeded for database operation. Exiting.') + raise + + # If we got here with no jobs, the consecutive_empty_fetches logic handled it + if not job_list: + continue + + def should_enqueue_more_jobs(self, session, current_batch_size): + """Check if we should enqueue more jobs based on THIS instance's progress""" + # Count only jobs claimed by this machine instance + our_in_progress_count = len(self.claimed_job_ids - self.completed_job_ids) + + # Allow enqueuing when less than 25% of our claimed jobs are still in progress + progress_threshold = current_batch_size * self.progress_factor + + self.logger.info( + "Our jobs in progress: %d, completed: %d, threshold: %d", + our_in_progress_count, + len(self.completed_job_ids), + progress_threshold, + ) + + return our_in_progress_count < progress_threshold + + def cleanup_completed_jobs(self): + """Periodically clean up old job tracking data""" + # Keep sets from growing indefinitely + max_tracking_size = 10000 + if len(self.completed_job_ids) > max_tracking_size: + # Keep only the most recent completions + recent_completions = list(self.completed_job_ids)[-5000:] + self.completed_job_ids = set(recent_completions) + + # Remove old claimed jobs that are completed + self.claimed_job_ids -= set(recent_completions[:-1000]) + + async def cleanup_redis_results(self, prefix): + """Remove stale redis results by key""" + backend_port, backend_host = get_backend_env() + redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") + + keys = [] + cursor = "0" + if prefix: + # a prefix is necessary when the need to different results in redis based on operation + # withough a prefix the redis key defaults to: "celery-task-meta-" + # with a prefix the key will look like: "celery-task-meta--" + # the prefix can be applied when filtering the redis keys as bellow + cursor, results = await redis.scan(cursor, match=f"*{prefix}*") + else: + # no prefix, match any key + cursor, results = await redis.scan(cursor, match="*") + keys.extend(results) + self.logger.info("Found %s old results", len(results)) + for key in keys: + try: + await redis.delete(key) + except aioredis.exceptions.ResponseError as red_err: + self.logger.error(red_err) + self.logger.info(key.decode("utf-8")) + continue + + self.logger.info("Done removing old redis results for prefix: %s", prefix) + + return True + + async def consume(self, job_counter, prefix): + """Retrieve celery results from redis db""" + + backend_port, backend_host = get_backend_env() + redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") + + while job_counter.value > 0: + cursor = "0" + keys = [] + while cursor != 0: + if prefix: + # a prefix is necessary when the need to different results in redis based on operation + # withough a prefix the redis key defaults to: "celery-task-meta-" + # with a prefix the key will look like: "celery-task-meta--" + # the prefix can be applied when filtering the redis keys as bellow + cursor, results = await redis.scan(cursor, match=f"*{prefix}*") + else: + # no prefix, match any key + cursor, results = await redis.scan(cursor, match="*") + keys.extend(results) + self.logger.info("Found %s results", len(results)) + for key in keys: try: - asyncio.run(self.async_callback(async_func, *args)) - except KeyboardInterrupt: - self.logger.warning("Keyboard interrupt caught, terminating") - - def reset_job_state_on_ctrl_c(self): - """Reset job state for jobs in flight""" - temp_obj = SimpleDict() - temp_obj.session_id = self.args.session_id # pylint: disable=invalid-name - attribs = ["state"] - temp_obj.state = 1 - - self.logger.info("Resetting job state in DB for in flight jobs") - - if self.operation == Operation.COMPILE: - state = 16 - elif self.operation == Operation.EVAL: - state = 12 - - query = gen_update_query( - temp_obj, - attribs, - self.dbt.job_table.__tablename__, - [("session", self.args.session_id), ("state", state)], - ) - with DbSession() as session: - - # pylint: disable=duplicate-code - def callback() -> bool: - session.execute(query) - session.commit() - return True - - # pylint: enable=duplicate-code - - assert session_retry(session, callback, lambda x: x(), self.logger) - self.logger.info("Sucessfully reset job state") - return True - + data = await redis.get(key) + if data: + _ = await self.parse_result(data.decode("utf-8")) + await redis.delete(key) + with job_counter_lock: + job_counter.value = job_counter.value - 1 + except aioredis.exceptions.ResponseError as red_err: + self.logger.error(red_err) + self.logger.info(key.decode("utf-8")) + + await asyncio.sleep(1) + self.logger.info("Job counter reached 0") + await redis.close() + + return True + + def prep_tuning(self): + """Prep env for tuning start""" + cmd = None + subp_list = [] + q_name = None + if self.operation == Operation.COMPILE: + q_name = get_q_name(self, op_compile=True) + cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{self.args.session_id} -Q {q_name}" # pylint: disable=line-too-long + else: + q_name = get_q_name(self, op_eval=True) + cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -c 1 -n tuna_HOSTNAME_sess_{self.args.session_id}_gpu_id_GPUID -Q {q_name}" # pylint: disable=line-too-long + + self.logger.info("celery Q name: %s", q_name) + if not self.args.enqueue_only: + try: + self.logger.info("Launching celery workers for queue %s", q_name) + subp_list = launch_celery_worker(self.operation, cmd, self.args, True) + self.logger.info("Done launching celery workers") + if not subp_list: + raise CustomError("Could not launch celery worker") + except kombu.exceptions.OperationalError as k_err: + self.logger.error("Redis error ocurred: %s", k_err) return False - - def has_tunable_operation(self): - """Check if current operation is a tuning operation""" - raise NotImplementedError("Not implemented") - - def get_job_attr(self): - """Get job attr for row selection""" - job_attr: List[str] = None - try: - job_attr = [column.name for column in inspect(self.dbt.job_table).c] - job_attr.remove("insert_ts") - job_attr.remove("update_ts") - except NoInspectionAvailable as error: - self.logger.warning("Ignoring error for init_session: %s", error) - return job_attr - - def check_jobs_found( - self, job_rows: List[SimpleDict], find_state: List[Any], session_id: int - ) -> bool: - """check for end of jobs""" - if not job_rows: - # we are done - self.logger.warning("No %s jobs found, session %s", find_state, session_id) - return False + else: + purge_queue([q_name]) + + return q_name, subp_list + + # pylint: disable=too-many-locals + def tune(self, job_batch_size=1000): + """tuning loop to spin out celery tasks""" + + if self.args.shutdown_workers: + self.logger.info("Shutting down all celery workers") + stop_active_workers() + return True + + try: + q_name, subp_list = self.prep_tuning() + except CustomError as verr: + self.logger.error(verr) + return False + + try: + # if enqueue_only is False, we launch the celery workers + if not self.args.enqueue_only: + for subp in subp_list: + subp.wait() return True + except KeyboardInterrupt: + for subp in subp_list: + subp.kill() + return False + + start = time.time() + + # set job count to 1 until first job fetch is finished + job_counter = Value("i", 1) + try: + enqueue_proc = Process(target=self.enqueue_jobs, + args=[job_counter, job_batch_size, q_name]) + # Start enqueue proc + enqueue_proc.start() + + # cleanup old results + cleanup_proc = Process(target=self.async_wrap, + args=(self.cleanup_redis_results, self.prefix)) + cleanup_proc.start() + cleanup_proc.join() + + # start async consume thread, blocking + consume_proc = Process(target=self.async_wrap, + args=(self.consume, job_counter, self.prefix)) + self.logger.info("Starting consume thread") + consume_proc.start() + + enqueue_proc.join() + # enqueue finished first fetch, remove hold on job_counter + with job_counter_lock: + job_counter.value = job_counter.value - 1 + + # Progress-aware polling - shorter intervals, smarter enqueuing + poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 5)) + + # check for new jobs + while consume_proc.is_alive(): + enqueue_proc = Process(target=self.enqueue_jobs, + args=[job_counter, job_batch_size, q_name]) + enqueue_proc.start() + enqueue_proc.join() + time.sleep(poll_interval) # Shorter, configurable polling + + consume_proc.join() + + except ( + KeyboardInterrupt, + Exception, + ) as exp: # pylint: disable=broad-exception-caught + self.logger.error("Error ocurred %s", exp) + purge_queue([q_name]) + self.cancel_consumer(q_name) + self.reset_job_state_on_ctrl_c() + with job_counter_lock: + job_counter.value = 0 + + self.cancel_consumer(q_name) + end = time.time() + self.logger.info("Took {:0>8} to tune".format( # pylint: disable=consider-using-f-string + str(timedelta(seconds=end - start)))) + + return True + + async def async_callback(self, async_func, *args): + """Wrapper function to await on async function""" + await async_func(*args) + + def async_wrap(self, async_func, *args): + """Run async function""" + try: + asyncio.run(self.async_callback(async_func, *args)) + except KeyboardInterrupt: + self.logger.warning("Keyboard interrupt caught, terminating") + + def reset_job_state_on_ctrl_c(self): + """Reset job state for jobs in flight""" + temp_obj = SimpleDict() + temp_obj.session_id = self.args.session_id # pylint: disable=invalid-name + attribs = ["state"] + temp_obj.state = 1 + + self.logger.info("Resetting job state in DB for in flight jobs") + + if self.operation == Operation.COMPILE: + state = 16 + elif self.operation == Operation.EVAL: + state = 12 + + query = gen_update_query( + temp_obj, + attribs, + self.dbt.job_table.__tablename__, + [("session", self.args.session_id), ("state", state)], + ) + with DbSession() as session: + + # pylint: disable=duplicate-code + def callback() -> bool: + session.execute(query) + session.commit() + return True + + # pylint: enable=duplicate-code + + assert session_retry(session, callback, lambda x: x(), self.logger) + self.logger.info("Sucessfully reset job state") + return True + + return False + + def has_tunable_operation(self): + """Check if current operation is a tuning operation""" + raise NotImplementedError("Not implemented") + + def get_job_attr(self): + """Get job attr for row selection""" + job_attr: List[str] = None + try: + job_attr = [column.name for column in inspect(self.dbt.job_table).c] + job_attr.remove("insert_ts") + job_attr.remove("update_ts") + except NoInspectionAvailable as error: + self.logger.warning("Ignoring error for init_session: %s", error) + return job_attr + + def check_jobs_found(self, job_rows: List[SimpleDict], find_state: List[Any], + session_id: int) -> bool: + """check for end of jobs""" + if not job_rows: + # we are done + self.logger.warning("No %s jobs found, session %s", find_state, + session_id) + return False + return True + + @lru_cache(1) + def get_context_items(self): + """Helper function to get items for celery job context""" + kwargs = None + f_vals = self.get_f_vals(Machine(local_machine=True), range(0), tuning=True) + kwargs = self.get_kwargs(0, f_vals, tuning=True) + return kwargs + + def serialize_jobs(self, session, batch_jobs): + """Return list of serialize jobs""" + raise NotImplementedError("Not implemented") + + def build_context(self, serialized_jobs): + """Build context list for enqueue job""" + raise NotImplementedError("Not implemented") + + def get_context_list(self, session, batch_jobs): + """Return list of jobs (context) for celery queue""" + + context_list: List[dict] = None + serialized_jobs = self.serialize_jobs(session, batch_jobs) + # build context for each celery task + context_list = self.build_context(serialized_jobs) + + return context_list + + async def parse_result(self, data): + """Function callback for celery async jobs to store results""" + data = json.loads(data) + + with DbSession() as session: + try: + fin_json = data["result"]["ret"] + context = data["result"]["context"] + + # Extract job ID from context to track completion + job_id = self.extract_job_id_from_context(context) + if job_id and job_id in self.claimed_job_ids: + self.completed_job_ids.add(job_id) + self.logger.info("Marked job %s as completed", job_id) + + except KeyError as kerr: + self.logger.error(kerr) + return False - @lru_cache(1) - def get_context_items(self): - """Helper function to get items for celery job context""" - kwargs = None - f_vals = self.get_f_vals(Machine(local_machine=True), range(0), tuning=True) - kwargs = self.get_kwargs(0, f_vals, tuning=True) - return kwargs - - def serialize_jobs(self, session, batch_jobs): - """Return list of serialize jobs""" - raise NotImplementedError("Not implemented") - - def build_context(self, serialized_jobs): - """Build context list for enqueue job""" - raise NotImplementedError("Not implemented") - - def get_context_list(self, session, batch_jobs): - """Return list of jobs (context) for celery queue""" - - context_list: List[dict] = None - serialized_jobs = self.serialize_jobs(session, batch_jobs) - # build context for each celery task - context_list = self.build_context(serialized_jobs) - - return context_list - - async def parse_result(self, data): - """Function callback for celery async jobs to store results""" - data = json.loads(data) - - with DbSession() as session: - try: - fin_json = data["result"]["ret"] - context = data["result"]["context"] - - # Extract job ID from context to track completion - job_id = self.extract_job_id_from_context(context) - if job_id and job_id in self.claimed_job_ids: - self.completed_job_ids.add(job_id) - self.logger.info("Marked job %s as completed", job_id) - - except KeyError as kerr: - self.logger.error(kerr) - return False - - self.logger.info("Parsing: %s", fin_json) - if self.operation == Operation.COMPILE: - self.process_compile_results(session, fin_json, context) - elif self.operation == Operation.EVAL: - self.process_eval_results(session, fin_json, context) - else: - raise CustomError("Unsupported tuning operation") - - return True - - def extract_job_id_from_context(self, context): - """Extract job ID from celery task context""" - # This needs to be implemented in the MIOpen subclass - # based on how job IDs are stored in the context - raise NotImplementedError("Subclass must implement job ID extraction") - - def process_compile_results(self, session, fin_json, context): - """Process result from fin_build worker""" - raise NotImplementedError("Not implemented") - - def process_eval_results(self, session, fin_json, context): - """Process fin_json result""" - raise NotImplementedError("Not implemented") + self.logger.info("Parsing: %s", fin_json) + if self.operation == Operation.COMPILE: + self.process_compile_results(session, fin_json, context) + elif self.operation == Operation.EVAL: + self.process_eval_results(session, fin_json, context) + else: + raise CustomError("Unsupported tuning operation") + + return True + + def extract_job_id_from_context(self, context): + """Extract job ID from celery task context""" + # This needs to be implemented in the MIOpen subclass + # based on how job IDs are stored in the context + raise NotImplementedError("Subclass must implement job ID extraction") + + def process_compile_results(self, session, fin_json, context): + """Process result from fin_build worker""" + raise NotImplementedError("Not implemented") + + def process_eval_results(self, session, fin_json, context): + """Process fin_json result""" + raise NotImplementedError("Not implemented") From a19f1bb642e61446e798cb0191bc848c07860376 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Thu, 6 Nov 2025 04:39:15 -0600 Subject: [PATCH 07/10] changed default base image and properly passed it through to second docker build stage --- Dockerfile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index cc493e68..0a9ce625 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ ARG OSDB_BKC_VERSION= ARG HASVER=${ROCMVERSION:+$ROCMVERSION} ARG HASVER=${HASVER:-$OSDB_BKC_VERSION} -ARG BASEIMAGE=rocm/miopen:ci_3708da +ARG BASEIMAGE=rocm/miopen:ci_7c45f0 ARG UBUNTU=ubuntu:22.04 #use UBUNTU with rocm version set @@ -18,6 +18,8 @@ FROM $USEIMAGE as dtuna-ver-0 #args before from are wiped ARG ROCMVERSION= ARG OSDB_BKC_VERSION= +# pass through baseimage for later use +ARG BASEIMAGE RUN test -d /opt/rocm*; \ if [ $? -eq 0 ] ; then \ From c396afcba4e36b1e54483efc76cb0dc7bb0dd8e3 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Fri, 7 Nov 2025 03:48:42 -0600 Subject: [PATCH 08/10] changed to newer version of clang-format (12 no longer available) --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 0a9ce625..459f1da4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -73,7 +73,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --all apt-utils \ build-essential \ cmake \ - clang-format-12 \ + clang-format \ curl \ doxygen \ gdb \ From 0a8fbbaafab3d1720941e5943b54c782af9f3c19 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Fri, 7 Nov 2025 08:40:17 -0600 Subject: [PATCH 09/10] changed branch of MIOpen --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 459f1da4..9e7dd23e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -127,8 +127,8 @@ RUN git clone --filter=blob:none --sparse https://github.com/ROCm/rocm-libraries WORKDIR $MIOPEN_DIR RUN git sparse-checkout set projects/miopen # not sure what this commit is, using latest develop for now -# ARG MIOPEN_BRANCH=4940cf3ec -ARG MIOPEN_BRANCH=develop +ARG MIOPEN_BRANCH=5564e20238 +# ARG MIOPEN_BRANCH=develop RUN git pull && git checkout $MIOPEN_BRANCH ARG PREFIX=/opt/rocm From 50eace55a15bd8464ec571f0374b2d6d5de7e705 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 11 Nov 2025 11:06:17 -0600 Subject: [PATCH 10/10] Fixed bug where distributor would not grab new batch when old batch runs out. --- tuna/mituna_interface.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 4c03ed0b..f383a106 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -379,7 +379,7 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): self.logger.info( "Waiting for our current batch to progress before enqueuing more" ) - return # Exit gracefully + break # Exit retry loop, will wait and check again # Get jobs from database job_list = self.get_jobs( @@ -448,6 +448,12 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): 'Max retries exceeded for database operation. Exiting.') raise + # If we broke out because we're waiting for progress, sleep before next check + if current_batch_size > 0 and (not job_list or not self.should_enqueue_more_jobs(None, current_batch_size)): + self.logger.info("Sleeping 60s before checking for more jobs...") + time.sleep(60) + continue + # If we got here with no jobs, the consecutive_empty_fetches logic handled it if not job_list: continue