From 77bcf4d2a74f22b5853e4b1b99953b81c8dfb26e Mon Sep 17 00:00:00 2001 From: vukrosic Date: Sat, 31 Jan 2026 10:58:49 +0100 Subject: [PATCH] feat: introduce hyper-connections to the transformer architecture with new configuration options, dedicated layer implementations, and a verification test. --- configs/llm_config.py | 5 ++ models/layers.py | 95 +++++++++++++++++++++++++++++++++ models/llm.py | 46 ++++++++++++---- tests/test_hyper_connections.py | 52 ++++++++++++++++++ 4 files changed, 189 insertions(+), 9 deletions(-) create mode 100644 tests/test_hyper_connections.py diff --git a/configs/llm_config.py b/configs/llm_config.py index 1384b013..219ea579 100644 --- a/configs/llm_config.py +++ b/configs/llm_config.py @@ -13,6 +13,11 @@ class BlueberryConfig: # GQA parameters n_kv_heads: int = 4 + # Hyper-connections parameters + use_hyper_connections: bool = False + hyper_rate: int = 4 + hyper_dynamic: bool = True + # Data params # ⚠️ WARNING: For simplicity, I recomend not changing max_seq_len # If you change max_seq_len, you MUST re-run data preparation! diff --git a/models/layers.py b/models/layers.py index 2b903501..14ddda1b 100644 --- a/models/layers.py +++ b/models/layers.py @@ -136,3 +136,98 @@ def forward(self, x): ff_out = self.feed_forward(self.norm2(x)) x = x + self.dropout(ff_out) return x + + +class HyperConnection(nn.Module): + def __init__(self, dim, rate, layer_id, dynamic, device=None): + super(HyperConnection, self).__init__() + self.rate = rate + self.layer_id = layer_id + self.dynamic = dynamic + self.static_beta = nn.Parameter(torch.ones((rate,), device=device)) + + init_alpha0 = torch.zeros((rate, 1), device=device) + init_alpha0[layer_id % rate, 0] = 1. + self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye((rate), device=device)], dim=1)) + + if self.dynamic: + self.dynamic_alpha_fn = nn.Parameter(torch.zeros((dim, rate + 1), device=device)) + self.dynamic_alpha_scale = nn.Parameter(torch.ones(1, device=device) * 0.01) + self.dynamic_beta_fn = nn.Parameter(torch.zeros((dim,), device=device)) + self.dynamic_beta_scale = nn.Parameter(torch.ones(1, device=device) * 0.01) + + self.layer_norm = nn.LayerNorm(dim) + + def width_connection(self, h): + # h: (B, L, N, D) + if self.dynamic: + norm_h = self.layer_norm(h) + + wc_weight = norm_h @ self.dynamic_alpha_fn + wc_weight = torch.tanh(wc_weight) + dynamic_alpha = wc_weight * self.dynamic_alpha_scale + alpha = dynamic_alpha + self.static_alpha[None, None, ...] + + dc_weight = norm_h @ self.dynamic_beta_fn + dc_weight = torch.tanh(dc_weight) + dynamic_beta = dc_weight * self.dynamic_beta_scale + beta = dynamic_beta + self.static_beta[None, None, ...] + else: + alpha = self.static_alpha[None, None, ...] + beta = self.static_beta[None, None, ...] + + # width connection + mix_h = alpha.transpose(-1, -2) @ h + return mix_h, beta + + def depth_connection(self, mix_h, h_o, beta): + # h_o: output of sub-layer (B, L, D) + # beta: (B, L, N) + # mix_h: (B, L, N+1, D) + h = torch.einsum("blh,bln->blnh", h_o, beta) + mix_h[..., 1:, :] + return h + + +class HyperTransformerBlock(nn.Module): + """Transformer block with Hyper-connections""" + + def __init__( + self, + d_model: int, + n_heads: int, + d_ff: int, + max_seq_len: int, + layer_id: int, + rate: int, + dynamic: bool, + dropout: float = 0.1, + n_kv_heads: int | None = None, + ): + super().__init__() + self.attention = MultiHeadAttention(d_model, n_heads, max_seq_len, dropout, n_kv_heads) + self.feed_forward = SquaredReLUFeedForward(d_model, d_ff, dropout) + + self.norm1 = nn.RMSNorm(d_model) + self.norm2 = nn.RMSNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.hc1 = HyperConnection(d_model, rate, layer_id * 2, dynamic) + self.hc2 = HyperConnection(d_model, rate, layer_id * 2 + 1, dynamic) + + def forward(self, h): + # h shape: (B, L, N, D) + + # --- Sub-layer 1: Attention --- + mix_h, beta = self.hc1.width_connection(h) + # Input to sub-layer is typically h[..., 0, :] + x = mix_h[..., 0, :] + attn_out = self.attention(self.norm1(x)) + h = self.hc1.depth_connection(mix_h, self.dropout(attn_out), beta) + + # --- Sub-layer 2: Feed-forward --- + mix_h, beta = self.hc2.width_connection(h) + x = mix_h[..., 0, :] + ff_out = self.feed_forward(self.norm2(x)) + h = self.hc2.depth_connection(mix_h, self.dropout(ff_out), beta) + + return h diff --git a/models/llm.py b/models/llm.py index 9e887372..e50717ef 100644 --- a/models/llm.py +++ b/models/llm.py @@ -3,7 +3,7 @@ import math from typing import Optional from configs.llm_config import BlueberryConfig -from models.layers import TransformerBlock +from models.layers import TransformerBlock, HyperTransformerBlock class MinimalLLM(nn.Module): @@ -18,9 +18,22 @@ def __init__(self, config: BlueberryConfig): self.position_dropout = nn.Dropout(config.dropout) # Transformer blocks - self.transformer_blocks = nn.ModuleList( - [ - TransformerBlock( + self.transformer_blocks = nn.ModuleList() + for i in range(config.n_layers): + if config.use_hyper_connections: + block = HyperTransformerBlock( + config.d_model, + config.n_heads, + config.d_ff, + config.max_seq_len, + layer_id=i, + rate=config.hyper_rate, + dynamic=config.hyper_dynamic, + dropout=config.dropout, + n_kv_heads=config.n_kv_heads, + ) + else: + block = TransformerBlock( config.d_model, config.n_heads, config.d_ff, @@ -28,9 +41,7 @@ def __init__(self, config: BlueberryConfig): config.dropout, n_kv_heads=config.n_kv_heads, ) - for i in range(config.n_layers) - ] - ) + self.transformer_blocks.append(block) # Output layers self.norm = nn.RMSNorm(config.d_model) @@ -56,8 +67,25 @@ def forward(self, x): x = self.position_dropout(x) # Pass through transformer blocks - for block in self.transformer_blocks: - x = block(x) + if self.config.use_hyper_connections: + # Initialize hyper hidden matrix h: (B, L, N, D) + # We start by repeating x along the N dimension + # or by putting x in the first slot and zeros elsewhere + h = torch.zeros( + (x.size(0), x.size(1), self.config.hyper_rate, x.size(2)), + device=x.device, + dtype=x.dtype, + ) + h[..., 0, :] = x + + for block in self.transformer_blocks: + h = block(h) + + # Final output is taken from the first slot + x = h[..., 0, :] + else: + for block in self.transformer_blocks: + x = block(x) # Output projection x = self.norm(x) diff --git a/tests/test_hyper_connections.py b/tests/test_hyper_connections.py new file mode 100644 index 00000000..200f0144 --- /dev/null +++ b/tests/test_hyper_connections.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from configs.llm_config import BlueberryConfig +from models.llm import MinimalLLM + +def test_hyper_connections(): + print("Testing Hyper-connections...") + + # Configure model with hyper-connections + config = BlueberryConfig( + d_model=128, + n_heads=4, + n_layers=2, + d_ff=512, + max_seq_len=64, + use_hyper_connections=True, + hyper_rate=4, + hyper_dynamic=True, + compile_model=False + ) + + model = MinimalLLM(config) + batch_size = 2 + seq_len = 16 + + # Create dummy input + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len)) + + # Forward pass + print("Running forward pass...") + logits = model(input_ids) + + print(f"Logits shape: {logits.shape}") + assert logits.shape == (batch_size, seq_len, config.vocab_size), "Unexpected logits shape" + + # Backward pass + print("Running backward pass...") + loss = logits.mean() + loss.backward() + + # Check if hyper-connection parameters have gradients + found_hc_grad = False + for name, param in model.named_parameters(): + if "hc" in name and param.grad is not None: + found_hc_grad = True + break + + assert found_hc_grad, "No gradients found for hyper-connection parameters" + print("Success! Hyper-connections forward and backward passes work as expected.") + +if __name__ == "__main__": + test_hyper_connections()