Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions configs/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ class BlueberryConfig:
# GQA parameters
n_kv_heads: int = 4

# Hyper-connections parameters
use_hyper_connections: bool = False
hyper_rate: int = 4
hyper_dynamic: bool = True

# Data params
# ⚠️ WARNING: For simplicity, I recomend not changing max_seq_len
# If you change max_seq_len, you MUST re-run data preparation!
Expand Down
95 changes: 95 additions & 0 deletions models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,98 @@ def forward(self, x):
ff_out = self.feed_forward(self.norm2(x))
x = x + self.dropout(ff_out)
return x


class HyperConnection(nn.Module):
def __init__(self, dim, rate, layer_id, dynamic, device=None):
super(HyperConnection, self).__init__()
self.rate = rate
self.layer_id = layer_id
self.dynamic = dynamic
self.static_beta = nn.Parameter(torch.ones((rate,), device=device))

init_alpha0 = torch.zeros((rate, 1), device=device)
init_alpha0[layer_id % rate, 0] = 1.
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye((rate), device=device)], dim=1))

if self.dynamic:
self.dynamic_alpha_fn = nn.Parameter(torch.zeros((dim, rate + 1), device=device))
self.dynamic_alpha_scale = nn.Parameter(torch.ones(1, device=device) * 0.01)
self.dynamic_beta_fn = nn.Parameter(torch.zeros((dim,), device=device))
self.dynamic_beta_scale = nn.Parameter(torch.ones(1, device=device) * 0.01)

self.layer_norm = nn.LayerNorm(dim)

def width_connection(self, h):
# h: (B, L, N, D)
if self.dynamic:
norm_h = self.layer_norm(h)

wc_weight = norm_h @ self.dynamic_alpha_fn
wc_weight = torch.tanh(wc_weight)
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
alpha = dynamic_alpha + self.static_alpha[None, None, ...]

dc_weight = norm_h @ self.dynamic_beta_fn
dc_weight = torch.tanh(dc_weight)
dynamic_beta = dc_weight * self.dynamic_beta_scale
beta = dynamic_beta + self.static_beta[None, None, ...]
else:
alpha = self.static_alpha[None, None, ...]
beta = self.static_beta[None, None, ...]

# width connection
mix_h = alpha.transpose(-1, -2) @ h
return mix_h, beta

def depth_connection(self, mix_h, h_o, beta):
# h_o: output of sub-layer (B, L, D)
# beta: (B, L, N)
# mix_h: (B, L, N+1, D)
h = torch.einsum("blh,bln->blnh", h_o, beta) + mix_h[..., 1:, :]
return h


class HyperTransformerBlock(nn.Module):
"""Transformer block with Hyper-connections"""

def __init__(
self,
d_model: int,
n_heads: int,
d_ff: int,
max_seq_len: int,
layer_id: int,
rate: int,
dynamic: bool,
dropout: float = 0.1,
n_kv_heads: int | None = None,
):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads, max_seq_len, dropout, n_kv_heads)
self.feed_forward = SquaredReLUFeedForward(d_model, d_ff, dropout)

self.norm1 = nn.RMSNorm(d_model)
self.norm2 = nn.RMSNorm(d_model)
self.dropout = nn.Dropout(dropout)

self.hc1 = HyperConnection(d_model, rate, layer_id * 2, dynamic)
self.hc2 = HyperConnection(d_model, rate, layer_id * 2 + 1, dynamic)

def forward(self, h):
# h shape: (B, L, N, D)

# --- Sub-layer 1: Attention ---
mix_h, beta = self.hc1.width_connection(h)
# Input to sub-layer is typically h[..., 0, :]
x = mix_h[..., 0, :]
attn_out = self.attention(self.norm1(x))
h = self.hc1.depth_connection(mix_h, self.dropout(attn_out), beta)

# --- Sub-layer 2: Feed-forward ---
mix_h, beta = self.hc2.width_connection(h)
x = mix_h[..., 0, :]
ff_out = self.feed_forward(self.norm2(x))
h = self.hc2.depth_connection(mix_h, self.dropout(ff_out), beta)

return h
46 changes: 37 additions & 9 deletions models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
from typing import Optional
from configs.llm_config import BlueberryConfig
from models.layers import TransformerBlock
from models.layers import TransformerBlock, HyperTransformerBlock


class MinimalLLM(nn.Module):
Expand All @@ -18,19 +18,30 @@ def __init__(self, config: BlueberryConfig):
self.position_dropout = nn.Dropout(config.dropout)

# Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock(
self.transformer_blocks = nn.ModuleList()
for i in range(config.n_layers):
if config.use_hyper_connections:
block = HyperTransformerBlock(
config.d_model,
config.n_heads,
config.d_ff,
config.max_seq_len,
layer_id=i,
rate=config.hyper_rate,
dynamic=config.hyper_dynamic,
dropout=config.dropout,
n_kv_heads=config.n_kv_heads,
)
else:
block = TransformerBlock(
config.d_model,
config.n_heads,
config.d_ff,
config.max_seq_len,
config.dropout,
n_kv_heads=config.n_kv_heads,
)
for i in range(config.n_layers)
]
)
self.transformer_blocks.append(block)

# Output layers
self.norm = nn.RMSNorm(config.d_model)
Expand All @@ -56,8 +67,25 @@ def forward(self, x):
x = self.position_dropout(x)

# Pass through transformer blocks
for block in self.transformer_blocks:
x = block(x)
if self.config.use_hyper_connections:
# Initialize hyper hidden matrix h: (B, L, N, D)
# We start by repeating x along the N dimension
# or by putting x in the first slot and zeros elsewhere
h = torch.zeros(
(x.size(0), x.size(1), self.config.hyper_rate, x.size(2)),
device=x.device,
dtype=x.dtype,
)
h[..., 0, :] = x

for block in self.transformer_blocks:
h = block(h)

# Final output is taken from the first slot
x = h[..., 0, :]
else:
for block in self.transformer_blocks:
x = block(x)

# Output projection
x = self.norm(x)
Expand Down
52 changes: 52 additions & 0 deletions tests/test_hyper_connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
import torch.nn as nn
from configs.llm_config import BlueberryConfig
from models.llm import MinimalLLM

def test_hyper_connections():
print("Testing Hyper-connections...")

# Configure model with hyper-connections
config = BlueberryConfig(
d_model=128,
n_heads=4,
n_layers=2,
d_ff=512,
max_seq_len=64,
use_hyper_connections=True,
hyper_rate=4,
hyper_dynamic=True,
compile_model=False
)

model = MinimalLLM(config)
batch_size = 2
seq_len = 16

# Create dummy input
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))

# Forward pass
print("Running forward pass...")
logits = model(input_ids)

print(f"Logits shape: {logits.shape}")
assert logits.shape == (batch_size, seq_len, config.vocab_size), "Unexpected logits shape"

# Backward pass
print("Running backward pass...")
loss = logits.mean()
loss.backward()

# Check if hyper-connection parameters have gradients
found_hc_grad = False
for name, param in model.named_parameters():
if "hc" in name and param.grad is not None:
found_hc_grad = True
break

assert found_hc_grad, "No gradients found for hyper-connection parameters"
print("Success! Hyper-connections forward and backward passes work as expected.")

if __name__ == "__main__":
test_hyper_connections()