Skip to content

[Module] Add configurable Embedding with init_weights support#2428

Draft
fegin wants to merge 4 commits intogh/fegin/81/basefrom
gh/fegin/81/head
Draft

[Module] Add configurable Embedding with init_weights support#2428
fegin wants to merge 4 commits intogh/fegin/81/basefrom
gh/fegin/81/head

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented Feb 24, 2026

Stack from ghstack (oldest at bottom):

Why the Change
This is the first step to make all submodules configurable and inherit from Module. We will further extend Module to support per-module sharding configurations.

Summary

  • Introduce Embedding(nn.Embedding, Module) in torchtitan/models/common/embedding.py — a configurable embedding that uses diamond inheritance to reuse
    nn.Embedding logic while satisfying the Module protocol (init_weights + Configurable.Config.build()).

  • Replace raw nn.Embedding(vocab_size, dim) construction in Decoder with config.tok_embeddings.build(), making the token embedding configurable (e.g., custom
    init_mean/init_std) without changing any model subclass.

  • Harden Module.init_weights with raise NotImplementedError for runtime enforcement, since @AbstractMethod alone is insufficient (nn.Module uses type, not
    ABCMeta).

  • Add unit tests for the Module diamond inheritance pattern and Embedding config/build lifecycle.

[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 24, 2026
**Why the Change**
This is the first step to make all submodules configurable and inherit from Module. We will further extend Module to support per-module sharding configurations.

**Summary**

- Introduce NNEmbedding(nn.Embedding, Module) in torchtitan/models/common/embedding.py — a configurable embedding that uses diamond inheritance to reuse
nn.Embedding logic while satisfying the Module protocol (init_weights + Configurable.Config.build()).

- Replace raw nn.Embedding(vocab_size, dim) construction in Decoder with config.tok_embeddings.build(), making the token embedding configurable (e.g., custom
init_mean/init_std) without changing any model subclass.

- Harden Module.init_weights with raise NotImplementedError for runtime enforcement, since abstractmethod alone is insufficient (nn.Module uses type, not
ABCMeta).

- Add unit tests for the Module diamond inheritance pattern and NNEmbedding config/build lifecycle.


ghstack-source-id: 8c4acde
Pull-Request: #2428
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 24, 2026
@fegin fegin marked this pull request as draft February 24, 2026 06:13
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also at least need to have RMSNorm and Linear

@@ -20,5 +20,15 @@ class Module(nn.Module, Configurable):

@abstractmethod
def init_weights(self, **kwargs) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should put this in config, instead of in module, because what I mentioned in the refactor PR summary #2386:

Refactor init_weights into Module.Config instead of staying in Module
The benefit is that param init can be configurable; o/w we are coupling module implementation and its weight init.
This may require refactor of current TransformerBlock and its config. E.g. weight_init_std may need to be put in config, with post_init determining its value. (See related complaints / discussions on post_init by chz)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also would love to have a better name.
Previous this was called init_weights to be different from reset_parameters, but weights is a too narrow term. FSDP1 called param_init_fn (see https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md); internal called it param_init.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer init_states as we also initialize buffers such as freqs_cis. As for moving into config, I prefer to do that after I modify all the three modules (Embedding, RMSNorm and Linear) to Module.

[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 24, 2026
**Why the Change**
This is the first step to make all submodules configurable and inherit from Module. We will further extend Module to support per-module sharding configurations.

**Summary**

- Introduce NNEmbedding(nn.Embedding, Module) in torchtitan/models/common/embedding.py — a configurable embedding that uses diamond inheritance to reuse
nn.Embedding logic while satisfying the Module protocol (init_weights + Configurable.Config.build()).

- Replace raw nn.Embedding(vocab_size, dim) construction in Decoder with config.tok_embeddings.build(), making the token embedding configurable (e.g., custom
init_mean/init_std) without changing any model subclass.

- Harden Module.init_weights with raise NotImplementedError for runtime enforcement, since abstractmethod alone is insufficient (nn.Module uses type, not
ABCMeta).

- Add unit tests for the Module diamond inheritance pattern and NNEmbedding config/build lifecycle.

ghstack-source-id: 4756859
Pull-Request: #2428
@fegin fegin changed the title [Module] Add configurable NNEmbedding with init_weights support [Module] Add configurable Embedding with init_weights support Feb 24, 2026
n_layers: int
vocab_size: int
norm_eps: float = 1e-5
tok_embeddings: Embedding.Config = field(default_factory=Embedding.Config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need a default here, unless it has to be CLI-configurable.

Suggested change
tok_embeddings: Embedding.Config = field(default_factory=Embedding.Config)
tok_embeddings: Embedding.Config

Comment on lines 77 to 81
def __post_init__(self):
# NOTE: If subclasses override __post_init__, they MUST call
# super().__post_init__() to ensure tok_embeddings is configured.
self.tok_embeddings.num_embeddings = self.vocab_size
self.tok_embeddings.embedding_dim = self.dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be tricky, e.g. https://github.com/openai/chz/blob/main/docs/21_post_init.md

I'd like to further discuss how fields in a submodule should be set by parent module. cc @shuhuayu

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to use kwargs as I found some issues in a later PR.

[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 24, 2026
** BC Breaking **
This is a breaking change for any downstream code that passes `norm_eps` directly to model configs.

** Why the Change **
Same as #2428. This PR changes RMSNorm to inherit from Module.

** Summary **
Introduces RMSNorm, a configurable wrapper around nn.RMSNorm that inherits from the Module. This enables RMSNorm to participate in the Config.build() pattern and provides a standardized init_weights() method.

All model families (Llama3, Llama4, Qwen3, DeepSeekV3, GptOss, Flux) are updated to:
- Use RMSNorm.Config fields in their dataclass configs instead of bare norm_eps: float
- Build norms via config.build(normalized_shape=dim) instead of nn.RMSNorm(dim, eps=...)
- Call norm.init_weights() instead of norm.reset_parameters()

The norm_eps field is removed from TransformerBlock.Config and Decoder.Config since the eps value is now encapsulated in RMSNorm.Config.


ghstack-source-id: bf4d8ec
Pull-Request: #2434
[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 24, 2026
** BC Breaking **
This is a breaking change for any downstream code that passes `norm_eps` directly to model configs.

** Why the Change **
Same as #2428. This PR changes RMSNorm to inherit from Module.

** Summary **
Introduces RMSNorm, a configurable wrapper around nn.RMSNorm that inherits from the Module. This enables RMSNorm to participate in the Config.build() pattern and provides a standardized init_weights() method.

All model families (Llama3, Llama4, Qwen3, DeepSeekV3, GptOss, Flux) are updated to:
- Use RMSNorm.Config fields in their dataclass configs instead of bare norm_eps: float
- Build norms via config.build(normalized_shape=dim) instead of nn.RMSNorm(dim, eps=...)
- Call norm.init_weights() instead of norm.reset_parameters()

The norm_eps field is removed from TransformerBlock.Config and Decoder.Config since the eps value is now encapsulated in RMSNorm.Config.

ghstack-source-id: 207cf6a
Pull-Request: #2434
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants