Skip to content
This repository was archived by the owner on May 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions training/post_gradient_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from datasets.base import DataLoader
from foundations import hparams
from foundations.step import Step


def post_gradient_callbacks(training_hparams: hparams.TrainingHparams, train_set_loader: DataLoader,
test_set_loader: DataLoader, eval_on_train: bool = False, verbose: bool = True,
start_step: Step = None, evaluate_every_epoch: bool = True):
result = []
return result
14 changes: 13 additions & 1 deletion training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from training.checkpointing import restore_checkpoint
from training import optimizers
from training import standard_callbacks
from training import post_gradient_callbacks
from training.metric_logger import MetricLogger

try:
Expand All @@ -32,6 +33,7 @@ def train(
train_loader: DataLoader,
output_location: str,
callbacks: typing.List[typing.Callable] = [],
post_gradient_cbs: typing.List[typing.Callable] = [],
start_step: Step = None,
end_step: Step = None
):
Expand All @@ -49,6 +51,12 @@ def train(
Callbacks are used for running the test set, saving the logger, saving the state of the
model, etc. The provide hooks into the training loop for customization so that the
training loop itself can remain simple.
* post_gradient_cbs: A list of functions that are called after the gradients are calculated and
before optimizer steps. Each function takes five arguments: the current step,
the output location, the model, the optimizer, and the logger.
These callbacks can be used e.g. to save gradients after each ro last step.
They provide hooks into the training loop for customization so that the
training loop itself can remain simple.
* start_step: The step at which the training data and learning rate schedule should begin.
Defaults to step 0.
* end_step: The step at which training should cease. Otherwise, training will go for the
Expand Down Expand Up @@ -122,6 +130,7 @@ def train(
else:
loss.backward()

for callback in post_gradient_cbs: callback(output_location, step, model, optimizer, logger)
# Step forward. Ignore extraneous warnings that the lr_schedule generates.
step_optimizer.step()
with warnings.catch_warnings(): # Filter unnecessary warning.
Expand Down Expand Up @@ -153,4 +162,7 @@ def standard_train(
callbacks = standard_callbacks.standard_callbacks(
training_hparams, train_loader, test_loader, start_step=start_step,
verbose=verbose, evaluate_every_epoch=evaluate_every_epoch)
train(training_hparams, model, train_loader, output_location, callbacks, start_step=start_step)
post_gradient_cbs = post_gradient_callbacks.post_gradient_callbacks(
training_hparams, train_loader, test_loader, start_step=start_step,
verbose=verbose, evaluate_every_epoch=evaluate_every_epoch)
train(training_hparams, model, train_loader, output_location, callbacks, post_gradient_cbs, start_step=start_step)