diff --git a/training/post_gradient_callbacks.py b/training/post_gradient_callbacks.py new file mode 100644 index 00000000..1a27a383 --- /dev/null +++ b/training/post_gradient_callbacks.py @@ -0,0 +1,10 @@ +from datasets.base import DataLoader +from foundations import hparams +from foundations.step import Step + + +def post_gradient_callbacks(training_hparams: hparams.TrainingHparams, train_set_loader: DataLoader, + test_set_loader: DataLoader, eval_on_train: bool = False, verbose: bool = True, + start_step: Step = None, evaluate_every_epoch: bool = True): + result = [] + return result diff --git a/training/train.py b/training/train.py index 99951b46..942ddc5a 100644 --- a/training/train.py +++ b/training/train.py @@ -17,6 +17,7 @@ from training.checkpointing import restore_checkpoint from training import optimizers from training import standard_callbacks +from training import post_gradient_callbacks from training.metric_logger import MetricLogger try: @@ -32,6 +33,7 @@ def train( train_loader: DataLoader, output_location: str, callbacks: typing.List[typing.Callable] = [], + post_gradient_cbs: typing.List[typing.Callable] = [], start_step: Step = None, end_step: Step = None ): @@ -49,6 +51,12 @@ def train( Callbacks are used for running the test set, saving the logger, saving the state of the model, etc. The provide hooks into the training loop for customization so that the training loop itself can remain simple. + * post_gradient_cbs: A list of functions that are called after the gradients are calculated and + before optimizer steps. Each function takes five arguments: the current step, + the output location, the model, the optimizer, and the logger. + These callbacks can be used e.g. to save gradients after each ro last step. + They provide hooks into the training loop for customization so that the + training loop itself can remain simple. * start_step: The step at which the training data and learning rate schedule should begin. Defaults to step 0. * end_step: The step at which training should cease. Otherwise, training will go for the @@ -122,6 +130,7 @@ def train( else: loss.backward() + for callback in post_gradient_cbs: callback(output_location, step, model, optimizer, logger) # Step forward. Ignore extraneous warnings that the lr_schedule generates. step_optimizer.step() with warnings.catch_warnings(): # Filter unnecessary warning. @@ -153,4 +162,7 @@ def standard_train( callbacks = standard_callbacks.standard_callbacks( training_hparams, train_loader, test_loader, start_step=start_step, verbose=verbose, evaluate_every_epoch=evaluate_every_epoch) - train(training_hparams, model, train_loader, output_location, callbacks, start_step=start_step) + post_gradient_cbs = post_gradient_callbacks.post_gradient_callbacks( + training_hparams, train_loader, test_loader, start_step=start_step, + verbose=verbose, evaluate_every_epoch=evaluate_every_epoch) + train(training_hparams, model, train_loader, output_location, callbacks, post_gradient_cbs, start_step=start_step)