-
Notifications
You must be signed in to change notification settings - Fork 28
Add epoch and token limits to train.py #675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add epoch and token limits to train.py #675
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces epoch-based and token-based training duration modes as alternatives to the existing iteration-based training. It adds new command-line arguments (--max_epochs, --max_tokens, --eval_interval_epochs, --eval_interval_tokens) to allow users to specify training schedules in terms of epochs or tokens rather than iterations. A new utility script is included to help users calculate epoch and token counts for their datasets.
Key changes:
- Adds epoch and token-based training duration modes with automatic conversion to iterations
- Implements dynamic evaluation scheduling based on epoch/token progress
- Includes a utility script for calculating training duration parameters
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| utils/epoch_token_helper.py | New utility script to compute tokens-per-epoch and iteration estimates for datasets |
| train_args.py | Adds new arguments for epoch/token-based training and validation logic for mutually exclusive options |
| train.py | Implements progress tracking infrastructure, dynamic evaluation scheduling, and training termination based on epoch/token limits |
| explorations/training_duration_modes.yaml | Example configuration demonstrating the three training duration modes |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| limit_eval_stop = False | ||
| if self.master_process: | ||
| while self.should_eval_for_limits(): | ||
| should_stop = run_validation_eval() | ||
| eval_happened = True | ||
| if should_stop: | ||
| limit_eval_stop = True | ||
| break | ||
| if limit_eval_stop: | ||
| break |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a distributed training setup (DDP), only the master process evaluates should_eval_for_limits() and runs validation, but the eval_happened flag is set only on the master process. This creates a synchronization issue where non-master processes won't know that an evaluation occurred, potentially leading to desynchronized state. Consider synchronizing eval_happened across all processes using a distributed barrier or broadcast operation.
| param_group['lr'] = self.lr | ||
|
|
||
| if self.iter_num % self.args.eval_interval == 0 and self.master_process: | ||
| losses = self.last_eval_losses |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The losses variable is initialized with self.last_eval_losses (line 1620) which is initialized to None in __init__ (line 112). If training starts without an initial evaluation (e.g., iter_num=0 doesn't trigger eval), then line 1909 self.scheduler.step(losses[\"val\"]) will fail with a TypeError when losses is None. The check on line 1909 guards against None, but the initial assignment on line 1620 should handle the case where self.last_eval_losses is None.
| losses = self.last_eval_losses | |
| losses = self.last_eval_losses if self.last_eval_losses is not None else {} |
| def _tokens_in_dataset(data_dir: str, dataset: str) -> int: | ||
| train_path = os.path.join(data_dir, dataset, 'train.bin') | ||
| if not os.path.exists(train_path): | ||
| raise FileNotFoundError(f"Could not find train.bin for dataset '{dataset}' in {train_path}") |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message says 'in {train_path}' but train_path is the full path to the file, not a directory. The message should say 'at {train_path}' for clarity.
| raise FileNotFoundError(f"Could not find train.bin for dataset '{dataset}' in {train_path}") | |
| raise FileNotFoundError(f"Could not find train.bin for dataset '{dataset}' at {train_path}") |
| if self.master_process and ( | ||
| (self.iter_num % self.args.eval_interval == 0) or limit_eval_due | ||
| ): | ||
| should_stop = run_validation_eval() | ||
| eval_happened = True | ||
| if should_stop: | ||
| break |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable eval_happened is initialized to False at the beginning of each training loop iteration, but when a limit-based evaluation occurs at lines 1916-1925, eval_happened is set to True only within the master process check. This could lead to inconsistency in the ETA calculation on line 1937 where eval_happened is used. Consider ensuring this flag's behavior is clearly defined for both master and non-master processes.
| if self.master_process and ( | |
| (self.iter_num % self.args.eval_interval == 0) or limit_eval_due | |
| ): | |
| should_stop = run_validation_eval() | |
| eval_happened = True | |
| if should_stop: | |
| break | |
| eval_due = (self.iter_num % self.args.eval_interval == 0) or limit_eval_due | |
| if self.master_process and eval_due: | |
| should_stop = run_validation_eval() | |
| if should_stop: | |
| eval_happened = True | |
| break | |
| if eval_due: | |
| eval_happened = True |
| self.training_args['eval_interval'] = self.args.eval_interval | ||
|
|
||
| def get_limit_tokens_trained(self): | ||
| if self.args.dataset_list is not None and hasattr(self, 'tokens_trained_dict'): |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method checks if self.args.dataset_list is not None but doesn't check if it's empty. If dataset_list is an empty list, line 599 will fail with an IndexError when trying to access self.args.dataset_list[0]. Add a check for empty list: if self.args.dataset_list is not None and self.args.dataset_list and hasattr(self, 'tokens_trained_dict').
| if self.args.dataset_list is not None and hasattr(self, 'tokens_trained_dict'): | |
| if self.args.dataset_list is not None and self.args.dataset_list and hasattr(self, 'tokens_trained_dict'): |
This pull request introduces flexible training duration controls to the training pipeline, allowing users to specify stopping criteria and evaluation cadence based on epochs or tokens, in addition to the existing iteration-based approach. The changes also add validation for new CLI arguments, improve tracking of training progress, and provide a utility for estimating epoch and iteration counts for datasets.
Training duration and evaluation control improvements:
train_args.pyfor--max_epochs,--max_tokens,--eval_interval_epochs, and--eval_interval_tokens, enabling training to stop after a specified number of epochs or tokens and setting evaluation frequency accordingly. Includes validation to ensure these options are used correctly and exclusively. [1] [2]train.pyto initialize and track training progress based on the selected duration mode (iterations, epochs, or tokens), including logic for determining stopping points and evaluation boundaries. [1] [2]train.pyto support early stopping and evaluation triggers based on epochs/tokens, with consistent progress tracking for both single and multi-dataset scenarios. [1] [2] [3] [4]Utility and configuration enhancements:
explorations/training_duration_modes.yamlto provide example hyperparameter configurations for different training duration modes, illustrating practical usage of the new options.utils/epoch_token_helper.py, a utility script to estimate the number of tokens and iterations per epoch for datasets, aiding in planning and configuring training runs.Minor fixes and code quality improvements:
These changes collectively make training configuration more flexible and robust, allowing for easier experimentation and reproducibility across different training regimes.