-
Notifications
You must be signed in to change notification settings - Fork 28
Add exploration comparing default vs mHC settings #707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add exploration comparing default vs mHC settings #707
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request introduces Manifold-Constrained Hyper-Connections (mHC) as a new architectural option for the transformer model. The implementation adds stream-based processing where embeddings are expanded into multiple streams, processed through attention/MLP layers with learned mappings, and then reduced back.
Key changes:
- New mHC implementation with Sinkhorn normalization and learnable hyper-connection mappings
- Integration of mHC logic throughout the forward pass with conditional stream expansion/reduction
- Configuration options for mHC parameters (expansion rate, alpha initialization, Sinkhorn iterations, RMS norm epsilon)
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| variations/mhc_variations.py | Implements ManifoldHyperConnections class with pre_map/post_map methods for stream-based processing |
| variations/block_variations.py | Integrates mHC into parallel_mlp_forward and attn_then_mlp_forward; adds helper methods for stream reduction and post-LN |
| model.py | Adds mHC stream expansion/reduction helper methods and conditional mHC logic throughout forward passes |
| gpt_conf.py | Defines mHC configuration parameters (use_mhc, mhc_expansion_rate, mhc_alpha_init, mhc_sinkhorn_iters, mhc_rmsnorm_eps) |
| explorations.yaml | Adds exploration configuration comparing default settings against mHC-recommended settings |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self.alpha_pre = nn.Parameter(torch.tensor(alpha_init)) | ||
| self.alpha_post = nn.Parameter(torch.tensor(alpha_init)) | ||
| self.alpha_res = nn.Parameter(torch.tensor(alpha_init)) |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alpha parameters are created using torch.tensor(alpha_init) which creates a 0-dimensional tensor. When these are used in operations like line 52, they will not have gradients properly tracked during distributed training or certain edge cases. Use torch.tensor(alpha_init, dtype=torch.float32) or preferably torch.full([], alpha_init) to ensure proper gradient tracking and dtype specification.
| self.alpha_pre = nn.Parameter(torch.tensor(alpha_init)) | |
| self.alpha_post = nn.Parameter(torch.tensor(alpha_init)) | |
| self.alpha_res = nn.Parameter(torch.tensor(alpha_init)) | |
| self.alpha_pre = nn.Parameter(torch.full([], alpha_init, dtype=torch.float32)) | |
| self.alpha_post = nn.Parameter(torch.full([], alpha_init, dtype=torch.float32)) | |
| self.alpha_res = nn.Parameter(torch.full([], alpha_init, dtype=torch.float32)) |
| mat = mat / (mat.sum(dim=-1, keepdim=True) + 1e-6) | ||
| mat = mat / (mat.sum(dim=-2, keepdim=True) + 1e-6) |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded epsilon value of 1e-6 for numerical stability in the Sinkhorn normalization should be configurable or match the mhc_rmsnorm_eps parameter from the config. Hardcoding this value makes it difficult to adjust for different precision requirements (e.g., float16 vs float32) and creates inconsistency with other epsilon values in the module.
| mat = mat / (mat.sum(dim=-1, keepdim=True) + 1e-6) | |
| mat = mat / (mat.sum(dim=-2, keepdim=True) + 1e-6) | |
| mat = mat / (mat.sum(dim=-1, keepdim=True) + self.rmsnorm_eps) | |
| mat = mat / (mat.sum(dim=-2, keepdim=True) + self.rmsnorm_eps) |
| def pre_map(self, stream: torch.Tensor): | ||
| h_pre, h_post, h_res = self._compute_mappings(stream) | ||
| b, t, n, c = stream.shape | ||
| stream_2d = stream.reshape(b * t, n, c) | ||
| pre_weights = h_pre.view(b * t, 1, n) | ||
| x_in = torch.bmm(pre_weights, stream_2d).view(b, t, c) | ||
| cache = (h_post, h_res) | ||
| return x_in, cache | ||
|
|
||
| def post_map(self, stream: torch.Tensor, out: torch.Tensor, cache): | ||
| h_post, h_res = cache | ||
| b, t, n, c = stream.shape | ||
| stream_2d = stream.reshape(b * t, n, c) | ||
| out_2d = out.reshape(b * t, 1, c) | ||
|
|
||
| post_weights = h_post.view(b * t, n, 1) | ||
| post_contrib = torch.bmm(post_weights, out_2d) | ||
| res_contrib = torch.bmm(h_res, stream_2d) | ||
|
|
||
| updated = res_contrib + post_contrib | ||
| return updated.view(b, t, n, c) |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pre_map and post_map methods lack docstrings explaining their parameters, return values, and expected tensor shapes. Given the complexity of the mHC mechanism and the specific shape requirements (4D input tensor with shape (batch, time, streams, channels)), comprehensive documentation is essential for maintainability and proper usage.
| # Manifold-Constrained Hyper-Connections (mHC) | ||
| use_mhc: bool = False | ||
| mhc_expansion_rate: int = 4 | ||
| mhc_alpha_init: float = 0.01 | ||
| mhc_sinkhorn_iters: int = 20 | ||
| mhc_rmsnorm_eps: float = 1e-5 | ||
|
|
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new mHC functionality lacks test coverage. Given that the repository has comprehensive test scripts for other model variations (e.g., test_ln_parallel_mlp_cpu.sh, test_gqa_variations_cpu.sh), a similar test should be added to verify that mHC works correctly with basic training and sampling. This is especially important since mHC introduces complex tensor reshaping operations and new forward pass logic that could fail in edge cases.
| def _sinkhorn(self, logits: torch.Tensor) -> torch.Tensor: | ||
| logits = logits - logits.amax(dim=(-2, -1), keepdim=True) | ||
| mat = torch.exp(logits) | ||
| for _ in range(self.sinkhorn_iters): | ||
| mat = mat / (mat.sum(dim=-1, keepdim=True) + 1e-6) | ||
| mat = mat / (mat.sum(dim=-2, keepdim=True) + 1e-6) | ||
| return mat |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Sinkhorn algorithm performs iterative normalization with 20 iterations by default (configurable via mhc_sinkhorn_iters). This is computed for every forward pass in _compute_mappings, which could become a performance bottleneck during training, especially with large batch sizes or long sequences. Consider caching the Sinkhorn results when the input doesn't change significantly, or evaluate if fewer iterations would suffice for convergence in practice.
| def _compute_mappings(self, stream: torch.Tensor): | ||
| b, t, n, c = stream.shape | ||
| stream_flat = stream.reshape(b * t, n * c) | ||
| stream_norm = self._rms_norm(stream_flat) | ||
|
|
||
| pre = self.alpha_pre * self.phi_pre(stream_norm) + self.bias_pre | ||
| post = self.alpha_post * self.phi_post(stream_norm) + self.bias_post | ||
| res = self.alpha_res * self.phi_res(stream_norm) + self.bias_res | ||
|
|
||
| h_pre = torch.sigmoid(pre) | ||
| h_post = 2.0 * torch.sigmoid(post) | ||
| h_res = self._sinkhorn(res.reshape(b * t, n, n)) | ||
|
|
||
| dtype = stream.dtype | ||
| return h_pre.to(dtype), h_post.to(dtype), h_res.to(dtype) |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _compute_mappings method computes all three mappings (h_pre, h_post, h_res) on every call, but pre_map only uses h_pre while caching h_post and h_res for later use in post_map. This means during the pre_map call, h_post and h_res are computed but not immediately used. Consider splitting this into separate methods or document why all three must be computed together (e.g., if they share expensive intermediate computations).
|
|
||
| if self.use_mhc: | ||
| self.mhc_attn = ManifoldHyperConnections(config) | ||
| self.mhc_mlp = ManifoldHyperConnections(config) |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When use_parallel_mlp is True, the mhc_mlp module is instantiated but never used. In parallel_mlp_forward, only mhc_attn is used to handle the combined attention+MLP output (lines 78 and 103 in block_variations.py). Consider conditionally creating mhc_mlp only when not in parallel mode to avoid wasting memory and parameters on an unused module.
| self.mhc_mlp = ManifoldHyperConnections(config) | |
| # Only instantiate mhc_mlp when not using parallel MLP mode, | |
| # since in parallel_mlp_forward the combined path uses mhc_attn. | |
| if not getattr(config, "use_parallel_mlp", False): | |
| self.mhc_mlp = ManifoldHyperConnections(config) |
| def _mhc_expand_stream(self, x: torch.Tensor) -> torch.Tensor: | ||
| return x.unsqueeze(2).expand(-1, -1, self.mhc_streams, -1) | ||
|
|
||
| @staticmethod | ||
| def _mhc_reduce_stream(stream: torch.Tensor) -> torch.Tensor: | ||
| return stream.mean(dim=2) | ||
|
|
||
| def _mhc_add_to_stream(self, stream: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: | ||
| return stream + residual.unsqueeze(2) | ||
|
|
||
| def _mhc_apply_to_stream(self, stream: torch.Tensor, fn) -> torch.Tensor: | ||
| reduced = self._mhc_reduce_stream(stream) | ||
| updated = fn(reduced) | ||
| return self._mhc_expand_stream(updated) | ||
| else: |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method definitions inserted here break the if-else structure of the init method. Lines 227-240 define helper methods that are placed between an if statement (line 224: if config.n_embd_wte:) and its corresponding else clause (line 241). These method definitions should be moved outside of the init method or placed after the if-else block completes. This will cause a syntax error preventing the code from running.
This pull request introduces support for Manifold-Constrained Hyper-Connections (mHC) in the model architecture. The changes add new configuration options, update the model to handle mHC logic throughout the forward pass, and integrate mHC into the block variations. Additionally, a new exploration YAML is added to compare default and mHC-recommended settings.
Major additions and changes:
1. Manifold-Constrained Hyper-Connections (mHC) Support
GPTConfig, such asuse_mhc,mhc_expansion_rate, and others.model.py). [1] [2] [3] [4] [5] [6]2. Integration into Block Variations
ManifoldHyperConnections) to block initialization and enforced incompatibility with certain modes. [1] [2]3. Configuration and Experimentation
explorations.yamlfile to compare default model settings with mHC-recommended settings, enabling structured experimentation.4. Miscellaneous
These changes collectively enable the use of mHC as an architectural option, with all necessary hooks for experimentation, configuration, and integration into the model and block logic.