From a8a51bf826b70ea12494651ea7dcaffe47cae8a2 Mon Sep 17 00:00:00 2001 From: Adam Novak Date: Thu, 12 Feb 2026 15:33:47 -0800 Subject: [PATCH 1/5] Implement exponential backoff delay for job retries --- docs/running/cliOptions.rst | 8 ++- src/toil/common.py | 4 ++ src/toil/job.py | 73 +++++++++++++++++--------- src/toil/jobStores/abstractJobStore.py | 2 +- src/toil/leader.py | 48 ++++++++++++++--- src/toil/options/common.py | 20 +++++++ src/toil/test/src/fileStoreTest.py | 2 + src/toil/test/src/jobTest.py | 48 ++++++++++++++++- src/toil/worker.py | 2 +- 9 files changed, 170 insertions(+), 37 deletions(-) diff --git a/docs/running/cliOptions.rst b/docs/running/cliOptions.rst index 06da91a007..847e81d614 100644 --- a/docs/running/cliOptions.rst +++ b/docs/running/cliOptions.rst @@ -475,8 +475,14 @@ systems have issues!). --retryCount INT Number of times to retry a failing job before giving up and labeling job failed. default=1 + --retryBackoffSeconds FLOAT + Number of seconds to wait when first retrying a job. + default=2 + --retryBackoffFactor FLOAT + Factor to increase retry backof time by for each + additional retry. default=3 --stopOnFirstFailure BOOL - Stop the workflow at the first complete job failure. + Stop the workflow at the first complete job failure. --enableUnlimitedPreemptibleRetries If set, preemptible failures (or any failure due to an instance getting unexpectedly terminated) will not count diff --git a/src/toil/common.py b/src/toil/common.py index 07beb40c5c..25639f448b 100644 --- a/src/toil/common.py +++ b/src/toil/common.py @@ -218,6 +218,8 @@ class Config: # Retrying/rescuing jobs retryCount: int + retry_backoff_seconds: float + retry_backoff_factor: float stop_on_first_failure: bool enableUnlimitedPreemptibleRetries: bool doubleMem: bool @@ -397,6 +399,8 @@ def set_option(option_name: str, old_names: list[str] | None = None) -> None: # Retrying/rescuing jobs set_option("retryCount") + set_option("retry_backoff_seconds") + set_option("retry_backoff_factor") set_option("stop_on_first_failure") set_option("enableUnlimitedPreemptibleRetries") set_option("doubleMem") diff --git a/src/toil/job.py b/src/toil/job.py index 43cef05631..fb6ef6432e 100644 --- a/src/toil/job.py +++ b/src/toil/job.py @@ -871,6 +871,10 @@ def makeString(x: str | bytes | None) -> str: # default value for this workflow execution. self._remainingTryCount = None + # The number of seconds to back off before retry. + # Gets increased each time the job fails, and reset at workflow restart. + self._retry_backoff_seconds = None + # Holds FileStore FileIDs of the files that should be seen as deleted, # as part of a transaction with the writing of this version of the job # to the job store. Used to journal deletions of files and recover from @@ -1342,6 +1346,23 @@ def onRegistration(self, jobStore: AbstractJobStore) -> None: :param jobStore: The job store we are being placed into """ + def chargeRetry(self) -> None: + """ + Charge the job one of its remaining retries. + + Manages exponential backoff of retried jobs. + + On completion, self.retry_backoff_seconds will be the time to wait + before the next retry. + """ + self.remainingTryCount = max(0, self.remainingTryCount - 1) + if self._retry_backoff_seconds is None: + # This was the first retry + self._retry_backoff_seconds = self._config.retry_backoff_seconds + else: + self._retry_backoff_seconds *= self._config.retry_backoff_factor + + def setupJobAfterFailure( self, exit_status: int | None = None, @@ -1380,7 +1401,7 @@ def setupJobAfterFailure( self.jobStoreID, ) else: - self.remainingTryCount = max(0, self.remainingTryCount - 1) + self.chargeRetry() logger.warning( "Due to failure we are reducing the remaining try count of job %s with ID %s to %s", self, @@ -1442,16 +1463,29 @@ def remainingTryCount(self): def remainingTryCount(self, val): self._remainingTryCount = val - def clearRemainingTryCount(self) -> bool: + @property + def retry_backoff_seconds(self) -> float: + """ + Get the number of seconds to wait before retrying this job. + """ + if self._retry_backoff_seconds is None: + # We need to distinguish the state of not yet having charged a retry. + # The config has the value to use after charging for the first retry. + return 0.0 + else: + return self._retry_backoff_seconds + + def resetRetries(self) -> bool: """ - Clear remainingTryCount and set it back to its default value. + Clear retry system values back to the workflow defaults. :returns: True if a modification to the JobDescription was made, and False otherwise. """ - if self._remainingTryCount is not None: + if self._remainingTryCount is not None or self._retryBackoff is not None: # We had a value stored self._remainingTryCount = None + self._retryBackoff = None return True else: # No change needed @@ -2474,10 +2508,7 @@ class Runner: """Used to setup and run Toil workflow.""" @staticmethod - def getDefaultArgumentParser( - jobstore_as_flag: bool = False, - config_option: str | None = None, - ) -> ArgParser: + def getDefaultArgumentParser(jobstore_as_flag: bool = False) -> ArgParser: """ Get argument parser with added toil workflow options. @@ -2485,7 +2516,6 @@ def getDefaultArgumentParser( workflow. :param jobstore_as_flag: make the job store option a --jobStore flag instead of a required jobStore positional argument. - :param config_option: If set, use this string for the Toil --config option instead of "config". :returns: The argument parser used by a toil workflow with added Toil options. """ parser = ArgParser(formatter_class=ArgumentDefaultsHelpFormatter) @@ -2494,9 +2524,7 @@ def getDefaultArgumentParser( @staticmethod def getDefaultOptions( - jobStore: StrPath | None = None, - jobstore_as_flag: bool = False, - config_option: str | None = None, + jobStore: StrPath | None = None, jobstore_as_flag: bool = False ) -> Namespace: """ Get default options for a toil workflow. @@ -2504,7 +2532,6 @@ def getDefaultOptions( :param jobStore: A string describing the jobStore \ for the workflow. :param jobstore_as_flag: make the job store option a --jobStore flag instead of a required jobStore positional argument. - :param config_option: If set, use this string for the Toil --config option instead of "config". :returns: The options used by a toil workflow. """ # setting jobstore_as_flag to True allows the user to declare the jobstore in the config file instead @@ -2514,8 +2541,7 @@ def getDefaultOptions( "to False!" ) parser = Job.Runner.getDefaultArgumentParser( - jobstore_as_flag=jobstore_as_flag, - config_option=config_option, + jobstore_as_flag=jobstore_as_flag ) arguments = [] if jobstore_as_flag and jobStore is not None: @@ -2528,7 +2554,6 @@ def getDefaultOptions( def addToilOptions( parser: OptionParser | ArgumentParser, jobstore_as_flag: bool = False, - config_option: str | None = None, ) -> None: """ Adds the default toil options to an :mod:`optparse` or :mod:`argparse` @@ -2543,13 +2568,8 @@ def addToilOptions( :param parser: Options object to add toil options to. :param jobstore_as_flag: make the job store option a --jobStore flag instead of a required jobStore positional argument. - :param config_option: If set, use this string for the Toil --config option instead of "config". """ - addOptions( - parser, - jobstore_as_flag=jobstore_as_flag, - config_option=config_option, - ) + addOptions(parser, jobstore_as_flag=jobstore_as_flag) @staticmethod def startToil(job: Job, options) -> Any: @@ -2738,7 +2758,7 @@ def _fulfillPromises(self, returnValues, jobStore): # File may be gone if the job is a service being re-run and the accessing job is # already complete. if jobStore.file_exists(promiseFileStoreID): - logger.debug( + logger.info( "Resolve promise %s from %s with a %s", promiseFileStoreID, self, @@ -4204,6 +4224,11 @@ class Promise: A set of IDs of files containing promised values when we know we won't need them anymore """ + resolving = True + """ + Set to False to disable promise resolution for debugging. + """ + def __init__(self, job: Job, path: Any): """ Initialize this promise. @@ -4241,7 +4266,7 @@ def __new__(cls, *args) -> Promise: raise RuntimeError( "Cannot instantiate promise. Invalid number of arguments given (Expected 2)." ) - if isinstance(args[0], Job): + if not cls.resolving or isinstance(args[0], Job): # Regular instantiation when promise is created, before it is being pickled return super().__new__(cls) else: diff --git a/src/toil/jobStores/abstractJobStore.py b/src/toil/jobStores/abstractJobStore.py index 4279240770..c01865e071 100644 --- a/src/toil/jobStores/abstractJobStore.py +++ b/src/toil/jobStores/abstractJobStore.py @@ -885,7 +885,7 @@ def replaceFlagsIfNeeded(serviceJobDescription: JobDescription) -> None: changed[0] = True # Reset the try count of the JobDescription so it will use the default. - changed[0] |= jobDescription.clearRemainingTryCount() + changed[0] |= jobDescription.resetRetries() # This cleans the old log file which may # have been left if the job is being retried after a failure. diff --git a/src/toil/leader.py b/src/toil/leader.py index 68fbe8e39e..ba8fa2210c 100644 --- a/src/toil/leader.py +++ b/src/toil/leader.py @@ -16,6 +16,7 @@ import base64 import glob +import heapq import logging import os import pickle @@ -123,6 +124,13 @@ def __init__( # state change information about jobs. self.toilState = ToilState(self.jobStore) + # We keep a min-heap of jobs that shouldn't be reissued until a + # particular time in seconds since epoch, as (time, JobDescription) + # pairs. When jobs fail, we can put them here to let them back off + # before being issued again. These jobs will have already been "ready" + # and just need to be issued at the right time. + self._waiting_jobs: list[tuple[float, JobDescription]] = [] + # Message bus messages need to go to the given file. # Keep a reference to the return value so the listener stays alive. self._message_subscription = self.toilState.bus.connect_output_file( @@ -594,7 +602,7 @@ def _processFailedSuccessors(self, predecessor_id: str): ) self.processTotallyFailedJob(predecessor_id) - def _processReadyJob(self, job_id: str, result_status: int): + def _process_ready_job(self, job_id: str, result_status: int): # We operate on the JobDescription mostly. readyJob = self.toilState.get_job(job_id) @@ -638,7 +646,14 @@ def _processReadyJob(self, job_id: str, result_status: int): logger.warning("Job %s is completely failed", readyJob) else: # Otherwise try the job again - self.issueJob(readyJob) + if result_status != 0 and readyJob.retry_backoff_seconds > 0: + # We don't want to reissue the job until some time has passed + target_time = time.time() + readyJob.retry_backoff_seconds + logger.info("Waiting %s seconds to retry job %s", readyJob.retry_backoff_seconds, readyJob) + heapq.heappush(self._waiting_jobs, (target_time, readyJob)) + else: + # We can issue the job now + self.issueJob(readyJob) elif next(readyJob.serviceHostIDsInBatches(), None) is not None: # the job has services to run, which have not been started, start them # Build a map from the service jobs to the job and a map @@ -725,7 +740,7 @@ def _processReadyJob(self, job_id: str, result_status: int): readyJob.jobStoreID, ) - def _processReadyJobs(self): + def _process_ready_jobs(self): """Process jobs that are ready to be scheduled/have successors to schedule.""" logger.debug( "Built the jobs list, currently have %i jobs to update and %i jobs issued", @@ -772,8 +787,21 @@ def _processReadyJobs(self): continue else: # New job for this tick so actually handle that it is updated - self._processReadyJob(message.job_id, message.result_status) + self._process_ready_job(message.job_id, message.result_status) handled_with_status[message.job_id] = message.result_status + + def _process_waiting_jobs(self): + """ + See if any jobs in the waiting-job min-heap are ready to issue. + + If so, reissues them. + """ + + while len(self._waiting_jobs) > 0 and self._waiting_jobs[0][0] <= time.time(): + # The next job is ready! + ready_time, ready_job = heapq.heappop(self._waiting_jobs) + logger.info("Job %s has finished its backoff period", ready_job) + self.issueJob(ready_job) def _startServiceJobs(self): """Start any service jobs available from the service manager.""" @@ -913,12 +941,16 @@ def innerLoop(self): while ( self._messages.count(JobUpdatedMessage) > 0 + or len(self._waiting_jobs) > 0 or self.getNumberOfJobsIssued() or self.serviceManager.get_job_count() ): if self._messages.count(JobUpdatedMessage) > 0: - self._processReadyJobs() + self._process_ready_jobs() + + if len(self._waiting_jobs) > 0: + self._process_waiting_jobs() # deal with service-related jobs self._startServiceJobs() @@ -988,8 +1020,8 @@ def checkForDeadlocks(self): totalRunningJobs = len(self.batchSystem.getRunningBatchJobIDs()) totalServicesIssued = self.serviceJobsIssued + self.preemptibleServiceJobsIssued - # If there are no updated jobs and at least some jobs running - if totalServicesIssued >= totalRunningJobs and totalRunningJobs > 0: + # If there are no updated jobs and no waiting jobs and at least some jobs running + if totalServicesIssued >= totalRunningJobs and len(self._waiting_jobs) == 0 and totalRunningJobs > 0: # Collect all running service job store IDs into a set to compare with the deadlock set running_service_ids: set[str] = set() for js_id in self.issued_jobs_by_batch_system_id.values(): @@ -1067,7 +1099,7 @@ def checkForDeadlocks(self): # We have observed non-service jobs running, so reset the potential deadlock self.feed_deadlock_watchdog() else: - # We have observed non-service jobs running, so reset the potential deadlock. + # We have observed non-service jobs running, or jobs waiting, so reset the potential deadlock. self.feed_deadlock_watchdog() def feed_deadlock_watchdog(self) -> None: diff --git a/src/toil/options/common.py b/src/toil/options/common.py index 04d35cba81..c39ed5bdbb 100644 --- a/src/toil/options/common.py +++ b/src/toil/options/common.py @@ -864,6 +864,26 @@ def __call__( help=f"Number of times to retry a failing job before giving up and " f"labeling job failed. default={1}", ) + job_options.add_argument( + "--retryBackoffSeconds", + dest="retry_backoff_seconds", + default=2, + type=float, + action=make_open_interval_action(0), + metavar="FLOAT", + help=f"Number of seconds to wait when first retrying a job. " + f"default={10}", + ) + job_options.add_argument( + "--retryBackoffFactor", + dest="retry_backoff_factor", + default=3, + type=float, + action=make_open_interval_action(1), + metavar="FLOAT", + help=f"Factor to increase retry backof time by for each " + f"additional retry. default={2}", + ) job_options.add_argument( "--stopOnFirstFailure", dest="stop_on_first_failure", diff --git a/src/toil/test/src/fileStoreTest.py b/src/toil/test/src/fileStoreTest.py index c6379e22cf..92dd6ea20a 100644 --- a/src/toil/test/src/fileStoreTest.py +++ b/src/toil/test/src/fileStoreTest.py @@ -493,6 +493,7 @@ def testExtremeCacheSetup(self): jobs[i].addChild(F) options = self.options() options.retryCount = 10 + options.retry_backoff_seconds = 0 options.badWorker = 0.25 options.badWorkerFailInterval = 0.2 Job.Runner.startToil(E, options) @@ -1091,6 +1092,7 @@ def testReturnFileSizesWithBadWorker(self): """ options = self.options() options.retryCount = 20 + options.retry_backoff_seconds = 0 options.badWorker = 0.5 options.badWorkerFailInterval = 0.1 workdir = self._createTempDir(purpose="nonLocalDir") diff --git a/src/toil/test/src/jobTest.py b/src/toil/test/src/jobTest.py index c77b14e31f..299722c696 100644 --- a/src/toil/test/src/jobTest.py +++ b/src/toil/test/src/jobTest.py @@ -14,6 +14,7 @@ import collections import os import random +import time from collections.abc import Callable from pathlib import Path from typing import Any, Callable, NoReturn, cast @@ -71,9 +72,10 @@ def testStatic(self, tmp_path: Path) -> None: options = Job.Runner.getDefaultOptions(tmp_path / "jobstore") options.logLevel = "INFO" options.retryCount = 100 + options.retry_backoff_seconds = 0 options.badWorker = 0.5 options.badWorkerFailInterval = 0.01 - # Run the workflow, the return value being the number of failed jobs + # Run the workflow, or raise if it does not succeed Job.Runner.startToil(A, options) # Check output @@ -108,14 +110,49 @@ def testStatic2(self, tmp_path: Path) -> None: options = Job.Runner.getDefaultOptions(tmp_path / "jobstore") options.logLevel = "INFO" options.retryCount = 100 + options.retry_backoff_seconds = 0 options.badWorker = 0.5 options.badWorkerFailInterval = 0.01 - # Run the workflow, the return value being the number of failed jobs + # Run the workflow, or raise if it does not succeed Job.Runner.startToil(A, options) # Check output assert open(outFile).readline() == "ABCDE" + def test_retry_backoff(self, tmp_path: Path): + """ + Make sure jobs retry with exponential backoff. + """ + + time_file = tmp_path / "times.txt" + + # Make a job that will log the time when it attempts to run, and that + # always fails. + MainJob = Job.wrapFn(time_recording_fn, time_file) + + options = Job.Runner.getDefaultOptions(tmp_path / "jobstore") + options.logLevel = "INFO" + options.retryCount = 4 # Plus the one original try + options.retry_backoff_seconds = 1.0 + options.retry_backoff_factor = 2.0 + + try: + # Run the workflow, or raise if it does not succeed + Job.Runner.startToil(MainJob, options) + except FailedJobsException as e: + # This is expected + pass + + assert time_file.exists() + + time_values = [float(l) for l in open(time_file) if len(l) > 0] + + assert len(time_values) == 5 + assert time_values[1] - time_values[0] >= 1.0 + assert time_values[2] - time_values[1] >= 2.0 + assert time_values[3] - time_values[2] >= 4.0 + assert time_values[4] - time_values[3] >= 8.0 + @slow @pytest.mark.slow def testTrivialDAGConsistency(self, tmp_path: Path) -> None: @@ -806,6 +843,13 @@ def fn2Test(pStrings: list[str], s: str, outputFile: Path) -> str: fH.write(" ".join(pStrings) + " " + s) return s +def time_recording_fn(path) -> None: + """ + Function to record the times it runs to a file, and fail. + """ + with open(path, "a") as fp: + fp.write(f"{time.time()}\n") + raise RuntimeError("Refuse to succeed!") def trivialParent(job: Job) -> None: strandedJob = JobFunctionWrappingJob(child) diff --git a/src/toil/worker.py b/src/toil/worker.py index fde05bd5dc..1838ba5165 100644 --- a/src/toil/worker.py +++ b/src/toil/worker.py @@ -520,7 +520,7 @@ def blockFn() -> bool: # Reduce the try count if jobDesc.remainingTryCount < 0: raise RuntimeError("The try count of the job cannot be negative.") - jobDesc.remainingTryCount = max(0, jobDesc.remainingTryCount - 1) + jobDesc.chargeRetry() jobDesc.restartCheckpoint(job_store) # Otherwise, the job and successors are done, and we can cleanup stuff we couldn't clean # because of the job being a checkpoint From 09abe841b0b725982a11415bffef29d1fbc6e210 Mon Sep 17 00:00:00 2001 From: Adam Novak Date: Thu, 12 Feb 2026 15:42:14 -0800 Subject: [PATCH 2/5] Satisfy MyPy --- src/toil/lib/threading.py | 10 +++++++--- src/toil/test/src/jobTest.py | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/toil/lib/threading.py b/src/toil/lib/threading.py index ab20f080d0..90faea876f 100644 --- a/src/toil/lib/threading.py +++ b/src/toil/lib/threading.py @@ -316,9 +316,13 @@ def cpu_count() -> int: proc_info = psutil.Process() if hasattr(proc_info, "cpu_affinity"): try: - logger.debug("CPU affinity available") - affinity_size = len(proc_info.cpu_affinity()) - logger.debug("CPU affinity is restricted to %d cores", affinity_size) + affinity = proc_info.cpu_affinity() + if affinity is not None: + affinity_size = len(affinity) + logger.debug("CPU affinity is restricted to %d cores", affinity_size) + else: + # Somehow this returned None, which MyPy thinks it might + logger.debug("CPU affinity appears available but isn't") except: # We can't actually read this even though it exists. logger.debug( diff --git a/src/toil/test/src/jobTest.py b/src/toil/test/src/jobTest.py index 299722c696..a1beb34b16 100644 --- a/src/toil/test/src/jobTest.py +++ b/src/toil/test/src/jobTest.py @@ -119,7 +119,7 @@ def testStatic2(self, tmp_path: Path) -> None: # Check output assert open(outFile).readline() == "ABCDE" - def test_retry_backoff(self, tmp_path: Path): + def test_retry_backoff(self, tmp_path: Path) -> None: """ Make sure jobs retry with exponential backoff. """ @@ -843,7 +843,7 @@ def fn2Test(pStrings: list[str], s: str, outputFile: Path) -> str: fH.write(" ".join(pStrings) + " " + s) return s -def time_recording_fn(path) -> None: +def time_recording_fn(path: Path) -> None: """ Function to record the times it runs to a file, and fail. """ From 7c71459cb679a44e66f08832036f023d1bdae5a4 Mon Sep 17 00:00:00 2001 From: Adam Novak Date: Thu, 12 Feb 2026 15:46:26 -0800 Subject: [PATCH 3/5] Actually save changes to the backoff help text --- src/toil/options/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/toil/options/common.py b/src/toil/options/common.py index c39ed5bdbb..d1644fc7cb 100644 --- a/src/toil/options/common.py +++ b/src/toil/options/common.py @@ -872,7 +872,7 @@ def __call__( action=make_open_interval_action(0), metavar="FLOAT", help=f"Number of seconds to wait when first retrying a job. " - f"default={10}", + f"default={2}", ) job_options.add_argument( "--retryBackoffFactor", @@ -882,7 +882,7 @@ def __call__( action=make_open_interval_action(1), metavar="FLOAT", help=f"Factor to increase retry backof time by for each " - f"additional retry. default={2}", + f"additional retry. default={3}", ) job_options.add_argument( "--stopOnFirstFailure", From 519c4f25be5e0caa7b37cdb11bc8779f2db54799 Mon Sep 17 00:00:00 2001 From: Adam Novak Date: Thu, 19 Feb 2026 15:31:08 -0500 Subject: [PATCH 4/5] Finish renaming field to _retry_backoff_seconds --- src/toil/job.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/toil/job.py b/src/toil/job.py index fb6ef6432e..75abfcd2d0 100644 --- a/src/toil/job.py +++ b/src/toil/job.py @@ -1482,10 +1482,10 @@ def resetRetries(self) -> bool: :returns: True if a modification to the JobDescription was made, and False otherwise. """ - if self._remainingTryCount is not None or self._retryBackoff is not None: + if self._remainingTryCount is not None or self._retry_backoff_seconds is not None: # We had a value stored self._remainingTryCount = None - self._retryBackoff = None + self._retry_backoff_seconds = None return True else: # No change needed From aa9fe05b47711b0d0caab48c2760cdd7aa2ccf9e Mon Sep 17 00:00:00 2001 From: Adam Novak Date: Thu, 19 Feb 2026 20:28:04 -0500 Subject: [PATCH 5/5] Add semi-synthetic typing to job.py I tried to get Anthropic Claude to un-ignore typing in job.py. It thought for half an hour, and then I spent 4 hours cleaning up after it. But it successfully tricked me into typing job.py, so there's that, I guess. --- contrib/admin/mypy-with-ignore.py | 1 - src/toil/cwl/cwltoil.py | 2 - src/toil/deferred.py | 6 +- src/toil/job.py | 485 +++++++++++++++---------- src/toil/jobStores/abstractJobStore.py | 4 +- src/toil/lib/expando.py | 3 +- src/toil/statsAndLogging.py | 6 +- src/toil/test/src/jobServiceTest.py | 4 +- src/toil/test/src/jobTest.py | 11 +- src/toil/test/src/retainTempDirTest.py | 6 +- src/toil/toilState.py | 5 +- src/toil/wdl/wdltoil.py | 3 +- src/toil/worker.py | 14 +- 13 files changed, 330 insertions(+), 220 deletions(-) diff --git a/contrib/admin/mypy-with-ignore.py b/contrib/admin/mypy-with-ignore.py index 9a9cec926c..15004c611b 100755 --- a/contrib/admin/mypy-with-ignore.py +++ b/contrib/admin/mypy-with-ignore.py @@ -29,7 +29,6 @@ def main(): 'docs/conf.py', 'docs/vendor/sphinxcontrib/fulltoc.py', 'docs/vendor/sphinxcontrib/__init__.py', - 'src/toil/job.py', 'src/toil/leader.py', 'src/toil/__init__.py', 'src/toil/deferred.py', diff --git a/src/toil/cwl/cwltoil.py b/src/toil/cwl/cwltoil.py index e941074c68..b53d0f9f73 100644 --- a/src/toil/cwl/cwltoil.py +++ b/src/toil/cwl/cwltoil.py @@ -2939,14 +2939,12 @@ def makeRootJob( # Get metadata for non-tool input files input_metadata = get_file_sizes( input_filenames, - toil._jobStore, include_remote_files=options.reference_inputs, ) # Also get metadata for tool input files, so we can resilve them to candidate URIs tool_metadata = get_file_sizes( input_filenames, - toil._jobStore, include_remote_files=options.reference_inputs, ) diff --git a/src/toil/deferred.py b/src/toil/deferred.py index 838d0842d9..e0dcecaf94 100644 --- a/src/toil/deferred.py +++ b/src/toil/deferred.py @@ -17,6 +17,8 @@ import tempfile from collections import namedtuple from contextlib import contextmanager +from typing import Any, Callable +from typing_extensions import ParamSpec import dill @@ -40,8 +42,10 @@ class DeferredFunction( True """ + P = ParamSpec("P") + @classmethod - def create(cls, function, *args, **kwargs): + def create(cls, function: Callable[P, Any], *args: P.args, **kwargs: P.kwargs): """ Capture the given callable and arguments as an instance of this class. diff --git a/src/toil/job.py b/src/toil/job.py index 75abfcd2d0..268181cce1 100644 --- a/src/toil/job.py +++ b/src/toil/job.py @@ -30,17 +30,23 @@ from collections.abc import Callable, Iterator, Mapping, Sequence from contextlib import contextmanager from io import BytesIO +from types import ModuleType from typing import ( TYPE_CHECKING, Any, + ContextManager, + IO, + Iterable, Literal, NamedTuple, + NoReturn, TypedDict, TypeVar, Union, cast, overload, ) +from typing_extensions import reveal_type, ParamSpec from urllib.error import HTTPError from urllib.parse import unquote, urljoin, urlsplit @@ -50,6 +56,7 @@ from toil.lib.io import is_remote_url from toil.lib.memoize import memoize from toil.lib.misc import StrPath +from toil.lib.url import URLAccess if sys.version_info < (3, 11): from typing_extensions import NotRequired @@ -66,7 +73,7 @@ from toil.lib.expando import Expando from toil.lib.resources import ResourceMonitor from toil.resource import ModuleDescriptor -from toil.statsAndLogging import set_logging_from_options +from toil.statsAndLogging import set_logging_from_options, StatsDict if TYPE_CHECKING: from optparse import OptionParser @@ -210,7 +217,7 @@ class AcceleratorRequirement(TypedDict): def parse_accelerator( - spec: int | str | dict[str, str | int], + spec: ParseableSingleAcceleratorRequirement, ) -> AcceleratorRequirement: """ Parse an AcceleratorRequirement specified by user code. @@ -253,7 +260,9 @@ def parse_accelerator( APIS = {"cuda", "rocm", "opencl"} parsed: AcceleratorRequirement = {"count": 1, "kind": "gpu"} - + + if isinstance(spec, bytes): + spec = spec.decode("utf-8") if isinstance(spec, int): parsed["count"] = spec elif isinstance(spec, str): @@ -296,7 +305,8 @@ def parse_accelerator( elif isinstance(spec, dict): # It's a dict, so merge with the defaults. parsed.update(cast(AcceleratorRequirement, spec)) - # TODO: make sure they didn't misspell keys or something + # TODO: make sure they didn't misspell keys or provide the wrong types + # of values for them. else: raise TypeError( f"Cannot parse value of type {type(spec)} as an AcceleratorRequirement" @@ -337,7 +347,10 @@ def accelerator_satisfies( :returns: True if the given candidate at least partially satisfies the given requirement (i.e. check all fields other than count). """ - for key in ["kind", "brand", "api", "model"]: + # MyPy needs a lot of cajoling to understand you're allowed to loop over + # TypedDict keys. TODO: Is there a better way to tell it the list contains + # only allowed keys without duplicating them all? + for key in cast(list[Literal["kind", "brand", "api", "model"]], ["kind", "brand", "api", "model"]): if key in ignore: # Skip this aspect. continue @@ -414,15 +427,19 @@ class RequirementsDict(TypedDict): ParsedRequirement = Union[int, float, bool, list[AcceleratorRequirement]] # We define some types for things we can parse into different kind of requirements -ParseableIndivisibleResource = Union[str, int] -ParseableDivisibleResource = Union[str, int, float] -ParseableFlag = Union[str, int, bool] -ParseableAcceleratorRequirement = Union[ +ParseableIndivisibleResource = Union[str, bytes, int] +ParseableDivisibleResource = Union[str, bytes, int, float] +ParseableFlag = Union[str, bytes, int, bool] +ParseableSingleAcceleratorRequirement = Union[ str, + bytes, int, - Mapping[str, Any], - AcceleratorRequirement, - Sequence[Union[str, int, Mapping[str, Any], AcceleratorRequirement]], + Mapping[str, int | str], + AcceleratorRequirement +] +ParseableAcceleratorRequirement = Union[ + ParseableSingleAcceleratorRequirement, + Sequence[ParseableSingleAcceleratorRequirement] ] ParseableRequirement = Union[ @@ -442,7 +459,7 @@ class Requirer: _requirementOverrides: RequirementsDict - def __init__(self, requirements: Mapping[str, ParseableRequirement]) -> None: + def __init__(self, requirements: Mapping[str, ParseableRequirement | None]) -> None: """ Parse and save the given requirements. @@ -464,11 +481,11 @@ def __init__(self, requirements: Mapping[str, ParseableRequirement]) -> None: # Save requirements, parsing and validating anything that needs parsing # or validating. Don't save Nones. - self._requirementOverrides = { + self._requirementOverrides = cast(RequirementsDict, { k: Requirer._parseResource(k, v) for (k, v) in requirements.items() if v is not None - } + }) def assignConfig(self, config: Config) -> None: """ @@ -528,7 +545,7 @@ def __deepcopy__(self, memo: Any) -> Requirer: @overload @staticmethod def _parseResource( - name: Literal["memory"] | Literal["disks"], + name: Literal["memory"] | Literal["disk"], value: ParseableIndivisibleResource, ) -> int: ... @@ -544,6 +561,12 @@ def _parseResource( name: Literal["accelerators"], value: ParseableAcceleratorRequirement ) -> list[AcceleratorRequirement]: ... + @overload + @staticmethod + def _parseResource( + name: Literal["preemptible"], value: ParseableFlag + ) -> bool: ... + @overload @staticmethod def _parseResource(name: str, value: ParseableRequirement) -> ParsedRequirement: ... @@ -589,7 +612,9 @@ def _parseResource( if name in ("memory", "disk", "cores"): # These should be numbers that accept things like "5G". - if isinstance(value, (str, bytes)): + if isinstance(value, bytes): + value = value.decode("utf-8") + if isinstance(value, str): value = human2bytes(value) if isinstance(value, int): return value @@ -626,9 +651,11 @@ def _parseResource( f"The '{name}' requirement does not accept values that are of type {type(value)}" ) elif name == "accelerators": - # The type checking for this is delegated to the - # AcceleratorRequirement class. - if isinstance(value, list): + if isinstance(value, float): + raise TypeError( + f"The '{name}' requirement does not accept values that are of type {type(value)}" + ) + if not isinstance(value, (str, bytes)) and isinstance(value, Sequence): return [ parse_accelerator(v) for v in value ] # accelerators={'kind': 'gpu', 'brand': 'nvidia', 'count': 2} @@ -649,14 +676,14 @@ def _fetchRequirement(self, requirement: str) -> ParsedRequirement | None: :param requirement: The name of the resource """ if requirement in self._requirementOverrides: - value = self._requirementOverrides[requirement] + value: ParsedRequirement | None = self._requirementOverrides[requirement] # type: ignore[literal-required] if value is None: raise AttributeError( f"Encountered explicit None for '{requirement}' requirement of {self}" ) return value elif self._config is not None: - values = [ + values: list[ParsedRequirement | None] = [ getattr(self._config, "default_" + requirement, None), getattr(self._config, "default" + requirement.capitalize(), None), ] @@ -675,7 +702,7 @@ def _fetchRequirement(self, requirement: str) -> ParsedRequirement | None: @property def requirements(self) -> RequirementsDict: """Get dict containing all non-None, non-defaulted requirements.""" - return dict(self._requirementOverrides) + return copy.copy(self._requirementOverrides) @property def disk(self) -> int: @@ -743,7 +770,7 @@ def scale(self, requirement: str, factor: float) -> Requirer: # Make a shallow copy scaled = copy.copy(self) # But make sure it has its own override dictionary - scaled._requirementOverrides = dict(scaled._requirementOverrides) + scaled._requirementOverrides = copy.copy(scaled._requirementOverrides) original_value = getattr(scaled, requirement) if isinstance(original_value, (int, float)): @@ -760,9 +787,9 @@ def scale(self, requirement: str, factor: float) -> Requirer: def requirements_string(self) -> str: """Get a nice human-readable string of our requirements.""" - parts = [] + parts: list[str] = [] for k in REQUIREMENT_NAMES: - v = self._fetchRequirement(k) + v: str | ParsedRequirement | None = self._fetchRequirement(k) if v is not None: if isinstance(v, (int, float)) and v > 1000: # Make large numbers readable @@ -780,8 +807,8 @@ class JobBodyReference(NamedTuple): file_store_id: str """File ID (or special shared file name for the root job) of the job's body.""" - module_string: str - """Stringified description of the module needed to load the body.""" + module_command: Sequence[str] + """Description of the module needed to load the body.""" class JobDescription(Requirer): @@ -804,7 +831,7 @@ class JobDescription(Requirer): def __init__( self, - requirements: Mapping[str, int | str | float | bool | list], + requirements: Mapping[str, ParseableRequirement | None], jobName: str, unitName: str | None = "", displayName: str | None = "", @@ -869,11 +896,11 @@ def makeString(x: str | bytes | None) -> str: # is reduced each time the job is run, until it is zero, and then no # further attempts to run the job are made. If None, taken as the # default value for this workflow execution. - self._remainingTryCount = None + self._remainingTryCount: int | None = None # The number of seconds to back off before retry. # Gets increased each time the job fails, and reset at workflow restart. - self._retry_backoff_seconds = None + self._retry_backoff_seconds: float | None = None # Holds FileStore FileIDs of the files that should be seen as deleted, # as part of a transaction with the writing of this version of the job @@ -889,7 +916,9 @@ def makeString(x: str | bytes | None) -> str: # # This will be empty at all times except when a new version of a job is # in the process of being committed. - self.filesToDelete = [] + # + # TODO: We should probably refactor to use FileID here. + self.filesToDelete: list[str] = [] # Holds job names and IDs of the jobs that have been chained into this # job, and which should be deleted when this job finally is deleted @@ -912,7 +941,7 @@ def makeString(x: str | bytes | None) -> str: # after the job is scheduled, so we don't have to worry about # conflicting updates from workers. # TODO: Move into ToilState itself so leader stops mutating us so much? - self.predecessorsFinished = set() + self.predecessorsFinished: set[str] = set() # Note that we don't hold IDs of our predecessors. Predecessors know # about us, and not the other way around. Otherwise we wouldn't be able @@ -921,25 +950,25 @@ def makeString(x: str | bytes | None) -> str: # The IDs of all child jobs of the described job. # Children which are done must be removed with filterSuccessors. - self.childIDs: set[str] = set() + self.childIDs: set[str | TemporaryID] = set() # The IDs of all follow-on jobs of the described job. # Follow-ons which are done must be removed with filterSuccessors. - self.followOnIDs: set[str] = set() + self.followOnIDs: set[str | TemporaryID] = set() # We keep our own children and follow-ons in a list of successor # phases, along with any successors adopted from jobs we have chained # from. When we finish our own children and follow-ons, we may have to # go back and finish successors for those jobs. - self.successor_phases: list[set[str]] = [self.followOnIDs, self.childIDs] + self.successor_phases: list[set[str | TemporaryID]] = [self.followOnIDs, self.childIDs] # Dict from ServiceHostJob ID to list of child ServiceHostJobs that start after it. # All services must have an entry, if only to an empty list. - self.serviceTree = {} + self.serviceTree: dict[str | TemporaryID, list[str | TemporaryID]] = {} # A jobStoreFileID of the log file for a job. This will be None unless # the job failed and the logging has been captured to be reported on the leader. - self.logJobStoreFileID = None + self.logJobStoreFileID: str | None = None # Every time we update a job description in place in the job store, we # increment this. @@ -985,9 +1014,13 @@ def serviceHostIDsInBatches(self) -> Iterator[list[str]]: (in the order they need to start in) """ + # At this point, nothing in serviceTree should still have a TemporaryID. + # But we can't really explain that to MyPy with narrowing. + # TODO: Is there a way? + tree = cast(dict[str, list[str]], self.serviceTree) # First start all the jobs with no parent - roots = set(self.serviceTree.keys()) - for _parent, children in self.serviceTree.items(): + roots = set(tree.keys()) + for parent, children in tree.items(): for child in children: roots.remove(child) batch = list(roots) @@ -999,7 +1032,7 @@ def serviceHostIDsInBatches(self) -> Iterator[list[str]]: nextBatch = [] for started in batch: # Go find all the children that can start now that we have started. - for child in self.serviceTree[started]: + for child in tree[started]: nextBatch.append(child) batch = nextBatch @@ -1009,8 +1042,11 @@ def serviceHostIDsInBatches(self) -> Iterator[list[str]]: def successorsAndServiceHosts(self) -> Iterator[str]: """Get an iterator over all child, follow-on, and service job IDs.""" - - return itertools.chain(self.allSuccessors(), self.serviceTree.keys()) + # At this point, nothing in serviceTree should still have a TemporaryID. + # But we can't really explain that to MyPy with narrowing. + # TODO: Is there a way? + tree = cast(dict[str, list[str]], self.serviceTree) + return itertools.chain(self.allSuccessors(), tree.keys()) def allSuccessors(self) -> Iterator[str]: """ @@ -1020,7 +1056,9 @@ def allSuccessors(self) -> Iterator[str]: """ for phase in self.successor_phases: - yield from phase + # At this point, TemporaryIDs should have been removed. + # TODO: enforce this + yield from cast(Iterable[str], phase) def successors_by_phase(self) -> Iterator[tuple[int, str]]: """ @@ -1031,17 +1069,25 @@ def successors_by_phase(self) -> Iterator[tuple[int, str]]: for i, phase in enumerate(self.successor_phases): for successor in phase: + assert not isinstance(successor, TemporaryID) yield i, successor @property - def services(self): + def services(self) -> list[str]: """ Get a collection of the IDs of service host jobs for this job, in arbitrary order. Will be empty if the job has no unfinished services. """ - return list(self.serviceTree.keys()) + # At this point, nothing in serviceTree should still have a TemporaryID. + # But we can't really explain that to MyPy with narrowing. + # TODO: Is there a way? + tree = cast(dict[str, list[str]], self.serviceTree) + return list(tree.keys()) + # TODO: This should narrow _body to not-None, but we can't make this a + # TypeGuard or TypeIs because they aren't allowed to work on the self + # parameter. def has_body(self) -> bool: """ Returns True if we have a job body associated, and False otherwise. @@ -1079,9 +1125,10 @@ def get_body(self) -> tuple[str, ModuleDescriptor]: if not self.has_body(): raise RuntimeError(f"Cannot load the body of a job {self} without one") - - return self._body.file_store_id, ModuleDescriptor.fromCommand( - self._body.module_string + # TODO: We can't get has_body() to narrow _body for us. + body = cast(JobBodyReference, self._body) + return body.file_store_id, ModuleDescriptor.fromCommand( + body.module_command ) def nextSuccessors(self) -> set[str] | None: @@ -1103,7 +1150,9 @@ def nextSuccessors(self) -> set[str] | None: for phase in reversed(self.successor_phases): if len(phase) > 0: # Rightmost phase that isn't empty - return phase + # At this point, TemporaryIDs should have been removed. + # TODO: enforce this + return cast(set[str], phase) # If no phase isn't empty, we're done. return None @@ -1118,6 +1167,8 @@ def filterSuccessors(self, predicate: Callable[[str], bool]) -> None: for phase in self.successor_phases: for successor_id in list(phase): + # At this point, TemporaryIDs should have been removed. + assert not isinstance(successor_id, TemporaryID) if not predicate(successor_id): phase.remove(successor_id) self.successor_phases = [p for p in self.successor_phases if len(p) > 0] @@ -1273,15 +1324,15 @@ def is_updated_by(self, other: JobDescription) -> bool: return True - def addChild(self, childID: str) -> None: + def addChild(self, childID: str | TemporaryID) -> None: """Make the job with the given ID a child of the described job.""" self.childIDs.add(childID) - def addFollowOn(self, followOnID: str) -> None: + def addFollowOn(self, followOnID: str | TemporaryID) -> None: """Make the job with the given ID a follow-on of the described job.""" self.followOnIDs.add(followOnID) - def addServiceHostJob(self, serviceID, parentServiceID=None): + def addServiceHostJob(self, serviceID: str | TemporaryID, parentServiceID: str | TemporaryID | None = None) -> None: """ Make the ServiceHostJob with the given ID a service of the described job. @@ -1295,15 +1346,15 @@ def addServiceHostJob(self, serviceID, parentServiceID=None): if parentServiceID is not None: self.serviceTree[parentServiceID].append(serviceID) - def hasChild(self, childID: str) -> bool: + def hasChild(self, childID: str | TemporaryID) -> bool: """Return True if the job with the given ID is a child of the described job.""" return childID in self.childIDs - def hasFollowOn(self, followOnID: str) -> bool: + def hasFollowOn(self, followOnID: str | TemporaryID) -> bool: """Test if the job with the given ID is a follow-on of the described job.""" return followOnID in self.followOnIDs - def hasServiceHostJob(self, serviceID) -> bool: + def hasServiceHostJob(self, serviceID: str | TemporaryID) -> bool: """Test if the ServiceHostJob is a service of the described job.""" return serviceID in self.serviceTree @@ -1323,9 +1374,17 @@ def renameReferences(self, renames: dict[TemporaryID, str]) -> None: # Replace each renamed item one at a time to preserve set identity phase.remove(item) phase.add(renames[item]) + # TODO: MyPy won't let us .get() with types other than the dict key + # type, even though that sort of makes sense because only objects of + # the actual type can *be* in the dict. Apparently it thinks dict + # doesn't guarantee membership testing for arbitrary types. This makes + # applying the renames fiddly. + # + # We should see if we can find a way to get away from this terrible + # codegen'd code. self.serviceTree = { - renames.get(parent, parent): [ - renames.get(child, child) for child in children + (renames[parent] if isinstance(parent, TemporaryID) and parent in renames else parent): [ + (renames[child] if isinstance(child, TemporaryID) and child in renames else child) for child in children ] for parent, children in self.serviceTree.items() } @@ -1354,8 +1413,11 @@ def chargeRetry(self) -> None: On completion, self.retry_backoff_seconds will be the time to wait before the next retry. + + Can only be used once the config has been attached. """ self.remainingTryCount = max(0, self.remainingTryCount - 1) + assert self._config is not None if self._retry_backoff_seconds is None: # This was the first retry self._retry_backoff_seconds = self._config.retry_backoff_seconds @@ -1434,16 +1496,17 @@ def setupJobAfterFailure( self.disk, ) - def getLogFileHandle(self, jobStore): + def getLogFileHandle(self, jobStore: AbstractJobStore) -> Any: """ Create a context manager that yields a file handle to the log file. Assumes logJobStoreFileID is set. """ + assert self.logJobStoreFileID is not None return jobStore.read_file_stream(self.logJobStoreFileID) @property - def remainingTryCount(self): + def remainingTryCount(self) -> int: """ Get the number of tries remaining. @@ -1460,7 +1523,7 @@ def remainingTryCount(self): raise AttributeError(f"Try count for {self} cannot be determined") @remainingTryCount.setter - def remainingTryCount(self, val): + def remainingTryCount(self, val: int) -> None: self._remainingTryCount = val @property @@ -1508,7 +1571,7 @@ def __str__(self) -> str: # There really should only ever be one true version of a JobDescription at # a time, keyed by jobStoreID. - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}( **{self.__dict__!r} )" def reserve_versions(self, count: int) -> None: @@ -1533,7 +1596,7 @@ def pre_update_hook(self) -> None: class ServiceJobDescription(JobDescription): """A description of a job that hosts a service.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """Create a ServiceJobDescription to describe a ServiceHostJob.""" # Make the base JobDescription super().__init__(*args, **kwargs) @@ -1552,7 +1615,7 @@ def __init__(self, *args, **kwargs): # should terminate signaling an error. self.errorJobStoreID: str | None = None - def onRegistration(self, jobStore): + def onRegistration(self, jobStore: AbstractJobStore) -> None: """ Setup flag files. @@ -1570,7 +1633,7 @@ class CheckpointJobDescription(JobDescription): A description of a job that is a checkpoint. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """Create a CheckpointJobDescription to describe a checkpoint job.""" # Make the base JobDescription super().__init__(*args, **kwargs) @@ -1581,9 +1644,9 @@ def __init__(self, *args, **kwargs): self.checkpoint: JobBodyReference | None = None # Files that can not be deleted until the job and its successors have completed - self.checkpointFilesToDelete = [] + self.checkpointFilesToDelete: list[str] = [] - def set_checkpoint(self) -> str: + def set_checkpoint(self) -> None: """ Save a body checkpoint into self.checkpoint """ @@ -1614,7 +1677,7 @@ def restartCheckpoint(self, jobStore: AbstractJobStore) -> list[str]: raise RuntimeError( "Cannot restart a checkpoint job. The checkpoint was never set." ) - successorsDeleted = [] + successorsDeleted: list[str] = [] all_successors = list(self.allSuccessors()) if len(all_successors) > 0 or self.serviceTree or self.has_body(): if self.has_body(): @@ -1638,7 +1701,7 @@ def restartCheckpoint(self, jobStore: AbstractJobStore) -> list[str]: # Delete everything on the stack, as these represent successors to clean # up as we restart the queue - def recursiveDelete(jobDesc): + def recursiveDelete(jobDesc: JobDescription) -> None: # Recursive walk the stack to delete all remaining jobs for otherJobID in jobDesc.successorsAndServiceHosts(): if jobStore.job_exists(otherJobID): @@ -1651,6 +1714,7 @@ def recursiveDelete(jobDesc): "Checkpoint is deleting old successor job: %s", jobDesc.jobStoreID, ) + assert not isinstance(jobDesc.jobStoreID, TemporaryID) jobStore.delete_job(jobDesc.jobStoreID) successorsDeleted.append(jobDesc.jobStoreID) @@ -1663,6 +1727,9 @@ def recursiveDelete(jobDesc): jobStore.update_job(self) return successorsDeleted +# Job methods to add children/follow-ons return the child/follow-on again, so +# we need a TypeVar to track that. +JobType = TypeVar("JobType", bound="Job") class Job: """ @@ -1680,7 +1747,7 @@ def __init__( unitName: str | None = "", checkpoint: bool | None = False, displayName: str | None = "", - descriptionClass: type | None = None, + descriptionClass: type[JobDescription] | None = None, local: bool | None = None, files: set[FileID] | None = None, ) -> None: @@ -1743,7 +1810,7 @@ def __init__( # Create the JobDescription that owns all the scheduling information. # Make it with a temporary ID until we can be assigned a real one by # the JobStore. - self._description = descriptionClass( + self._description: JobDescription | None = descriptionClass( requirements, jobName, unitName=unitName, @@ -1767,7 +1834,7 @@ def __init__( # while the user is creating the job graphs, to check for duplicate # relationships and to let EncapsulatedJob magically add itself as a # child. Note that this stores actual Job objects, to call addChild on. - self._directPredecessors = set() + self._directPredecessors: set[Job] = set() # Note that self.__module__ is not necessarily this module, i.e. job.py. It is the module # defining the class self is an instance of, which may be a subclass of Job that may be @@ -1780,11 +1847,11 @@ def __init__( # traverses a nested data structure of lists, dicts, tuples or any other type supporting # the __getitem__() protocol.. The special key `()` (the empty tuple) represents the # entire return value. - self._rvs = collections.defaultdict(list) - self._promiseJobStore = None - self._fileStore = None - self._defer = None - self._tempDir = None + self._rvs: dict[tuple[Any, ...], list[str]] = collections.defaultdict(list) + self._promiseJobStore: AbstractJobStore | None = None + self._fileStore: AbstractFileStore | None = None + self._defer: Callable[[Any], None] | None = None + self._tempDir: str | None = None # Holds flags set by set_debug_flag() self._debug_flags: set[str] = set() @@ -1818,11 +1885,13 @@ def check_initialized(self) -> None: def jobStoreID(self) -> str | TemporaryID: """Get the ID of this Job.""" # This is managed by the JobDescription. + assert self._description is not None return self._description.jobStoreID @property def description(self) -> JobDescription: """Expose the JobDescription that describes this job.""" + assert self._description is not None return self._description # Instead of being a Requirer ourselves, we pass anything about @@ -1861,21 +1930,23 @@ def accelerators(self) -> list[AcceleratorRequirement]: return self.description.accelerators @accelerators.setter - def accelerators(self, val: list[ParseableAcceleratorRequirement]) -> None: + def accelerators(self, val: ParseableAcceleratorRequirement) -> None: self.description.accelerators = val @property def preemptible(self) -> bool: """Whether the job can be run on a preemptible node.""" return self.description.preemptible - - @deprecated(new_function_name="preemptible") - def preemptable(self) -> bool: - return self.description.preemptible - + @preemptible.setter def preemptible(self, val: bool) -> None: self.description.preemptible = val + + # Note that unless the two halves of a property are *immediately* adjacent, + # MyPy throws an error. So the old version has to come later. + @deprecated(new_function_name="preemptible") + def preemptable(self) -> bool: + return self.description.preemptible @property def checkpoint(self) -> bool: @@ -1952,7 +2023,7 @@ def _jobGraphsJoined(self, other: Job) -> None: # Point all their jobs at the new combined registry job._registry = self._registry - def addChild(self, childJob: Job) -> Job: + def addChild(self, childJob: JobType) -> JobType: """ Add a childJob to be run as child of this job. @@ -1971,7 +2042,7 @@ def addChild(self, childJob: Job) -> Job: # Join the job graphs self._jobGraphsJoined(childJob) # Remember the child relationship - self._description.addChild(childJob.jobStoreID) + self.description.addChild(childJob.jobStoreID) # Record the temporary back-reference childJob._addPredecessor(self) @@ -1983,9 +2054,9 @@ def hasChild(self, childJob: Job) -> bool: :return: True if childJob is a child of the job, else False. """ - return self._description.hasChild(childJob.jobStoreID) + return self.description.hasChild(childJob.jobStoreID) - def addFollowOn(self, followOnJob: Job) -> Job: + def addFollowOn(self, followOnJob: JobType) -> JobType: """ Add a follow-on job. @@ -2003,7 +2074,7 @@ def addFollowOn(self, followOnJob: Job) -> Job: # Join the job graphs self._jobGraphsJoined(followOnJob) # Remember the follow-on relationship - self._description.addFollowOn(followOnJob.jobStoreID) + self.description.addFollowOn(followOnJob.jobStoreID) # Record the temporary back-reference followOnJob._addPredecessor(self) @@ -2019,7 +2090,7 @@ def hasFollowOn(self, followOnJob: Job) -> bool: :return: True if the followOnJob is a follow-on of this job, else False. """ - return self._description.hasChild(followOnJob.jobStoreID) + return self.description.hasChild(followOnJob.jobStoreID) def addService( self, service: Job.Service, parentService: Job.Service | None = None @@ -2056,7 +2127,7 @@ def addService( self._jobGraphsJoined(hostingJob) # Record the relationship to the hosting job, with its parent if any. - self._description.addServiceHostJob( + self.description.addServiceHostJob( hostingJob.jobStoreID, parentService.hostID if parentService is not None else None, ) @@ -2071,13 +2142,22 @@ def addService( def hasService(self, service: Job.Service) -> bool: """Return True if the given Service is a service of this job, and False otherwise.""" - return service.hostID is None or self._description.hasServiceHostJob( + return service.hostID is None or self.description.hasServiceHostJob( service.hostID ) # Convenience functions for creating jobs - def addChildFn(self, fn: Callable, *args, **kwargs) -> FunctionWrappingJob: + + # TODO: We want to take the Callable here, and accept all its arguments and + # keyword arguments *plus* the extra Toil keyword arguments the job types + # have, and also we need promised versions of the arguments to be allowed, + # with the promise inserted at any level of a recursive data structure. + # Neither of those is really possible in MyPy, and we're specifically not + # allowed to fiddle with the kwargs with Concatenate and ParamSpec (see + # https://peps.python.org/pep-0612/#concatenating-keyword-parameters) + + def addChildFn(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> "Job": """ Add a function as a child job. @@ -2093,7 +2173,7 @@ def addChildFn(self, fn: Callable, *args, **kwargs) -> FunctionWrappingJob: else: return self.addChild(FunctionWrappingJob(fn, *args, **kwargs)) - def addFollowOnFn(self, fn: Callable, *args, **kwargs) -> FunctionWrappingJob: + def addFollowOnFn(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> "Job": """ Add a function as a follow-on job. @@ -2109,7 +2189,7 @@ def addFollowOnFn(self, fn: Callable, *args, **kwargs) -> FunctionWrappingJob: else: return self.addFollowOn(FunctionWrappingJob(fn, *args, **kwargs)) - def addChildJobFn(self, fn: Callable, *args, **kwargs) -> FunctionWrappingJob: + def addChildJobFn(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> "Job": """ Add a job function as a child job. @@ -2127,7 +2207,7 @@ def addChildJobFn(self, fn: Callable, *args, **kwargs) -> FunctionWrappingJob: else: return self.addChild(JobFunctionWrappingJob(fn, *args, **kwargs)) - def addFollowOnJobFn(self, fn: Callable, *args, **kwargs) -> FunctionWrappingJob: + def addFollowOnJobFn(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> "Job": """ Add a follow-on job function. @@ -2154,15 +2234,17 @@ def tempDir(self) -> str: :return: Path to tempDir. See `job.fileStore.getLocalTempDir` """ if self._tempDir is None: + assert self._fileStore is not None self._tempDir = self._fileStore.getLocalTempDir() return self._tempDir - def log(self, text: str, level=logging.INFO) -> None: + def log(self, text: str, level: int = logging.INFO) -> None: """Log using :func:`fileStore.log_to_leader`.""" + assert self._fileStore is not None self._fileStore.log_to_leader(text, level) @staticmethod - def wrapFn(fn, *args, **kwargs) -> FunctionWrappingJob: + def wrapFn(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Job: """ Makes a Job out of a function. @@ -2179,7 +2261,7 @@ def wrapFn(fn, *args, **kwargs) -> FunctionWrappingJob: return FunctionWrappingJob(fn, *args, **kwargs) @staticmethod - def wrapJobFn(fn, *args, **kwargs) -> JobFunctionWrappingJob: + def wrapJobFn(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Job: """ Makes a Job out of a job function. @@ -2211,7 +2293,7 @@ def encapsulate(self, name: str | None = None) -> EncapsulatedJob: # job run functions #################################################### - def rv(self, *path) -> Promise: + def rv(self, *path: Any) -> Promise: """ Create a *promise* (:class:`toil.job.Promise`). @@ -2236,7 +2318,7 @@ def rv(self, *path) -> Promise: """ return Promise(self, path) - def registerPromise(self, path): + def registerPromise(self, path: tuple[Any, ...]) -> tuple[str, str]: if self._promiseJobStore is None: # We haven't had a job store set to put our return value into, so # we must not have been hit yet in job topological order. @@ -2267,7 +2349,7 @@ def prepareForPromiseRegistration(self, jobStore: AbstractJobStore) -> None: """ self._promiseJobStore = jobStore - def _disablePromiseRegistration(self): + def _disablePromiseRegistration(self) -> None: """ Called when the job data is about to be saved in the JobStore. @@ -2280,7 +2362,7 @@ def _disablePromiseRegistration(self): # Cycle/connectivity checking #################################################### - def checkJobGraphForDeadlocks(self): + def checkJobGraphForDeadlocks(self) -> None: """ Ensures that a graph of Jobs (that hasn't yet been saved to the JobStore) doesn't contain any pathological relationships between jobs @@ -2361,11 +2443,11 @@ def checkJobGraphAcylic(self) -> None: extraEdges = self._getImpliedEdges(roots) # Check for directed cycles in the augmented graph - visited = set() + visited: set[Job] = set() for root in roots: root._checkJobGraphAcylicDFS([], visited, extraEdges) - def _checkJobGraphAcylicDFS(self, stack, visited, extraEdges): + def _checkJobGraphAcylicDFS(self, stack: list[Job], visited: set[Job], extraEdges: dict[Job, list[Job]]) -> None: """DFS traversal to detect cycles in augmented job graph.""" if self not in visited: visited.add(self) @@ -2386,7 +2468,7 @@ def _checkJobGraphAcylicDFS(self, stack, visited, extraEdges): ) @staticmethod - def _getImpliedEdges(roots) -> dict[Job, list[Job]]: + def _getImpliedEdges(roots: set[Job]) -> dict[Job, list[Job]]: """ Gets the set of implied edges (between children and follow-ons of a common job). @@ -2397,14 +2479,14 @@ def _getImpliedEdges(roots) -> dict[Job, list[Job]]: :returns: dict from Job object to list of Job objects that must be done before it can start. """ # Get nodes (Job objects) in job graph - nodes = set() + nodes: set[Job] = set() for root in roots: root._collectAllSuccessors(nodes) ##For each follow-on edge calculate the extra implied edges # Adjacency list of implied edges, i.e. map of jobs to lists of jobs # connected by an implied edge - extraEdges = {n: [] for n in nodes} + extraEdges: dict[Job, list[Job]] = {n: [] for n in nodes} for job in nodes: # Get all the nonempty successor phases phases = [p for p in job.description.successor_phases if len(p) > 0] @@ -2416,7 +2498,7 @@ def _getImpliedEdges(roots) -> dict[Job, list[Job]]: lower = phases[depth - 1] # Find everything in the upper subtree - reacheable = set() + reacheable: set[Job] = set() for upperID in upper: if upperID in job._registry: # This is a locally added job, not an already-saved job @@ -2455,7 +2537,7 @@ def checkNewCheckpointsAreLeafVertices(self) -> None: ) # Roots jobs of component, these are preexisting jobs in the graph # All jobs in the component of the job graph containing self - jobs = set() + jobs: set[Job] = set() list(map(lambda x: x._collectAllSuccessors(jobs), roots)) # Check for each job for which checkpoint is true that it is a cut vertex or leaf @@ -2470,7 +2552,8 @@ def checkNewCheckpointsAreLeafVertices(self) -> None: # Deferred function system #################################################### - def defer(self, function, *args, **kwargs) -> None: + P = ParamSpec("P") + def defer(self, function: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: """ Register a deferred function, i.e. a callable that will be invoked after the current attempt at running this job concludes. A job attempt is said to conclude when the job @@ -2569,10 +2652,10 @@ def addToilOptions( :param parser: Options object to add toil options to. :param jobstore_as_flag: make the job store option a --jobStore flag instead of a required jobStore positional argument. """ - addOptions(parser, jobstore_as_flag=jobstore_as_flag) + addOptions(cast(ArgumentParser, parser), jobstore_as_flag=jobstore_as_flag) @staticmethod - def startToil(job: Job, options) -> Any: + def startToil(job: Job, options: Namespace) -> Any: """ Run the toil workflow using the given options. @@ -2632,7 +2715,7 @@ def __init__( self.jobName = self.__class__.__name__ # Record that we have as of yet no ServiceHostJob - self.hostID = None + self.hostID: str | TemporaryID | None = None @abstractmethod def start(self, job: ServiceHostJob) -> Any: @@ -2657,6 +2740,7 @@ def stop(self, job: ServiceHostJob) -> None: the fileStore for creating temporary files. """ + @abstractmethod def check(self) -> bool: """ Checks the service is still running. @@ -2667,28 +2751,30 @@ def check(self) -> bool: RuntimeError, not return False! """ - def _addPredecessor(self, predecessorJob): + def _addPredecessor(self, predecessorJob: Job) -> None: """Adds a predecessor job to the set of predecessor jobs.""" if predecessorJob in self._directPredecessors: raise ConflictingPredecessorError(predecessorJob, self) self._directPredecessors.add(predecessorJob) # Record the need for the predecessor to finish - self._description.addPredecessor() + self.description.addPredecessor() @staticmethod - def _isLeafVertex(job): + def _isLeafVertex(job: Job) -> bool: return next(job.description.successorsAndServiceHosts(), None) is None @classmethod - def _loadUserModule(cls, userModule: ModuleDescriptor): + def _loadUserModule(cls, userModule: ModuleDescriptor) -> ModuleType: """ Imports and returns the module object represented by the given module descriptor. """ - return userModule.load() + result = userModule.load() + assert result is not None, f"Failed to load module {userModule}" + return result @classmethod - def _unpickle(cls, userModule, fileHandle, requireInstanceOf=None): + def _unpickle(cls, userModule: ModuleType, fileHandle: Any, requireInstanceOf: type | None = None) -> Any: """ Unpickles an object graph from the given file handle while loading symbols \ referencing the __main__ module from the given userModule instead. @@ -2699,7 +2785,7 @@ def _unpickle(cls, userModule, fileHandle, requireInstanceOf=None): :returns: """ - def filter_main(module_name, class_name): + def filter_main(module_name: str, class_name: str) -> Any: try: if module_name == "__main__": return getattr(userModule, class_name) @@ -2717,7 +2803,7 @@ def filter_main(module_name, class_name): raise class FilteredUnpickler(pickle.Unpickler): - def find_class(self, module, name): + def find_class(self, module: str, name: str) -> Any: return filter_main(module, name) unpickler = FilteredUnpickler(fileHandle) @@ -2733,7 +2819,7 @@ def find_class(self, module, name): def getUserScript(self) -> ModuleDescriptor: return self.userModule - def _fulfillPromises(self, returnValues, jobStore): + def _fulfillPromises(self, returnValues: Any, jobStore: AbstractJobStore) -> None: """ Set the values for promises using the return values from this job's run() function. """ @@ -2784,7 +2870,7 @@ def _fulfillPromises(self, returnValues, jobStore): # Functions associated with Job.checkJobGraphAcyclic to establish that the job graph does not # contain any cycles of dependencies: - def _collectAllSuccessors(self, visited): + def _collectAllSuccessors(self, visited: set[Job]) -> None: """ Add the job and all jobs reachable on a directed path from current node to the given set. @@ -2852,7 +2938,7 @@ def getTopologicalOrderingOfJobs(self) -> list[Job]: # Storing Jobs into the JobStore #################################################### - def _register(self, jobStore) -> list[tuple[TemporaryID, str]]: + def _register(self, jobStore: AbstractJobStore) -> list[tuple[TemporaryID, str]]: """ If this job lacks a JobStore-assigned ID, assign this job an ID. Must be called for each job before it is saved to the JobStore for the first time. @@ -2871,6 +2957,7 @@ def _register(self, jobStore) -> list[tuple[TemporaryID, str]]: # Replace it with a real ID jobStore.assign_job_id(self.description) + assert not isinstance(self.description.jobStoreID, TemporaryID) # Make sure the JobDescription can do its JobStore-related setup. self.description.onRegistration(jobStore) @@ -2893,7 +2980,7 @@ def _renameReferences(self, renames: dict[TemporaryID, str]) -> None: """ # Do renames in the description - self._description.renameReferences(renames) + self.description.renameReferences(renames) def saveBody(self, jobStore: AbstractJobStore) -> None: """ @@ -2922,6 +3009,7 @@ def saveBody(self, jobStore: AbstractJobStore) -> None: # Remember fields we will overwrite description = self._description + assert description is not None registry = self._registry directPredecessors = self._directPredecessors @@ -2936,6 +3024,7 @@ def saveBody(self, jobStore: AbstractJobStore) -> None: self._directPredecessors = set() # Save the body of the job + assert not isinstance(description.jobStoreID, TemporaryID) with jobStore.write_file_stream( description.jobStoreID, cleanup=True ) as (fileHandle, fileStoreID): @@ -2962,14 +3051,14 @@ def saveBody(self, jobStore: AbstractJobStore) -> None: userScript = self.getUserScript().globalize() # Connect the body of the job to the JobDescription - self._description.attach_body(fileStoreID, userScript) + self.description.attach_body(fileStoreID, userScript) def _saveJobGraph( self, jobStore: AbstractJobStore, saveSelf: bool = False, - returnValues: bool = None, - ): + returnValues: Any = None, + ) -> None: """ Save job data and new JobDescriptions to the given job store for this job and all descending jobs, including services. @@ -2992,7 +3081,7 @@ def _saveJobGraph( # and has an ID. Also rewrite ID references. allJobs = list(self._registry.values()) # We use one big dict from fake ID to corresponding real ID to rewrite references. - fakeToReal = {} + fakeToReal: dict[TemporaryID, str] = {} for job in allJobs: # Register the job, get the old ID to new ID pair if any, and save that in the fake to real mapping fakeToReal.update(job._register(jobStore)) @@ -3092,6 +3181,7 @@ def saveAsRootJob(self, jobStore: AbstractJobStore) -> JobDescription: # Store the name of the first job in a file in case of restart. Up to this point the # root job is not recoverable. FIXME: "root job" or "first job", which one is it? + assert not isinstance(self.jobStoreID, TemporaryID) jobStore.set_root_job(self.jobStoreID) # Assign the config from the JobStore as if we were loaded. @@ -3116,17 +3206,16 @@ def loadJob( logger.debug("Loading user module %s.", user_module_descriptor) user_module = cls._loadUserModule(user_module_descriptor) - # Loads context manager using file stream if file_store_id == "firstJob": # This one is actually a shared file name and not a file ID. - manager = job_store.read_shared_file_stream(file_store_id) + stream: ContextManager[IO[bytes]] = job_store.read_shared_file_stream(file_store_id) else: - manager = job_store.read_file_stream(file_store_id) + stream = job_store.read_file_stream(file_store_id) - # Open and unpickle - with manager as file_handle: + # Enter closing context and unpickle + with stream as file_handle: - job = cls._unpickle(user_module, file_handle, requireInstanceOf=Job) + job: Job = cls._unpickle(user_module, file_handle, requireInstanceOf=Job) # Fill in the current description job._description = job_description @@ -3135,7 +3224,7 @@ def loadJob( return job - def _run(self, jobGraph=None, fileStore=None, **kwargs): + def _run(self, jobGraph: Any = None, fileStore: AbstractFileStore | None = None, **kwargs: Any) -> Any: """ Function which worker calls to ultimately invoke a job's Job.run method, and then handle created @@ -3157,10 +3246,13 @@ def _run(self, jobGraph=None, fileStore=None, **kwargs): :param toil.fileStores.abstractFileStore.AbstractFileStore fileStore: the FileStore to use to access files when running the job. Required. """ + # TODO: Can't we drop compatibility with extremely old Cactus and make + # fileStore positional and non-nullably typed? + assert fileStore is not None return self.run(fileStore) @contextmanager - def _executor(self, stats, fileStore): + def _executor(self, stats: StatsDict, fileStore: AbstractFileStore) -> Iterator[None]: """ This is the core wrapping method for running the job within a worker. It sets up the stats and logging before yielding. After completion of the body, the function will finish up the @@ -3193,6 +3285,7 @@ def _executor(self, stats, fileStore): fileStore.deleteGlobalFile(FileID(jobStoreFileID, 0)) else: # Else copy them to the job description to delete later + assert isinstance(self.description, CheckpointJobDescription) self.description.checkpointFilesToDelete = list(Promise.filesToDelete) Promise.filesToDelete.clear() # Now indicate the asynchronous update of the job can happen @@ -3248,7 +3341,7 @@ def _runner( jobStore: AbstractJobStore, fileStore: AbstractFileStore, defer: Callable[[Any], None], - **kwargs, + **kwargs: Any, ) -> None: """ Run the job, and serialise the next jobs. @@ -3293,11 +3386,11 @@ def _runner( # That and the new child/follow-on relationships will need to be # recorded later by an update() of the JobDescription. - def _jobName(self): + def _jobName(self) -> str: """ :rtype : string, used as identifier of the job class in the stats report. """ - return self._description.displayName + return self.description.displayName def set_debug_flag(self, flag: str) -> None: """ @@ -3347,7 +3440,7 @@ class JobGraphDeadlockException(JobException): dependency, such as a cycle. See :func:`toil.job.Job.checkJobGraphForDeadlocks`. """ - def __init__(self, string): + def __init__(self, string: str) -> None: super().__init__(string) @@ -3357,7 +3450,7 @@ class FunctionWrappingJob(Job): """ def __init__( - self, userFunction: Callable[[...], Any], *args: Any, **kwargs: Any + self, userFunction: Callable[..., Any], *args: Any, **kwargs: Any ) -> None: """ :param callable userFunction: The function to wrap. It will be called with ``*args`` and @@ -3381,7 +3474,7 @@ def __init__( list(zip(argSpec.args[-len(argSpec.defaults) :], argSpec.defaults)) ) - def resolve(key, default: Any | None = None, dehumanize: bool = False) -> Any: + def resolve(key: str, default: Any | None = None, dehumanize: bool = False) -> Any: try: # First, try constructor arguments, ... value = kwargs.pop(key) @@ -3422,13 +3515,16 @@ def _getUserFunction(self) -> Callable[..., Any]: self.userFunctionModule, ) userFunctionModule = self._loadUserModule(self.userFunctionModule) - return getattr(userFunctionModule, self.userFunctionName) + user_function = getattr(userFunctionModule, self.userFunctionName) + assert callable(user_function) + # TODO: The assert dhould narrow, but doesn't. See https://github.com/python/mypy/issues/20748 + return cast(Callable[..., Any], user_function) def run(self, fileStore: AbstractFileStore) -> Any: userFunction = self._getUserFunction() return userFunction(*self._args, **self._kwargs) - def getUserScript(self) -> str: + def getUserScript(self) -> ModuleDescriptor: return self.userFunctionModule def _jobName(self) -> str: @@ -3470,6 +3566,7 @@ class JobFunctionWrappingJob(FunctionWrappingJob): @property def fileStore(self) -> AbstractFileStore: + assert self._fileStore is not None return self._fileStore def run(self, fileStore: AbstractFileStore) -> Any: @@ -3485,7 +3582,7 @@ class PromisedRequirementFunctionWrappingJob(FunctionWrappingJob): resource requirements. """ - def __init__(self, userFunction, *args, **kwargs): + def __init__(self, userFunction: Callable[..., Any], *args: Any, **kwargs: Any) -> None: self._promisedKwargs = kwargs.copy() # Replace resource requirements in intermediate job with small values. kwargs.update( @@ -3501,7 +3598,7 @@ def __init__(self, userFunction, *args, **kwargs): super().__init__(userFunction, *args, **kwargs) @classmethod - def create(cls, userFunction, *args, **kwargs): + def create(cls, userFunction: Callable[..., Any], *args: Any, **kwargs: Any) -> EncapsulatedJob: """ Creates an encapsulated Toil job function with unfulfilled promised resource requirements. After the promises are fulfilled, a child job function is created @@ -3511,13 +3608,13 @@ def create(cls, userFunction, *args, **kwargs): """ return EncapsulatedJob(cls(userFunction, *args, **kwargs)) - def run(self, fileStore): + def run(self, fileStore: AbstractFileStore) -> Any: # Assumes promises are fulfilled when parent job is run self.evaluatePromisedRequirements() userFunction = self._getUserFunction() return self.addChildFn(userFunction, *self._args, **self._promisedKwargs).rv() - def evaluatePromisedRequirements(self): + def evaluatePromisedRequirements(self) -> None: # Fulfill resource requirement promises for requirement in REQUIREMENT_NAMES: try: @@ -3535,7 +3632,7 @@ class PromisedRequirementJobFunctionWrappingJob(PromisedRequirementFunctionWrapp See :class:`toil.job.JobFunctionWrappingJob` """ - def run(self, fileStore): + def run(self, fileStore: AbstractFileStore) -> Any: self.evaluatePromisedRequirements() userFunction = self._getUserFunction() return self.addChildJobFn( @@ -3574,6 +3671,9 @@ def __init__(self, job: Job | None, unitName: str | None = None) -> None: :param str unitName: human-readable name to identify this job instance. """ + self.encapsulatedJob: Job | None + self.encapsulatedFollowOn: Job | None + if job is not None: # Initial construction, when encapsulating a job @@ -3602,14 +3702,14 @@ def __init__(self, job: Job | None, unitName: str | None = None) -> None: self.encapsulatedJob = None self.encapsulatedFollowOn = None - def addChild(self, childJob: Job) -> Job: + def addChild(self, childJob: JobType) -> JobType: if self.encapsulatedFollowOn is None: raise RuntimeError( "Children cannot be added to EncapsulatedJob while it is running" ) return Job.addChild(self.encapsulatedFollowOn, childJob) - def addService(self, service, parentService=None): + def addService(self, service: Job.Service, parentService: Job.Service | None = None) -> Promise: if self.encapsulatedFollowOn is None: raise RuntimeError( "Services cannot be added to EncapsulatedJob while it is running" @@ -3618,19 +3718,19 @@ def addService(self, service, parentService=None): self.encapsulatedFollowOn, service, parentService=parentService ) - def addFollowOn(self, followOnJob: Job) -> Job: + def addFollowOn(self, followOnJob: JobType) -> JobType: if self.encapsulatedFollowOn is None: raise RuntimeError( "Follow-ons cannot be added to EncapsulatedJob while it is running" ) return Job.addFollowOn(self.encapsulatedFollowOn, followOnJob) - def rv(self, *path) -> Promise: + def rv(self, *path: Any) -> Promise: if self.encapsulatedJob is None: raise RuntimeError("The encapsulated job was not set.") return self.encapsulatedJob.rv(*path) - def prepareForPromiseRegistration(self, jobStore): + def prepareForPromiseRegistration(self, jobStore: AbstractJobStore) -> None: # This one will be called after execution when re-serializing the # (unchanged) graph of jobs rooted here. super().prepareForPromiseRegistration(jobStore) @@ -3638,13 +3738,13 @@ def prepareForPromiseRegistration(self, jobStore): # Running where the job was created. self.encapsulatedJob.prepareForPromiseRegistration(jobStore) - def _disablePromiseRegistration(self): + def _disablePromiseRegistration(self) -> None: if self.encapsulatedJob is None: raise RuntimeError("The encapsulated job was not set.") super()._disablePromiseRegistration() self.encapsulatedJob._disablePromiseRegistration() - def __reduce__(self): + def __reduce__(self) -> tuple[type, tuple[None]]: """ Called during pickling to define the pickled representation of the job. @@ -3655,7 +3755,7 @@ def __reduce__(self): return self.__class__, (None,) - def getUserScript(self): + def getUserScript(self) -> ModuleDescriptor: if self.encapsulatedJob is None: raise RuntimeError("The encapsulated job was not set.") return self.encapsulatedJob.getUserScript() @@ -3666,7 +3766,7 @@ class ServiceHostJob(Job): Job that runs a service. Used internally by Toil. Users should subclass Service instead of using this. """ - def __init__(self, service): + def __init__(self, service: Job.Service) -> None: """ This constructor should not be called by a user. @@ -3698,47 +3798,49 @@ def __init__(self, service): # The service to run, or None if it is still pickled. # We can't just pickle as part of ourselves because we may need to load # an additional module. - self.service = service + self.service: Job.Service | None = service # The pickled service, or None if it isn't currently pickled. # We can't just pickle right away because we may owe promises from it. - self.pickledService = None + self.pickledService: bytes | None = None # Pick up our name from the service. self.description.jobName = service.jobName @property - def fileStore(self): + def fileStore(self) -> AbstractFileStore: """ Return the file store, which the Service may need. """ + assert self._fileStore is not None return self._fileStore - def _renameReferences(self, renames): + def _renameReferences(self, renames: dict[TemporaryID, str]) -> None: # When the job store finally hads out IDs we have to fix up the # back-reference from our Service to us. super()._renameReferences(renames) if self.service is not None: + assert isinstance(self.service.hostID, TemporaryID) self.service.hostID = renames[self.service.hostID] # Since the running service has us, make sure they don't try to tack more # stuff onto us. - def addChild(self, child): + def addChild(self, child: Job) -> NoReturn: raise RuntimeError( "Service host jobs cannot have children, follow-ons, or services" ) - def addFollowOn(self, followOn): + def addFollowOn(self, followOn: Job) -> NoReturn: raise RuntimeError( "Service host jobs cannot have children, follow-ons, or services" ) - def addService(self, service, parentService=None): + def addService(self, service: Job.Service, parentService: Job.Service | None = None) -> NoReturn: raise RuntimeError( "Service host jobs cannot have children, follow-ons, or services" ) - def saveBody(self, jobStore): + def saveBody(self, jobStore: AbstractJobStore) -> None: """ Serialize the service itself before saving the host job's body. """ @@ -3756,10 +3858,13 @@ def saveBody(self, jobStore): self.service = service self.pickledService = None - def run(self, fileStore): + def run(self, fileStore: AbstractFileStore) -> None: + # Narrow the description type for access to service-specific fields + assert isinstance(self.description, ServiceJobDescription) # Unpickle the service logger.debug("Loading service module %s.", self.serviceModule) userModule = self._loadUserModule(self.serviceModule) + assert self.pickledService is not None service = self._unpickle( userModule, BytesIO(self.pickledService), requireInstanceOf=Job.Service ) @@ -3773,7 +3878,7 @@ def run(self, fileStore): # the service, to do this while the run method is running we # cheat and set the return value promise within the run method self._fulfillPromises(startCredentials, fileStore.jobStore) - self._rvs = ( + self._rvs = cast(dict[tuple[Any, ...], list[str]], {} ) # Set this to avoid the return values being updated after the # run method has completed! @@ -3801,6 +3906,7 @@ def run(self, fileStore): logger.debug( "Detected that the terminate jobStoreID has been removed so exiting" ) + assert self.description.errorJobStoreID is not None if not fileStore.jobStore.file_exists( self.description.errorJobStoreID ): @@ -3835,7 +3941,7 @@ def run(self, fileStore): # The stop function is always called service.stop(self) - def getUserScript(self): + def getUserScript(self) -> ModuleDescriptor: return self.serviceModule @@ -3929,7 +4035,6 @@ def potential_absolute_uris( def get_file_sizes( filenames: list[str], - file_source: AbstractJobStore, search_paths: list[str] | None = None, include_remote_files: bool = True, execution_dir: str | None = None, @@ -3939,7 +4044,6 @@ def get_file_sizes( to a tuple of the normalized URI, parent directory ID, and size of the file. The size of the file may be None, which means unknown size. :param filenames: list of filenames to evaluate on - :param file_source: Context to search for files with :param task_path: Dotted WDL name of the user-level code doing the importing (probably the workflow name). :param search_paths: If set, try resolving input location relative to the URLs or @@ -3959,13 +4063,13 @@ def get_filename_size(filename: str) -> FileMetadata: try: if not include_remote_files and is_remote_url(candidate_uri): # Use remote URIs in place. But we need to find the one that exists. - if not file_source.url_exists(candidate_uri): + if not URLAccess.url_exists(candidate_uri): # Wasn't found there continue # Now we know this exists, so pass it through # Get filesizes - filesize = file_source.get_size(candidate_uri) + filesize = URLAccess.get_size(candidate_uri) except UnimplementedURLException as e: # We can't find anything that can even support this URL scheme. # Report to the user, they are probably missing an extra. @@ -4032,7 +4136,7 @@ class CombineImportsJob(Job): Combine the outputs of multiple WorkerImportJobs into one promise """ - def __init__(self, d: Sequence[Promised[dict[str, FileID]]], **kwargs): + def __init__(self, d: Sequence[Promised[dict[str, FileID]]], **kwargs: Any) -> None: """ :param d: Sequence of dictionaries to merge """ @@ -4107,13 +4211,15 @@ class ImportsJob(Job): """ Job to organize and delegate files to individual WorkerImportJobs. + Only works on files of known size. + For the CWL/WDL runners, this is only used when runImportsOnWorkers is enabled """ def __init__( self, file_to_data: dict[str, FileMetadata], - max_batch_size: ParseableIndivisibleResource, + max_batch_size: str, import_worker_disk: ParseableIndivisibleResource, **kwargs: Any, ): @@ -4123,7 +4229,8 @@ def __init__( This class is only used when runImportsOnWorkers is enabled. :param file_to_data: mapping of file source name to file metadata - :param max_batch_size: maximum cumulative file size of a batched import + :param max_batch_size: maximum cumulative file size of a batched + import, as a disk amount specification. """ super().__init__(local=True, **kwargs) self._file_to_data = file_to_data @@ -4146,7 +4253,8 @@ def run( to FileMetadata mapping). The candidate URI is stored in FileMetadata.source. """ - max_batch_size = self._max_batch_size + # Parse the disk amount for the batch size. + max_batch_size: int = human2bytes(self._max_batch_size) file_to_data = self._file_to_data # Run WDL imports on a worker instead @@ -4158,12 +4266,13 @@ def run( file_batches = [] # List of filenames for each batch - per_batch_files = [] + per_batch_files: list[str] = [] per_batch_size = 0 while len(filenames) > 0: filename = filenames.pop(0) - # See if adding this to the queue will make the batch job too big + # See if adding this to the queue will make the batch job too big. filesize = file_to_data[filename][2] + assert filesize is not None if per_batch_size + filesize >= max_batch_size: # batch is too big now, store to schedule the batch if len(per_batch_files) == 0: @@ -4219,7 +4328,7 @@ class Promise: for each promise """ - filesToDelete = set() + filesToDelete: set[str] = set() """ A set of IDs of files containing promised values when we know we won't need them anymore """ @@ -4239,7 +4348,7 @@ def __init__(self, job: Job, path: Any): self.job = job self.path = path - def __reduce__(self): + def __reduce__(self) -> tuple[type, tuple[str, str]]: """ Return the Promise class and construction arguments. @@ -4260,7 +4369,7 @@ def __reduce__(self): return self.__class__, (jobStoreLocator, jobStoreFileID) @staticmethod - def __new__(cls, *args) -> Promise: + def __new__(cls, *args: Any) -> Any: """Instantiate this Promise.""" if len(args) != 2: raise RuntimeError( @@ -4274,7 +4383,7 @@ def __new__(cls, *args) -> Promise: return cls._resolve(*args) @classmethod - def _resolve(cls, jobStoreLocator, jobStoreFileID): + def _resolve(cls, jobStoreLocator: str, jobStoreFileID: str) -> Any: # Initialize the cached job store if it was never initialized in the current process or # if it belongs to a different workflow that was run earlier in the current process. if cls._jobstore is None or cls._jobstore.config.jobStore != jobStoreLocator: @@ -4325,7 +4434,7 @@ def unwrap_all(p: Sequence[Promised[T]]) -> Sequence[T]: raise TypeError( f"Attempted to unwrap a value at index {i} that is still a Promise: {item}" ) - return p + return cast(Sequence[T], p) class PromisedRequirement: @@ -4345,14 +4454,14 @@ class PromisedRequirement: C = B.addChildFn(h, cores=PromisedRequirement(lambda x: 2*x, B.rv())) """ + # TODO: Type this whole class better so it understands Promised[] def __init__(self, valueOrCallable: Any, *args: Any) -> None: """ Initialize this Promised Requirement. :param valueOrCallable: A single Promise instance or a function that takes args as input parameters. - :param args: variable length argument list - :type args: int or .Promise + :param args: variable length argument list for a callable first argument """ if hasattr(valueOrCallable, "__call__"): if len(args) == 0: @@ -4364,7 +4473,7 @@ def __init__(self, valueOrCallable: Any, *args: Any) -> None: "Define a PromisedRequirement function to handle multiple arguments." ) func = lambda x: x - args = [valueOrCallable] + args = (valueOrCallable,) self._func = dill.dumps(func) self._args = list(args) diff --git a/src/toil/jobStores/abstractJobStore.py b/src/toil/jobStores/abstractJobStore.py index c01865e071..b5058d457b 100644 --- a/src/toil/jobStores/abstractJobStore.py +++ b/src/toil/jobStores/abstractJobStore.py @@ -263,11 +263,11 @@ def setRootJob(self, rootJobStoreID: FileID) -> None: """Set the root job of the workflow backed by this job store.""" return self.set_root_job(rootJobStoreID) - def set_root_job(self, job_id: FileID) -> None: + def set_root_job(self, job_id: str) -> None: """ Set the root job of the workflow backed by this job store. - :param job_id: The ID of the job to set as root + :param job_id: The job store ID of the job to set as root """ with self.write_shared_file_stream(self.rootJobStoreIDFileName) as f: f.write(job_id.encode("utf-8")) diff --git a/src/toil/lib/expando.py b/src/toil/lib/expando.py index 4b8e9727c9..c484ed1212 100644 --- a/src/toil/lib/expando.py +++ b/src/toil/lib/expando.py @@ -14,6 +14,7 @@ # 5.14.2018: copied into Toil from https://github.com/BD2KGenomics/bd2k-python-lib +from typing import Any class Expando(dict): """ @@ -101,7 +102,7 @@ class Expando(dict): True """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.__slots__ = None self.__dict__ = self diff --git a/src/toil/statsAndLogging.py b/src/toil/statsAndLogging.py index 16965ef305..ccce1c1f76 100644 --- a/src/toil/statsAndLogging.py +++ b/src/toil/statsAndLogging.py @@ -24,7 +24,7 @@ from typing import IO, TYPE_CHECKING, Any, Union from toil.lib.conversions import strtobool -from toil.lib.expando import Expando +from toil.lib.expando import Expando, MagicExpando from toil.lib.history import HistoryManager from toil.lib.resources import ResourceMonitor @@ -44,6 +44,10 @@ logging.addLevelName(TRACE, "TRACE") +class StatsDict(MagicExpando): + """Subclass of MagicExpando for type-checking purposes.""" + + jobs: list[Expando] class StatsAndLogging: """A thread to aggregate statistics and logging.""" diff --git a/src/toil/test/src/jobServiceTest.py b/src/toil/test/src/jobServiceTest.py index d5aeb69ce1..289d16e4dc 100644 --- a/src/toil/test/src/jobServiceTest.py +++ b/src/toil/test/src/jobServiceTest.py @@ -473,8 +473,8 @@ def serviceAccessor( # Try reading an integer from the input file and writing out the message with job.fileStore.jobStore.read_file_stream(outJobStoreFileID) as fH: - fH = codecs.getreader("utf-8")(fH) - line = fH.readline() + reader = codecs.getreader("utf-8")(fH) + line = reader.readline() tokens = line.split() if len(tokens) != 2: diff --git a/src/toil/test/src/jobTest.py b/src/toil/test/src/jobTest.py index a1beb34b16..0525af660b 100644 --- a/src/toil/test/src/jobTest.py +++ b/src/toil/test/src/jobTest.py @@ -24,7 +24,6 @@ from toil.common import Toil from toil.exceptions import FailedJobsException from toil.job import ( - FunctionWrappingJob, Job, JobFunctionWrappingJob, JobGraphDeadlockException, @@ -406,7 +405,7 @@ def testNewCheckpointIsLeafVertexNonRootCase( """ - def createWorkflow() -> tuple[Job, FunctionWrappingJob]: + def createWorkflow() -> tuple[Job, Job]: rootJob = Job.wrapJobFn(simpleJobFn, "Parent") childCheckpointJob = rootJob.addChildJobFn( simpleJobFn, "Child", checkpoint=True @@ -731,9 +730,9 @@ def makeJobGraph( followOn edges. """ # Map of jobs to the list of promises they have - jobsToPromisesMap: dict[FunctionWrappingJob, list[Promise]] = {} + jobsToPromisesMap: dict[Job, list[Promise]] = {} - def makeJob(string: str) -> FunctionWrappingJob: + def makeJob(string: str) -> Job: promises: list[Promise] = [] job = Job.wrapFn( fn2Test, @@ -751,7 +750,7 @@ def makeJob(string: str) -> FunctionWrappingJob: jobs = [makeJob(str(i)) for i in range(nodeNumber)] # Record predecessors for sampling - predecessors: dict[FunctionWrappingJob, list[FunctionWrappingJob]] = ( + predecessors: dict[Job, list[Job]] = ( collections.defaultdict(list) ) @@ -777,7 +776,7 @@ def makeJob(string: str) -> FunctionWrappingJob: for job in jobs } - def getRandomPredecessor(job: FunctionWrappingJob) -> FunctionWrappingJob: + def getRandomPredecessor(job: Job) -> Job: predecessor = random.choice(list(predecessors[job])) while random.random() > 0.5 and len(predecessors[predecessor]) > 0: predecessor = random.choice(list(predecessors[predecessor])) diff --git a/src/toil/test/src/retainTempDirTest.py b/src/toil/test/src/retainTempDirTest.py index a6429fccd8..7cf8688c3d 100644 --- a/src/toil/test/src/retainTempDirTest.py +++ b/src/toil/test/src/retainTempDirTest.py @@ -80,7 +80,7 @@ def testOnSuccessWithSuccess(self, tmp_path: Path) -> None: "The worker's temporary workspace was not deleted despite " "a successful job execution and cleanWorkDir being set to 'onSuccesss'" ) - + def _runAndReturnWorkDir( self, tmp_path: Path, @@ -107,12 +107,12 @@ def _runAndReturnWorkDir( return os.listdir(workdir) def _launchRegular( - self, A: JobFunctionWrappingJob, options: argparse.Namespace + self, A: Job, options: argparse.Namespace ) -> None: Job.Runner.startToil(A, options) def _launchError( - self, A: JobFunctionWrappingJob, options: argparse.Namespace + self, A: Job, options: argparse.Namespace ) -> None: try: Job.Runner.startToil(A, options) diff --git a/src/toil/toilState.py b/src/toil/toilState.py index a563c27e5b..84eb3f7a45 100644 --- a/src/toil/toilState.py +++ b/src/toil/toilState.py @@ -15,7 +15,7 @@ import time from toil.bus import JobUpdatedMessage, MessageBus -from toil.job import CheckpointJobDescription, JobDescription +from toil.job import CheckpointJobDescription, JobDescription, TemporaryID from toil.jobStores.abstractJobStore import AbstractJobStore, NoSuchJobException logger = logging.getLogger(__name__) @@ -352,6 +352,9 @@ def _buildToilState(self, jobDesc: JobDescription) -> None: def processSuccessorWithMultiplePredecessors( successor: JobDescription, ) -> None: + # TODO: Can we hide the fact that TemporaryID exists better + # from the type system??? + assert not isinstance(jobDesc.jobStoreID, TemporaryID) # If jobDesc is not reported as complete by the successor if jobDesc.jobStoreID not in successor.predecessorsFinished: diff --git a/src/toil/wdl/wdltoil.py b/src/toil/wdl/wdltoil.py index 43be7bfc03..d722736e3b 100755 --- a/src/toil/wdl/wdltoil.py +++ b/src/toil/wdl/wdltoil.py @@ -6002,7 +6002,7 @@ def __init__( wdl_options: WDLContext, inputs_search_path: list[str], import_remote_files: bool, - import_workers_batchsize: ParseableIndivisibleResource, + import_workers_batchsize: str, import_workers_disk: ParseableIndivisibleResource, **kwargs: Any, ): @@ -6023,7 +6023,6 @@ def run(self, file_store: AbstractFileStore) -> Promised[WDLBindings]: filenames = extract_inode_values(self._inputs) file_to_metadata = get_file_sizes( filenames, - file_store.jobStore, self._inputs_search_path, include_remote_files=self._import_remote_files, execution_dir=self._wdl_options.get("execution_dir"), diff --git a/src/toil/worker.py b/src/toil/worker.py index 1838ba5165..16c36db9bb 100644 --- a/src/toil/worker.py +++ b/src/toil/worker.py @@ -49,20 +49,13 @@ JobDescription, ) from toil.jobStores.abstractJobStore import AbstractJobStore -from toil.lib.expando import MagicExpando from toil.lib.io import make_public_dir, path_union from toil.lib.resources import ResourceMonitor -from toil.statsAndLogging import configure_root_logger, install_log_color, set_log_level +from toil.statsAndLogging import StatsDict, configure_root_logger, install_log_color, set_log_level logger = logging.getLogger(__name__) -class StatsDict(MagicExpando): - """Subclass of MagicExpando for type-checking purposes.""" - - jobs: list[MagicExpando] - - def nextChainable( predecessor: JobDescription, job_store: AbstractJobStore, config: Config ) -> JobDescription | None: @@ -455,7 +448,7 @@ def workerScript( jobAttemptFailed = False failure_exit_code = 1 first_job_cores = None - statsDict = StatsDict() # type: ignore[no-untyped-call] + statsDict = StatsDict() statsDict.jobs = [] statsDict.workers.logs_to_leader = [] statsDict.workers.logging_user_streams = [] @@ -702,7 +695,8 @@ def blockFn() -> bool: max_bytes = 0 for job_stats in statsDict.jobs: if "disk" in job_stats: - max_bytes = max(max_bytes, int(job_stats.disk)) + # TODO: MyPy doesn't know this type-narrows our Expando to something that has .disk + max_bytes = max(max_bytes, int(job_stats.disk)) # type:ignore[attr-defined] statsDict.workers.disk = str(max_bytes) # Count the jobs executed. # TODO: toil stats could compute this but its parser is too general to hook into simply.