Skip to content

Conversation

@klei22
Copy link
Collaborator

@klei22 klei22 commented Jan 4, 2026

This pull request introduces support for Manifold-Constrained Hyper-Connections (mHC) in the model architecture. The changes add new configuration options, update the model to handle mHC logic throughout the forward pass, and integrate mHC into the block variations. Additionally, a new exploration YAML is added to compare default and mHC-recommended settings.

Major additions and changes:

1. Manifold-Constrained Hyper-Connections (mHC) Support

  • Added mHC-related configuration options to GPTConfig, such as use_mhc, mhc_expansion_rate, and others.
  • Implemented mHC logic in the model, including stream expansion/reduction and integration into forward passes (model.py). [1] [2] [3] [4] [5] [6]
  • Added mHC-specific methods for handling streams and residual connections in both the model and block variations. [1] [2]

2. Integration into Block Variations

  • Integrated mHC logic into both parallel and sequential block forward functions, including pre- and post-mapping, skip connections, and layer normalization. [1] [2] [3] [4] [5]
  • Added mHC modules (ManifoldHyperConnections) to block initialization and enforced incompatibility with certain modes. [1] [2]

3. Configuration and Experimentation

  • Added a new explorations.yaml file to compare default model settings with mHC-recommended settings, enabling structured experimentation.

4. Miscellaneous

  • Minor code cleanups and ensuring correct initialization of mHC-related parameters in model components. [1] [2]

These changes collectively enable the use of mHC as an architectural option, with all necessary hooks for experimentation, configuration, and integration into the model and block logic.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request introduces Manifold-Constrained Hyper-Connections (mHC) as a new architectural option for the transformer model. The implementation adds stream-based processing where embeddings are expanded into multiple streams, processed through attention/MLP layers with learned mappings, and then reduced back.

Key changes:

  • New mHC implementation with Sinkhorn normalization and learnable hyper-connection mappings
  • Integration of mHC logic throughout the forward pass with conditional stream expansion/reduction
  • Configuration options for mHC parameters (expansion rate, alpha initialization, Sinkhorn iterations, RMS norm epsilon)

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
variations/mhc_variations.py Implements ManifoldHyperConnections class with pre_map/post_map methods for stream-based processing
variations/block_variations.py Integrates mHC into parallel_mlp_forward and attn_then_mlp_forward; adds helper methods for stream reduction and post-LN
model.py Adds mHC stream expansion/reduction helper methods and conditional mHC logic throughout forward passes
gpt_conf.py Defines mHC configuration parameters (use_mhc, mhc_expansion_rate, mhc_alpha_init, mhc_sinkhorn_iters, mhc_rmsnorm_eps)
explorations.yaml Adds exploration configuration comparing default settings against mHC-recommended settings

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +31 to +33
self.alpha_pre = nn.Parameter(torch.tensor(alpha_init))
self.alpha_post = nn.Parameter(torch.tensor(alpha_init))
self.alpha_res = nn.Parameter(torch.tensor(alpha_init))
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alpha parameters are created using torch.tensor(alpha_init) which creates a 0-dimensional tensor. When these are used in operations like line 52, they will not have gradients properly tracked during distributed training or certain edge cases. Use torch.tensor(alpha_init, dtype=torch.float32) or preferably torch.full([], alpha_init) to ensure proper gradient tracking and dtype specification.

Suggested change
self.alpha_pre = nn.Parameter(torch.tensor(alpha_init))
self.alpha_post = nn.Parameter(torch.tensor(alpha_init))
self.alpha_res = nn.Parameter(torch.tensor(alpha_init))
self.alpha_pre = nn.Parameter(torch.full([], alpha_init, dtype=torch.float32))
self.alpha_post = nn.Parameter(torch.full([], alpha_init, dtype=torch.float32))
self.alpha_res = nn.Parameter(torch.full([], alpha_init, dtype=torch.float32))

Copilot uses AI. Check for mistakes.
Comment on lines +43 to +44
mat = mat / (mat.sum(dim=-1, keepdim=True) + 1e-6)
mat = mat / (mat.sum(dim=-2, keepdim=True) + 1e-6)
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded epsilon value of 1e-6 for numerical stability in the Sinkhorn normalization should be configurable or match the mhc_rmsnorm_eps parameter from the config. Hardcoding this value makes it difficult to adjust for different precision requirements (e.g., float16 vs float32) and creates inconsistency with other epsilon values in the module.

Suggested change
mat = mat / (mat.sum(dim=-1, keepdim=True) + 1e-6)
mat = mat / (mat.sum(dim=-2, keepdim=True) + 1e-6)
mat = mat / (mat.sum(dim=-1, keepdim=True) + self.rmsnorm_eps)
mat = mat / (mat.sum(dim=-2, keepdim=True) + self.rmsnorm_eps)

Copilot uses AI. Check for mistakes.
Comment on lines +63 to +83
def pre_map(self, stream: torch.Tensor):
h_pre, h_post, h_res = self._compute_mappings(stream)
b, t, n, c = stream.shape
stream_2d = stream.reshape(b * t, n, c)
pre_weights = h_pre.view(b * t, 1, n)
x_in = torch.bmm(pre_weights, stream_2d).view(b, t, c)
cache = (h_post, h_res)
return x_in, cache

def post_map(self, stream: torch.Tensor, out: torch.Tensor, cache):
h_post, h_res = cache
b, t, n, c = stream.shape
stream_2d = stream.reshape(b * t, n, c)
out_2d = out.reshape(b * t, 1, c)

post_weights = h_post.view(b * t, n, 1)
post_contrib = torch.bmm(post_weights, out_2d)
res_contrib = torch.bmm(h_res, stream_2d)

updated = res_contrib + post_contrib
return updated.view(b, t, n, c)
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pre_map and post_map methods lack docstrings explaining their parameters, return values, and expected tensor shapes. Given the complexity of the mHC mechanism and the specific shape requirements (4D input tensor with shape (batch, time, streams, channels)), comprehensive documentation is essential for maintainability and proper usage.

Copilot uses AI. Check for mistakes.
Comment on lines +161 to +167
# Manifold-Constrained Hyper-Connections (mHC)
use_mhc: bool = False
mhc_expansion_rate: int = 4
mhc_alpha_init: float = 0.01
mhc_sinkhorn_iters: int = 20
mhc_rmsnorm_eps: float = 1e-5

Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new mHC functionality lacks test coverage. Given that the repository has comprehensive test scripts for other model variations (e.g., test_ln_parallel_mlp_cpu.sh, test_gqa_variations_cpu.sh), a similar test should be added to verify that mHC works correctly with basic training and sampling. This is especially important since mHC introduces complex tensor reshaping operations and new forward pass logic that could fail in edge cases.

Copilot uses AI. Check for mistakes.
Comment on lines +39 to +45
def _sinkhorn(self, logits: torch.Tensor) -> torch.Tensor:
logits = logits - logits.amax(dim=(-2, -1), keepdim=True)
mat = torch.exp(logits)
for _ in range(self.sinkhorn_iters):
mat = mat / (mat.sum(dim=-1, keepdim=True) + 1e-6)
mat = mat / (mat.sum(dim=-2, keepdim=True) + 1e-6)
return mat
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Sinkhorn algorithm performs iterative normalization with 20 iterations by default (configurable via mhc_sinkhorn_iters). This is computed for every forward pass in _compute_mappings, which could become a performance bottleneck during training, especially with large batch sizes or long sequences. Consider caching the Sinkhorn results when the input doesn't change significantly, or evaluate if fewer iterations would suffice for convergence in practice.

Copilot uses AI. Check for mistakes.
Comment on lines +47 to +61
def _compute_mappings(self, stream: torch.Tensor):
b, t, n, c = stream.shape
stream_flat = stream.reshape(b * t, n * c)
stream_norm = self._rms_norm(stream_flat)

pre = self.alpha_pre * self.phi_pre(stream_norm) + self.bias_pre
post = self.alpha_post * self.phi_post(stream_norm) + self.bias_post
res = self.alpha_res * self.phi_res(stream_norm) + self.bias_res

h_pre = torch.sigmoid(pre)
h_post = 2.0 * torch.sigmoid(post)
h_res = self._sinkhorn(res.reshape(b * t, n, n))

dtype = stream.dtype
return h_pre.to(dtype), h_post.to(dtype), h_res.to(dtype)
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _compute_mappings method computes all three mappings (h_pre, h_post, h_res) on every call, but pre_map only uses h_pre while caching h_post and h_res for later use in post_map. This means during the pre_map call, h_post and h_res are computed but not immediately used. Consider splitting this into separate methods or document why all three must be computed together (e.g., if they share expensive intermediate computations).

Copilot uses AI. Check for mistakes.

if self.use_mhc:
self.mhc_attn = ManifoldHyperConnections(config)
self.mhc_mlp = ManifoldHyperConnections(config)
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When use_parallel_mlp is True, the mhc_mlp module is instantiated but never used. In parallel_mlp_forward, only mhc_attn is used to handle the combined attention+MLP output (lines 78 and 103 in block_variations.py). Consider conditionally creating mhc_mlp only when not in parallel mode to avoid wasting memory and parameters on an unused module.

Suggested change
self.mhc_mlp = ManifoldHyperConnections(config)
# Only instantiate mhc_mlp when not using parallel MLP mode,
# since in parallel_mlp_forward the combined path uses mhc_attn.
if not getattr(config, "use_parallel_mlp", False):
self.mhc_mlp = ManifoldHyperConnections(config)

Copilot uses AI. Check for mistakes.
Comment on lines +227 to 241
def _mhc_expand_stream(self, x: torch.Tensor) -> torch.Tensor:
return x.unsqueeze(2).expand(-1, -1, self.mhc_streams, -1)

@staticmethod
def _mhc_reduce_stream(stream: torch.Tensor) -> torch.Tensor:
return stream.mean(dim=2)

def _mhc_add_to_stream(self, stream: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
return stream + residual.unsqueeze(2)

def _mhc_apply_to_stream(self, stream: torch.Tensor, fn) -> torch.Tensor:
reduced = self._mhc_reduce_stream(stream)
updated = fn(reduced)
return self._mhc_expand_stream(updated)
else:
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method definitions inserted here break the if-else structure of the init method. Lines 227-240 define helper methods that are placed between an if statement (line 224: if config.n_embd_wte:) and its corresponding else clause (line 241). These method definitions should be moved outside of the init method or placed after the if-else block completes. This will cause a syntax error preventing the code from running.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant