-
Notifications
You must be signed in to change notification settings - Fork 28
Add norm ft experiments #733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds a new normalization layer RMSHyperSphereNorm that enables fine-tuning from RMSNorm to HyperSphereNorm over a configurable iteration window. The implementation allows models to start with RMSNorm behavior and gradually transition to HyperSphereNorm by linearly interpolating between the two normalization methods.
Note: The PR description mentions "PyraNet-Verilog dataset preprocessing pipeline" with Tree-sitter highlighting, but the actual code changes implement normalization fine-tuning experiments. This is a significant mismatch.
Changes:
- Added
RMSHyperSphereNormclass that blends RMSNorm and HyperSphereNorm based on training iteration - Added infrastructure to propagate iteration numbers to normalization layers throughout the model
- Added configuration parameters for controlling the fine-tuning schedule
- Added experimental YAML configuration file for testing the new normalization approach
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| variations/norm_variations.py | Implements RMSHyperSphereNorm class with iteration-based blending between RMSNorm and HyperSphereNorm, and registers it in the norm dictionary |
| variations/block_variations.py | Adds _set_norm_iter method to propagate iteration numbers to normalization layers within transformer blocks |
| train_args.py | Adds command-line arguments for fine-tuning schedule configuration and registers the new normalization variant |
| model.py | Adds _set_norm_iter static method and calls it at appropriate points to propagate iteration numbers to norm layers |
| gpt_conf.py | Adds configuration fields for fine-tuning start iteration and duration |
| explorations/rmsnorm_hyperspherenorm_finetune.yaml | Provides experimental configuration for testing the new normalization approach with various fine-tuning schedules |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| radius = self.const_radius_factor * self.radius_init_factor | ||
| hypersphere_norm = x.norm(2, dim=-1, keepdim=True) | ||
| hypersphere_out = x / hypersphere_norm * radius |
Copilot
AI
Feb 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the RMSHyperSphereNorm implementation, the hypersphere_out computation does not include the gain parameter, unlike the HyperSphereNorm class (line 100). The hypersphere_out on line 141 should likely multiply by self.gain to be consistent with the standalone HyperSphereNorm implementation. Without this, when alpha=0 (fully hypersphere mode), the gain parameter would have no effect, which is inconsistent behavior.
| hypersphere_out = x / hypersphere_norm * radius | |
| hypersphere_out = x / hypersphere_norm * radius * self.gain |
| gain_init = torch.ones(ndim) | ||
| self.gain = nn.Parameter(gain_init) | ||
|
|
||
| radius_init = gain_init.mean().item() * math.sqrt(ndim) |
Copilot
AI
Feb 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The radius initialization uses gain_init.mean().item() which is always 1.0 (since gain_init is torch.ones(ndim)). This makes the computation on line 111 equivalent to just math.sqrt(ndim). If this was intentional, the code could be simplified. If it was meant to allow for different initialization, it should use a configurable value similar to HyperSphereNorm's hsnorm_radius parameter.
| gain_init = torch.ones(ndim) | |
| self.gain = nn.Parameter(gain_init) | |
| radius_init = gain_init.mean().item() * math.sqrt(ndim) | |
| self.gain = nn.Parameter(torch.ones(ndim)) | |
| radius_init = math.sqrt(ndim) |
| radius_init = gain_init.mean().item() * math.sqrt(ndim) | ||
| self.const_radius_factor = config.hsnorm_scale | ||
| radius_init = radius_init / self.const_radius_factor | ||
| self.radius_init_factor = nn.Parameter(torch.tensor([radius_init])) |
Copilot
AI
Feb 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RMSHyperSphereNorm class always makes radius_init_factor a learnable parameter (line 114), unlike HyperSphereNorm which respects the config.hsnorm_radius_learning flag (lines 89-95 in the same file). This is inconsistent behavior. The class should either respect the same config flag or have its own dedicated configuration parameter for controlling whether the radius is learnable.
| def _set_norm_iter(self, iter_num: int): | ||
| for module in self.modules(): | ||
| if hasattr(module, "set_iter_num"): | ||
| module.set_iter_num(iter_num) |
Copilot
AI
Feb 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _set_norm_iter method calls self.modules() which iterates over all modules in the block, including nested submodules. This means it will set iter_num on norm layers within attention and MLP modules as well. However, when gradient checkpointing is enabled (line 477), the checkpointed function receives iter_num as a parameter, and any norm layers inside block_forward that need iter_num should get it through that parameter, not through set_iter_num. This could lead to the iter_num being set twice, and more importantly, during the backward pass of gradient checkpointing, the iter_num state will not be properly restored since set_iter_num mutates module state before the checkpoint boundary. Consider moving _set_norm_iter inside the block_forward function or passing iter_num through the forward computation instead of storing it as state.
| - named_group: "rmsnorm_wte" | ||
| named_group_settings: | ||
| norm_variant_wte: ["rmsnorm"] | ||
| norm_wte_radius_learning: [true] |
Copilot
AI
Feb 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The named groups "rmsnorm_wte" (line 29) and "hsnorm_wte" (line 35) both set norm_wte_radius_learning to true, but this parameter only has an effect when using HyperSphereNorm or RMSHyperSphereNorm. For the plain RMSNorm case (line 31), this parameter has no effect since RMSNorm doesn't have a radius parameter. This could be confusing - consider removing norm_wte_radius_learning from the "rmsnorm_wte" named group.
| norm_wte_radius_learning: [true] |
| class RMSHyperSphereNorm(nn.Module): | ||
| """Blend from RMSNorm into HyperSphereNorm over a finetuning window.""" | ||
|
|
||
| def __init__(self, config): | ||
| super().__init__() | ||
| ndim = config.n_embd | ||
| gain_init = torch.ones(ndim) | ||
| self.gain = nn.Parameter(gain_init) | ||
|
|
||
| radius_init = gain_init.mean().item() * math.sqrt(ndim) | ||
| self.const_radius_factor = config.hsnorm_scale | ||
| radius_init = radius_init / self.const_radius_factor | ||
| self.radius_init_factor = nn.Parameter(torch.tensor([radius_init])) | ||
|
|
||
| self.finetune_start_iter = config.rms_hsnorm_finetune_start_iter | ||
| self.finetune_iters = config.rms_hsnorm_finetune_iters | ||
| self.iter_num = None | ||
|
|
||
| def set_iter_num(self, iter_num): | ||
| if isinstance(iter_num, torch.Tensor): | ||
| iter_num = int(iter_num.item()) | ||
| self.iter_num = iter_num | ||
|
|
||
| def _alpha(self): | ||
| if self.iter_num is None: | ||
| return 1.0 | ||
| if self.iter_num < self.finetune_start_iter: | ||
| return 1.0 | ||
| if self.finetune_iters <= 0: | ||
| return 0.0 | ||
| progress = (self.iter_num - self.finetune_start_iter) / float(self.finetune_iters) | ||
| return max(0.0, 1.0 - progress) | ||
|
|
||
| def forward(self, x): | ||
| rms = x.norm(2, dim=-1, keepdim=True) / math.sqrt(x.size(-1)) | ||
| rms_out = x / rms * self.gain | ||
|
|
||
| radius = self.const_radius_factor * self.radius_init_factor | ||
| hypersphere_norm = x.norm(2, dim=-1, keepdim=True) | ||
| hypersphere_out = x / hypersphere_norm * radius | ||
|
|
||
| alpha = self._alpha() | ||
| return alpha * rms_out + (1 - alpha) * hypersphere_out | ||
|
|
Copilot
AI
Feb 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR description mentions "PyraNet-Verilog dataset preprocessing pipeline" with Tree-sitter highlighting, but the actual code changes implement a normalization fine-tuning feature (RMSHyperSphereNorm). This is a significant discrepancy - either the PR description is incorrect or the wrong changes were included in this PR.
This pull request introduces support for a new normalization variant,
RMSHyperSphereNorm, which allows models to smoothly transition from RMSNorm to HyperSphereNorm during finetuning. The changes include the implementation of this new norm, configuration options for controlling the finetuning schedule, and integration with the model's forward pass to set the current iteration for the norm modules. Additionally, a new YAML configuration is added to facilitate experiments with these features.New normalization variant and finetuning support:
RMSHyperSphereNormclass invariations/norm_variations.py, enabling a blend from RMSNorm to HyperSphereNorm over a configurable finetuning window. The class supports dynamic control via the training iteration number.Configuration and CLI integration:
rms_hsnorm_finetune_start_iter,rms_hsnorm_finetune_iters) toGPTConfigingpt_conf.pyand exposed them as command-line arguments intrain_args.py. This enables users to specify when and how quickly the transition from RMSNorm to HyperSphereNorm should occur during training. [1] [2]"rmsnorm_hyperspherenorm"to the list of selectable norm variants in the CLI.Model integration and iteration handling:
model.pyto call a new_set_norm_itermethod before applying normalization layers, ensuring that the current iteration is correctly propagated to norm modules that support dynamic behavior. [1] [2] [3] [4] [5] [6] [7]_set_norm_iterlogic to blocks inblock_variations.pyto propagate iteration number to all submodules that require it.Experiment configuration:
explorations/rmsnorm_hyperspherenorm_finetune.yaml, which provides example setups for using the new norm variant and controlling its finetuning schedule.