[Module] Add configurable RMSNorm that inherits from Module#2434
Draft
fegin wants to merge 2 commits intogh/fegin/82/basefrom
Draft
[Module] Add configurable RMSNorm that inherits from Module#2434fegin wants to merge 2 commits intogh/fegin/82/basefrom
fegin wants to merge 2 commits intogh/fegin/82/basefrom
Conversation
fegin
added a commit
that referenced
this pull request
Feb 24, 2026
** BC Breaking ** This is a breaking change for any downstream code that passes `norm_eps` directly to model configs. ** Why the Change ** Same as #2428. This PR changes RMSNorm to inherit from Module. ** Summary ** Introduces RMSNorm, a configurable wrapper around nn.RMSNorm that inherits from the Module. This enables RMSNorm to participate in the Config.build() pattern and provides a standardized init_weights() method. All model families (Llama3, Llama4, Qwen3, DeepSeekV3, GptOss, Flux) are updated to: - Use RMSNorm.Config fields in their dataclass configs instead of bare norm_eps: float - Build norms via config.build(normalized_shape=dim) instead of nn.RMSNorm(dim, eps=...) - Call norm.init_weights() instead of norm.reset_parameters() The norm_eps field is removed from TransformerBlock.Config and Decoder.Config since the eps value is now encapsulated in RMSNorm.Config. ghstack-source-id: bf4d8ec Pull-Request: #2434
tianyu-l
reviewed
Feb 24, 2026
| # Shared config object: safe because RMSNorm.Config is an immutable-style | ||
| # dataclass (slots=True, no mutable fields). If mutable fields are ever | ||
| # added, each model variant should get its own instance instead. | ||
| norm_config = RMSNorm.Config(eps=1e-6) |
Contributor
There was a problem hiding this comment.
question 1: why only qwen3 is doing this shared norm config, not other models?
question 2: if it's shared by multiple nodes in the config tree including the root node, why don't we let root node pass the config to children, just like hidden dimension / vocab_size / etc.
torchtitan/models/common/decoder.py
Outdated
Comment on lines
53
to
54
| attention_norm: RMSNorm.Config = field(default_factory=RMSNorm.Config) | ||
| ffn_norm: RMSNorm.Config = field(default_factory=RMSNorm.Config) |
Contributor
There was a problem hiding this comment.
don't need default_factory here?
fegin
added a commit
that referenced
this pull request
Feb 24, 2026
** BC Breaking ** This is a breaking change for any downstream code that passes `norm_eps` directly to model configs. ** Why the Change ** Same as #2428. This PR changes RMSNorm to inherit from Module. ** Summary ** Introduces RMSNorm, a configurable wrapper around nn.RMSNorm that inherits from the Module. This enables RMSNorm to participate in the Config.build() pattern and provides a standardized init_weights() method. All model families (Llama3, Llama4, Qwen3, DeepSeekV3, GptOss, Flux) are updated to: - Use RMSNorm.Config fields in their dataclass configs instead of bare norm_eps: float - Build norms via config.build(normalized_shape=dim) instead of nn.RMSNorm(dim, eps=...) - Call norm.init_weights() instead of norm.reset_parameters() The norm_eps field is removed from TransformerBlock.Config and Decoder.Config since the eps value is now encapsulated in RMSNorm.Config. ghstack-source-id: 207cf6a Pull-Request: #2434
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
BC Breaking
This is a breaking change for any downstream code that passes
norm_epsdirectly to model configs.Why the Change
Same as #2428. This PR changes RMSNorm to inherit from Module.
Summary
Introduces RMSNorm, a configurable wrapper around nn.RMSNorm that inherits from the Module. This enables RMSNorm to participate in the Config.build() pattern and provides a standardized init_weights() method.
All model families (Llama3, Llama4, Qwen3, DeepSeekV3, GptOss, Flux) are updated to:
The norm_eps field is removed from TransformerBlock.Config and Decoder.Config since the eps value is now encapsulated in RMSNorm.Config.