Conversation
|
Have you tried asking claude to help split this into a stack of PRs that can be reviewed/landed independently? I did see your comment/apology to reviewers, but i still think honestly nobody is going to review this PR in its entirety so are you asking for an uncareful scan and a stamp, or do you want to break out important pieces of the code that you want careful review on? |
|
@wconstab While I understand how intimidating it could be for reviewing a huge PR, I would like to initially deliver the package as whole instead of letting people only see incremental changes (if it's possible at all). Maybe I would like to achieve
|
| model: nn.Module, | ||
| parallel_dims: ParallelDims, | ||
| job_config: JobConfig, | ||
| *, |
There was a problem hiding this comment.
what's the rule for differentiating the positional args and kw args?
There was a problem hiding this comment.
I'm inspired by https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md#avoid-multiple-positional-arguments. Here I'm moving parallel_dims to kwarg as well.
There was a problem hiding this comment.
limit the number of positional arguments to <= 1 and use keyword arguments for the rest
what's the reason not making them all kwargs?
There was a problem hiding this comment.
I don't know for sure. Likely because for some functions, there would always be a "main" arg that is always there and doesn't introduce ambiguity / error-proneness. E.g. if a function only takes one arg, like parallelize(model), maybe it's fine? You can imagine later on when people adds more and more optional kwargs to the function, the model part doesn't need to be changed.
| train_spec: TrainSpec, | ||
| def register_torchtitan_model_from_model_spec( | ||
| model_spec: ModelSpec, | ||
| model_name: str, |
There was a problem hiding this comment.
model_name should be part of model_spec
There was a problem hiding this comment.
Model name refers to another thing: model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM"). Maybe need a more descriptive name here, but that could be done in a separate PR
wwwjn
left a comment
There was a problem hiding this comment.
Mainly took a look on components, config, experiments/rl, models, train.py and trainer.py
| train_spec: TrainSpec, | ||
| def register_torchtitan_model_from_model_spec( | ||
| model_spec: ModelSpec, | ||
| model_name: str, |
There was a problem hiding this comment.
Model name refers to another thing: model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM"). Maybe need a more descriptive name here, but that could be done in a separate PR
| parallel_dims: ParallelDims, | ||
| dump_folder: str = "./outputs", | ||
| pp_schedule: str = "1F1B", | ||
| ft_enable: bool = False, | ||
| ft_replica_id: int = 0, | ||
| config_dict: dict[str, Any] | None = None, | ||
| tag: str | None = None, |
There was a problem hiding this comment.
I also am worried about the kwargs being a loophole where people passing configurations around.
There was a problem hiding this comment.
An alternative approach is to require each component define these shared configurations and resolve the shared configurations when constructing the root configuration (Trainer.Config).
There was a problem hiding this comment.
Yeah, this is the top issue I put in "Longer-term Issues" in PR summary, which I couldn't handle entirely in this initial PR.
First, we need to figure out the boundary between "shared config" and "runtime kwargs". We can use more shared config, but that is "utilizing" (a.k.a. "abusing") the python config power in a way that makes it harder to transform to pure yaml solution, which may be OK.
More importantly, we need to reconsider if the current function calling structure makes sense at all. Current metrics logging is limited and hard to customize -- e.g. in MoE how to log the number of tokens each expert processes? In this sense, such problems are reflecting the design flaws we have in torchtitan -- in the past, these are omitted due to the usage of JobConfig everywhere. I think this is one of the good things about this refactor.
|
|
||
|
|
||
| register_model_converter(Float8LinearConverter, "quantize.linear.float8") | ||
| register_model_converter(Float8GroupedMMConverter, "quantize.grouped_mm.float8") |
There was a problem hiding this comment.
i see this removes the converter names (quantize.grouped_mm.float8 etc) - what does the command line API for this look like now?
There was a problem hiding this comment.
We lose CLI capability for adjusting this, because there is no string attached to each converter anymore.
* support launching custom trainer; * init trainer components through .build() (pytorch#2386); * move data to GPU by micro-batch; * remove rescale_accumulated_loss (pytorch#2206).
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386) that replaced TOML configs with Python dataclass configs and changed the CLI from CONFIG_FILE + --model.name to --module + --config. Updates the CI commands accordingly. Also fixes a runtime crash where aliased buffers (registered for user-facing API compat by #321) were being passed to the compiled graph, which only expects the canonical (deduplicated) set. The deepseek_v3 test is commented out as it's also disabled in torchtitan's own CI. Authored with Claude.
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386) that replaced TOML configs with Python dataclass configs and changed the CLI from CONFIG_FILE + --model.name to --module + --config. Updates the CI commands accordingly. Also fixes a runtime crash where aliased buffers (registered for user-facing API compat by #321) were being passed to the compiled graph, which only expects the canonical (deduplicated) set. The deepseek_v3 test is commented out as it's also disabled in torchtitan's own CI. Authored with Claude. stack-info: PR: #325, branch: xmfan/stack/26
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386) that replaced TOML configs with Python dataclass configs and changed the CLI from CONFIG_FILE + --model.name to --module + --config. Updates the CI commands accordingly. Also fixes a runtime crash where aliased buffers (registered for user-facing API compat by #321) were being passed to the compiled graph, which only expects the canonical (deduplicated) set. The deepseek_v3 test is commented out as it's also disabled in torchtitan's own CI. Authored with Claude. stack-info: PR: #325, branch: xmfan/stack/26
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386) that replaced TOML configs with Python dataclass configs and changed the CLI from CONFIG_FILE + --model.name to --module + --config. Updates the CI commands accordingly. Also fixes a runtime crash where aliased buffers (registered for user-facing API compat by #321) were being passed to the compiled graph, which only expects the canonical (deduplicated) set. The deepseek_v3 test is commented out as it's also disabled in torchtitan's own CI. Authored with Claude. stack-info: PR: #325, branch: xmfan/stack/26
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386) that replaced TOML configs with Python dataclass configs and changed the CLI from CONFIG_FILE + --model.name to --module + --config. Updates the CI commands accordingly. Also fixes a runtime crash where aliased buffers (registered for user-facing API compat by #321) were being passed to the compiled graph, which only expects the canonical (deduplicated) set. The deepseek_v3 test is commented out as it's also disabled in torchtitan's own CI. Authored with Claude. stack-info: PR: #325, branch: xmfan/stack/26
NOTE: This PR is a large refactor of the codebase. https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a latest release right before this PR is merged.
author's note
This refactor is mainly trying to address two issues:
JobConfigis leaked everywhereThe main changes are:
Configurablecomponent owns its ownConfig, which builds the owner component. It achieves modularization via polymorphism and inheritance, both classic concepts in OOP._target_in Hydra, but there are opinions not to couple with Hydra's other offerings. See [Feature request] Use omegaconf or hydra for the config system #1415config_registry.pyin each model).Trainer.Confignow,JobConfigin the past).This PR also
Remaining work
TypeError: forward() missing 1 required positional argument: 'fwd_rng_state_2'cc @yiming0416 please help take a lookdocs/to subfolders, as we are having more contents to cover in generalapply_rotary_emb_complexandapply_rotary_emb_single_complexLonger-term issues
model.pyandparallelize.pybyBaseModel.update_from_configviolates encapsulation by passing the Trainer config into Model config. This could be avoided by python logic either in config construction time, or in trainer.init_weightsintoModule.Configinstead of staying inModuleweight_init_stdmay need to be put in config, with__post_init__determining its value. (See related complaints / discussions on__post_init__by chz)Note to reviewer:
Although I believe the changes in this PR come naturally in a bundle, you may (or may not) find the stack of 16 commits easier to review, as I tried to split the changes in some logic manner. I apologize for the giant PR.
claude-generated summary
Summary
This PR refactors torchtitan's configuration and training infrastructure in 15 incremental, backwards-incompatible commits. The central change replaces TOML config files and a monolithic
JobConfigparser with typed Python dataclass configs, aConfigurablebase class pattern, and aconfig_registrymodule per model.270 files changed, 10,025 insertions, 11,418 deletions.
Motivation
The previous system used TOML files parsed by a custom
ConfigManagerthat layered CLI overrides on top. While simple, this had several friction points:training.stpes) silently becomes a default value.[model],[training],[optimizer],[checkpoint], ...) lived in a singleJobConfigclass. Every component received the fullJobConfigeven when it only needed a few fields.compile.graph_passesor FaultTolerant'sfault_tolerance.*) required acustom_config_moduleTOML key and a runtime_merge_configscall to graft new fields ontoJobConfig.ModelArgsdataclass inargs.pydefined hyperparameters, but theTrainSpecthat bundled model + parallelization + loss was registered separately, with no type-level link between them.What Changed
1.
ConfigurableBase ClassA new
Configurablebase class (torchtitan/config/configurable.py) establishes a universal pattern:Every configurable component (Trainer, model, optimizer, tokenizer, dataloader, checkpoint manager, metrics, validators, quantization converters, ...) follows this pattern. Calling
config.build()constructs the owning class.2.
Trainer.ConfigReplacesJobConfigThe monolithic
JobConfigis replaced byTrainer.Config, a nested dataclass that aggregates typed sub-configs:Each sub-config is the
Configclass of the component that consumes it (e.g.,CheckpointManager.Configis defined insideCheckpointManager). Components receive only their own config, not the entire training config.3.
config_registry.pyReplaces TOML FilesEach model defines a
config_registry.pywith functions that return completeTrainer.Configinstances:4.
TrainSpec->ModelSpecTrainSpecis renamed toModelSpecwith a narrower scope: it holds only model-specific concerns (model config, parallelization function, loss function, state dict adapter). All training-level concerns (optimizer, LR scheduler, checkpointing, etc.) live inTrainer.Config.5. Model Configs: Flat
ModelArgs-> Nested Dataclass HierarchyModel hyperparameters move from a flat
ModelArgsdataclass into a nestedConfighierarchy that mirrors the module tree:6.
train.pySplitThe monolithic
train.py(~800 lines) is split into:train.py(~60 lines): thin entry point that callsConfigManager.parse_args()andconfig.build()trainer.py(~850 lines): theTrainerclass with training loop logic7. Experiment Extension via Inheritance
Experiments extend the config system through dataclass subclassing instead of runtime config merging:
Their
config_registry.pyreturns the subclassed config type, andtyroauto-generates CLI parsing for the extended fields.UX Comparison
Launching Training
CLI Overrides
CLI override syntax is unchanged (
--section.field value), buttyronow provides typed--helpoutput generated from the dataclass tree.Defining a New Model Config
Adding Experiment-Specific Config Fields
Float8 / Quantization Configuration
Limitations and Trade-offs
1. Configs are no longer declarative text files
TOML files were readable by anyone without Python knowledge. The new config_registry functions are Python code, which requires understanding imports, function calls, and dataclass construction. For users who only need to tweak hyperparameters, the CLI override syntax (
--training.steps 100) works the same, but understanding the full config requires reading Python.2. Steeper learning curve for contributors
Adding a new model now requires understanding the
Configurableprotocol, nestedConfigdataclass hierarchy, and theconfig_registrypattern. The old approach of copying a TOML file and editing values had a lower barrier to entry.3. Config serialization is more complex
TOML files were trivially serializable and diffable. The new system supports
to_dict()+ JSON serialization, but configs containing callables (e.g.,ModelSpec.parallelize_fn) cannot be fully round-tripped. Themodel_specfield is excluded from serialization and suppressed from CLI parsing.4. tyro dependency
The CLI parsing now depends on
tyro, a third-party library. Whiletyrois well-maintained and provides typed CLI generation from dataclasses, it is an additional dependency that must be kept compatible with the dataclass patterns used here.5.
@dataclass(slots=True)constraintsThe
Configurablebase class enforces@dataclass(kw_only=True, slots=True)on all Config classes. While this provides memory efficiency and prevents accidental attribute assignment,slots=Trueprevents dynamic attribute addition and makes multiple inheritance with other slotted classes more constrained. Each Config subclass in a deep hierarchy must repeat the@dataclass(kw_only=True, slots=True)decorator.6. Two-level indirection for model selection
The old system required one identifier:
--job.config_file path/to/file.toml. The new system requires two:--module llama3 --config llama3_8b. While this separates model identity from training recipe, it adds an extra argument.Numerics Verification
All model configs were verified for numerical equivalence against the main branch (commit
10d8a306):NOTE
epsin final RMSNorm'dict' object has no attribute 'BLOCK_SIZE') but now work after this PRMigration Guide
CONFIG_FILE="path/to/config.toml" ./run_train.shMODEL=llama3 CONFIG=llama3_8b ./run_train.sh--job.config_file path.toml--module llama3 --config llama3_8btrain_configs/*.tomlconfig_registry.pyfunctionsTrainSpecModelSpecModelArgs/args.pyModel.Configdataclasscustom_config_module+_merge_configs()Trainer.Configbuild_model_converters()free functionModelConvertersContainer.Config.build()build_metrics_processor()free functionMetricsProcessor.Config.build()