[Module] Add configurable Embedding with init_weights support#2428
[Module] Add configurable Embedding with init_weights support#2428fegin wants to merge 4 commits intogh/fegin/81/basefrom
Conversation
**Why the Change** This is the first step to make all submodules configurable and inherit from Module. We will further extend Module to support per-module sharding configurations. **Summary** - Introduce NNEmbedding(nn.Embedding, Module) in torchtitan/models/common/embedding.py — a configurable embedding that uses diamond inheritance to reuse nn.Embedding logic while satisfying the Module protocol (init_weights + Configurable.Config.build()). - Replace raw nn.Embedding(vocab_size, dim) construction in Decoder with config.tok_embeddings.build(), making the token embedding configurable (e.g., custom init_mean/init_std) without changing any model subclass. - Harden Module.init_weights with raise NotImplementedError for runtime enforcement, since abstractmethod alone is insufficient (nn.Module uses type, not ABCMeta). - Add unit tests for the Module diamond inheritance pattern and NNEmbedding config/build lifecycle. ghstack-source-id: 8c4acde Pull-Request: #2428
There was a problem hiding this comment.
also at least need to have RMSNorm and Linear
| @@ -20,5 +20,15 @@ class Module(nn.Module, Configurable): | |||
|
|
|||
| @abstractmethod | |||
| def init_weights(self, **kwargs) -> None: | |||
There was a problem hiding this comment.
I think we should put this in config, instead of in module, because what I mentioned in the refactor PR summary #2386:
Refactor init_weights into Module.Config instead of staying in Module
The benefit is that param init can be configurable; o/w we are coupling module implementation and its weight init.
This may require refactor of current TransformerBlock and its config. E.g. weight_init_std may need to be put in config, with post_init determining its value. (See related complaints / discussions on post_init by chz)
There was a problem hiding this comment.
also would love to have a better name.
Previous this was called init_weights to be different from reset_parameters, but weights is a too narrow term. FSDP1 called param_init_fn (see https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md); internal called it param_init.
There was a problem hiding this comment.
I prefer init_states as we also initialize buffers such as freqs_cis. As for moving into config, I prefer to do that after I modify all the three modules (Embedding, RMSNorm and Linear) to Module.
**Why the Change** This is the first step to make all submodules configurable and inherit from Module. We will further extend Module to support per-module sharding configurations. **Summary** - Introduce NNEmbedding(nn.Embedding, Module) in torchtitan/models/common/embedding.py — a configurable embedding that uses diamond inheritance to reuse nn.Embedding logic while satisfying the Module protocol (init_weights + Configurable.Config.build()). - Replace raw nn.Embedding(vocab_size, dim) construction in Decoder with config.tok_embeddings.build(), making the token embedding configurable (e.g., custom init_mean/init_std) without changing any model subclass. - Harden Module.init_weights with raise NotImplementedError for runtime enforcement, since abstractmethod alone is insufficient (nn.Module uses type, not ABCMeta). - Add unit tests for the Module diamond inheritance pattern and NNEmbedding config/build lifecycle. ghstack-source-id: 4756859 Pull-Request: #2428
torchtitan/models/common/decoder.py
Outdated
| n_layers: int | ||
| vocab_size: int | ||
| norm_eps: float = 1e-5 | ||
| tok_embeddings: Embedding.Config = field(default_factory=Embedding.Config) |
There was a problem hiding this comment.
I think we don't need a default here, unless it has to be CLI-configurable.
| tok_embeddings: Embedding.Config = field(default_factory=Embedding.Config) | |
| tok_embeddings: Embedding.Config |
torchtitan/models/common/decoder.py
Outdated
| def __post_init__(self): | ||
| # NOTE: If subclasses override __post_init__, they MUST call | ||
| # super().__post_init__() to ensure tok_embeddings is configured. | ||
| self.tok_embeddings.num_embeddings = self.vocab_size | ||
| self.tok_embeddings.embedding_dim = self.dim |
There was a problem hiding this comment.
This could be tricky, e.g. https://github.com/openai/chz/blob/main/docs/21_post_init.md
I'd like to further discuss how fields in a submodule should be set by parent module. cc @shuhuayu
There was a problem hiding this comment.
I decided to use kwargs as I found some issues in a later PR.
** BC Breaking ** This is a breaking change for any downstream code that passes `norm_eps` directly to model configs. ** Why the Change ** Same as #2428. This PR changes RMSNorm to inherit from Module. ** Summary ** Introduces RMSNorm, a configurable wrapper around nn.RMSNorm that inherits from the Module. This enables RMSNorm to participate in the Config.build() pattern and provides a standardized init_weights() method. All model families (Llama3, Llama4, Qwen3, DeepSeekV3, GptOss, Flux) are updated to: - Use RMSNorm.Config fields in their dataclass configs instead of bare norm_eps: float - Build norms via config.build(normalized_shape=dim) instead of nn.RMSNorm(dim, eps=...) - Call norm.init_weights() instead of norm.reset_parameters() The norm_eps field is removed from TransformerBlock.Config and Decoder.Config since the eps value is now encapsulated in RMSNorm.Config. ghstack-source-id: bf4d8ec Pull-Request: #2434
** BC Breaking ** This is a breaking change for any downstream code that passes `norm_eps` directly to model configs. ** Why the Change ** Same as #2428. This PR changes RMSNorm to inherit from Module. ** Summary ** Introduces RMSNorm, a configurable wrapper around nn.RMSNorm that inherits from the Module. This enables RMSNorm to participate in the Config.build() pattern and provides a standardized init_weights() method. All model families (Llama3, Llama4, Qwen3, DeepSeekV3, GptOss, Flux) are updated to: - Use RMSNorm.Config fields in their dataclass configs instead of bare norm_eps: float - Build norms via config.build(normalized_shape=dim) instead of nn.RMSNorm(dim, eps=...) - Call norm.init_weights() instead of norm.reset_parameters() The norm_eps field is removed from TransformerBlock.Config and Decoder.Config since the eps value is now encapsulated in RMSNorm.Config. ghstack-source-id: 207cf6a Pull-Request: #2434
Stack from ghstack (oldest at bottom):
Why the Change
This is the first step to make all submodules configurable and inherit from Module. We will further extend Module to support per-module sharding configurations.
Summary
Introduce Embedding(nn.Embedding, Module) in torchtitan/models/common/embedding.py — a configurable embedding that uses diamond inheritance to reuse
nn.Embedding logic while satisfying the Module protocol (init_weights + Configurable.Config.build()).
Replace raw nn.Embedding(vocab_size, dim) construction in Decoder with config.tok_embeddings.build(), making the token embedding configurable (e.g., custom
init_mean/init_std) without changing any model subclass.
Harden Module.init_weights with raise NotImplementedError for runtime enforcement, since @AbstractMethod alone is insufficient (nn.Module uses type, not
ABCMeta).
Add unit tests for the Module diamond inheritance pattern and Embedding config/build lifecycle.