Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import torch.nn.functional as F
from torchtune.modules import RotaryPositionalEmbeddings
from .components import SquaredReLUFeedForward

from torch.nn.attention.flex_attention import (
flex_attention,
BlockMask,
)

class Rotary(nn.Module):
def __init__(self, dim: int, max_seq_len: int):
Expand Down Expand Up @@ -34,7 +37,10 @@ def __init__(
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
self.num_key_value_groups = self.n_heads // self.n_kv_heads
self.d_k = d_model // n_heads


# Sparse attention gate
self.atten_gate = nn.Linear(12, n_heads, bias=False)
nn.init.zeros_(self.atten_gate.weight)
# ============ MERGED QKVO PROJECTION ============
# Instead of 4 separate Linear layers, use single merged projection
q_size = d_model
Expand Down Expand Up @@ -62,7 +68,7 @@ def __init__(
self.rotary = Rotary(self.d_k, max_seq_len)
self.dropout = dropout

def forward(self, x):
def forward(self, x, mask: BlockMask):
batch_size, seq_len = x.size(0), x.size(1)

# ============ MERGED QKV PROJECTION ============
Expand Down Expand Up @@ -91,15 +97,19 @@ def forward(self, x):
Q, K, V = Q.transpose(1, 2), K.transpose(1, 2), V.transpose(1, 2)

# Compute attention
attn_output = F.scaled_dot_product_attention(
Q, K, V, is_causal=True, dropout_p=self.dropout if self.training else 0.0
)
attn_output = flex_attention(Q, K, V, block_mask=mask)

# Reshape output
attn_output = attn_output.transpose(1, 2).reshape(
batch_size, seq_len, self.n_heads, self.d_k
)
# Sparse attention
attn_output = attn_output * F.sigmoid(
self.atten_gate(x[..., : self.atten_gate.in_features])
).view(batch_size, seq_len, self.n_heads, 1)
attn_output = attn_output.contiguous().reshape(
batch_size, seq_len, self.d_model
)

# ============ MERGED O PROJECTION ============
# Use the last part of qkvo_proj for output projection
return F.linear(attn_output, self.qkvo_proj[self.qkv_size:])
Expand Down Expand Up @@ -127,9 +137,9 @@ def __init__(
self.norm2 = nn.RMSNorm(d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
def forward(self, x, mask: BlockMask):
# Self-attention
attn_out = self.attention(self.norm1(x))
attn_out = self.attention(self.norm1(x), mask)
x = x + self.dropout(attn_out)

# Feed-forward
Expand Down
23 changes: 21 additions & 2 deletions models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from typing import Optional
from configs.llm_config import BlueberryConfig
from models.layers import TransformerBlock
from torch.nn.attention.flex_attention import and_masks, create_block_mask


class MinimalLLM(nn.Module):
"""Minimal dense LLM"""

def __init__(self, config: BlueberryConfig):
def __init__(self, config: BlueberryConfig, eos_token: int | None = None):
super().__init__()
self.config = config
self.eos_token = eos_token

# Token embeddings
self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
Expand Down Expand Up @@ -50,14 +52,31 @@ def _init_weights(self, module):
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

def gen_mask(self, x):
doc = (x == self.eos_token).cumsum(dim=1)

def document_mask(b, h, q_idx, kv_idx):
return doc[b, q_idx] == doc[b, kv_idx]

def casual_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx

return and_masks(casual_mask, document_mask)

def forward(self, x):
batch_size, seq_len = x.size(0), x.size(1)
mask_mod = self.gen_mask(x)
mask = create_block_mask(
mask_mod, batch_size, None, seq_len, seq_len, _compile=True
)

# Token embeddings
x = self.token_embedding(x) * math.sqrt(self.config.d_model)
x = self.position_dropout(x)

# Pass through transformer blocks
for block in self.transformer_blocks:
x = block(x)
x = block(x, mask)

# Output projection
x = self.norm(x)
Expand Down
77 changes: 77 additions & 0 deletions plots/baseline.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
{
"final_metrics": {
"val_loss": 4.892065734863281,
"val_accuracy": 0.24040852467024915,
"val_perplexity": 133.22850475359996,
"train_loss": 4.8399858474731445
},
"setup_time_seconds": 12.699106931686401,
"active_training_time_seconds": 396.8685128688812,
"total_wall_time_seconds": 409.5676198005676,
"total_time_minutes": 6.826126996676127,
"actual_steps": 489,
"tokens_seen": 8011776,
"train_tokens": 8000000,
"history": {
"steps": [
0,
50,
100,
150,
200,
300,
400,
489
],
"val_losses": [
10.885406560897827,
6.553639316558838,
6.228558578491211,
5.9009589147567745,
5.614467568397522,
5.255833668708801,
5.036515588760376,
4.892065734863281
],
"val_accuracies": [
0.008136907669760626,
0.14678248656570592,
0.1532230092818759,
0.1738275525158769,
0.1929457743038593,
0.21529311187103078,
0.2301080850024426,
0.24040852467024915
],
"val_perplexities": [
53391.48529861793,
701.7935809523002,
507.02412104808775,
365.3876755628698,
274.3672585000841,
191.68121789674856,
153.93271460078893,
133.22850475359996
],
"elapsed_times": [
0.4335458437601725,
1.1485339283943177,
1.8589276591936748,
2.5722230354944866,
3.285143864154816,
4.4299559712409975,
5.5750153183937075,
6.614473648866018
],
"learning_rates": [
0.024,
0.024,
0.024,
0.024,
0.024,
0.024,
0.024,
0.024
]
}
}
Binary file added plots/baseline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
77 changes: 77 additions & 0 deletions plots/doc_mask.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
{
"final_metrics": {
"val_loss": 4.84444709777832,
"val_accuracy": 0.24380373717635564,
"val_perplexity": 127.03302573095146,
"train_loss": 4.802108287811279
},
"setup_time_seconds": 78.49195218086243,
"active_training_time_seconds": 466.7198405265808,
"total_wall_time_seconds": 545.2117927074432,
"total_time_minutes": 9.086863211790721,
"actual_steps": 489,
"tokens_seen": 8011776,
"train_tokens": 8000000,
"history": {
"steps": [
0,
50,
100,
150,
200,
300,
400,
489
],
"val_losses": [
10.871560344696045,
6.515311102867127,
6.170207910537719,
5.811286811828613,
5.548059434890747,
5.207452573776245,
4.995285930633545,
4.84444709777832
],
"val_accuracies": [
0.008136907669760626,
0.14623900830483635,
0.1556185881778212,
0.1805276013678554,
0.1974218368343918,
0.2186675622862726,
0.23282730825598436,
0.24380373717635564
],
"val_perplexities": [
52657.30975598555,
675.4040489781731,
478.28553635564975,
334.04870704765545,
256.7388537185275,
182.62823316781476,
147.71517563947796,
127.03302573095146
],
"elapsed_times": [
1.4798673907915751,
2.206452254454295,
2.927223034699758,
3.6552630066871643,
4.378844745953878,
5.54952247540156,
6.71799882253011,
7.778662610054016
],
"learning_rates": [
0.024,
0.024,
0.024,
0.024,
0.024,
0.024,
0.024,
0.024
]
}
}
Binary file added plots/doc_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
77 changes: 77 additions & 0 deletions plots/flex-attention.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
{
"final_metrics": {
"val_loss": 4.847935905456543,
"val_accuracy": 0.2435326086956522,
"val_perplexity": 127.47699353532882,
"train_loss": 4.811601161956787
},
"setup_time_seconds": 14.949411869049072,
"active_training_time_seconds": 403.03150844573975,
"total_wall_time_seconds": 417.9809203147888,
"total_time_minutes": 6.966348671913147,
"actual_steps": 489,
"tokens_seen": 8011776,
"train_tokens": 8000000,
"history": {
"steps": [
0,
50,
100,
150,
200,
300,
400,
489
],
"val_losses": [
10.871569671630859,
6.514704399108886,
6.1700096559524535,
5.807818622589111,
5.547788681983948,
5.20478675365448,
4.9944330596923825,
4.847935905456543
],
"val_accuracies": [
0.008136907669760626,
0.14628358573522227,
0.1555410356619443,
0.18063385442110405,
0.19776624328285294,
0.21926538837322912,
0.2330013434294089,
0.2435326086956522
],
"val_perplexities": [
52657.800889571496,
674.994403082745,
478.1907234538577,
332.8921696210005,
256.6693503371409,
182.14202750539155,
147.58924736652187,
127.47699353532882
],
"elapsed_times": [
0.36853692134221394,
1.109649940331777,
1.8349166552225749,
2.562107467651367,
3.2909557660420736,
4.46957155863444,
5.64816772143046,
6.717189677556356
],
"learning_rates": [
0.024,
0.024,
0.024,
0.024,
0.024,
0.024,
0.024,
0.024
]
}
}
Binary file added plots/flex-attention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading