Skip to content

[Module] Add configurable RMSNorm that inherits from Module#2434

Draft
fegin wants to merge 2 commits intogh/fegin/82/basefrom
gh/fegin/82/head
Draft

[Module] Add configurable RMSNorm that inherits from Module#2434
fegin wants to merge 2 commits intogh/fegin/82/basefrom
gh/fegin/82/head

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented Feb 24, 2026

Stack from ghstack (oldest at bottom):

BC Breaking
This is a breaking change for any downstream code that passes norm_eps directly to model configs.

Why the Change
Same as #2428. This PR changes RMSNorm to inherit from Module.

Summary
Introduces RMSNorm, a configurable wrapper around nn.RMSNorm that inherits from the Module. This enables RMSNorm to participate in the Config.build() pattern and provides a standardized init_weights() method.

All model families (Llama3, Llama4, Qwen3, DeepSeekV3, GptOss, Flux) are updated to:

  • Use RMSNorm.Config fields in their dataclass configs instead of bare norm_eps: float
  • Build norms via config.build(normalized_shape=dim) instead of nn.RMSNorm(dim, eps=...)
  • Call norm.init_weights() instead of norm.reset_parameters()

The norm_eps field is removed from TransformerBlock.Config and Decoder.Config since the eps value is now encapsulated in RMSNorm.Config.

[ghstack-poisoned]
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 24, 2026
fegin added a commit that referenced this pull request Feb 24, 2026
** BC Breaking **
This is a breaking change for any downstream code that passes `norm_eps` directly to model configs.

** Why the Change **
Same as #2428. This PR changes RMSNorm to inherit from Module.

** Summary **
Introduces RMSNorm, a configurable wrapper around nn.RMSNorm that inherits from the Module. This enables RMSNorm to participate in the Config.build() pattern and provides a standardized init_weights() method.

All model families (Llama3, Llama4, Qwen3, DeepSeekV3, GptOss, Flux) are updated to:
- Use RMSNorm.Config fields in their dataclass configs instead of bare norm_eps: float
- Build norms via config.build(normalized_shape=dim) instead of nn.RMSNorm(dim, eps=...)
- Call norm.init_weights() instead of norm.reset_parameters()

The norm_eps field is removed from TransformerBlock.Config and Decoder.Config since the eps value is now encapsulated in RMSNorm.Config.


ghstack-source-id: bf4d8ec
Pull-Request: #2434
@fegin fegin marked this pull request as draft February 24, 2026 21:28
# Shared config object: safe because RMSNorm.Config is an immutable-style
# dataclass (slots=True, no mutable fields). If mutable fields are ever
# added, each model variant should get its own instance instead.
norm_config = RMSNorm.Config(eps=1e-6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question 1: why only qwen3 is doing this shared norm config, not other models?
question 2: if it's shared by multiple nodes in the config tree including the root node, why don't we let root node pass the config to children, just like hidden dimension / vocab_size / etc.

Comment on lines 53 to 54
attention_norm: RMSNorm.Config = field(default_factory=RMSNorm.Config)
ffn_norm: RMSNorm.Config = field(default_factory=RMSNorm.Config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need default_factory here?

[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 24, 2026
** BC Breaking **
This is a breaking change for any downstream code that passes `norm_eps` directly to model configs.

** Why the Change **
Same as #2428. This PR changes RMSNorm to inherit from Module.

** Summary **
Introduces RMSNorm, a configurable wrapper around nn.RMSNorm that inherits from the Module. This enables RMSNorm to participate in the Config.build() pattern and provides a standardized init_weights() method.

All model families (Llama3, Llama4, Qwen3, DeepSeekV3, GptOss, Flux) are updated to:
- Use RMSNorm.Config fields in their dataclass configs instead of bare norm_eps: float
- Build norms via config.build(normalized_shape=dim) instead of nn.RMSNorm(dim, eps=...)
- Call norm.init_weights() instead of norm.reset_parameters()

The norm_eps field is removed from TransformerBlock.Config and Decoder.Config since the eps value is now encapsulated in RMSNorm.Config.

ghstack-source-id: 207cf6a
Pull-Request: #2434
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants