diff --git a/models/layers.py b/models/layers.py index 2b903501..de2a437d 100644 --- a/models/layers.py +++ b/models/layers.py @@ -3,7 +3,10 @@ import torch.nn.functional as F from torchtune.modules import RotaryPositionalEmbeddings from .components import SquaredReLUFeedForward - +from torch.nn.attention.flex_attention import ( + flex_attention, + BlockMask, +) class Rotary(nn.Module): def __init__(self, dim: int, max_seq_len: int): @@ -34,7 +37,10 @@ def __init__( self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads self.num_key_value_groups = self.n_heads // self.n_kv_heads self.d_k = d_model // n_heads - + + # Sparse attention gate + self.atten_gate = nn.Linear(12, n_heads, bias=False) + nn.init.zeros_(self.atten_gate.weight) # ============ MERGED QKVO PROJECTION ============ # Instead of 4 separate Linear layers, use single merged projection q_size = d_model @@ -62,7 +68,7 @@ def __init__( self.rotary = Rotary(self.d_k, max_seq_len) self.dropout = dropout - def forward(self, x): + def forward(self, x, mask: BlockMask): batch_size, seq_len = x.size(0), x.size(1) # ============ MERGED QKV PROJECTION ============ @@ -91,15 +97,19 @@ def forward(self, x): Q, K, V = Q.transpose(1, 2), K.transpose(1, 2), V.transpose(1, 2) # Compute attention - attn_output = F.scaled_dot_product_attention( - Q, K, V, is_causal=True, dropout_p=self.dropout if self.training else 0.0 - ) + attn_output = flex_attention(Q, K, V, block_mask=mask) # Reshape output attn_output = attn_output.transpose(1, 2).reshape( + batch_size, seq_len, self.n_heads, self.d_k + ) + # Sparse attention + attn_output = attn_output * F.sigmoid( + self.atten_gate(x[..., : self.atten_gate.in_features]) + ).view(batch_size, seq_len, self.n_heads, 1) + attn_output = attn_output.contiguous().reshape( batch_size, seq_len, self.d_model ) - # ============ MERGED O PROJECTION ============ # Use the last part of qkvo_proj for output projection return F.linear(attn_output, self.qkvo_proj[self.qkv_size:]) @@ -127,9 +137,9 @@ def __init__( self.norm2 = nn.RMSNorm(d_model) self.dropout = nn.Dropout(dropout) - def forward(self, x): + def forward(self, x, mask: BlockMask): # Self-attention - attn_out = self.attention(self.norm1(x)) + attn_out = self.attention(self.norm1(x), mask) x = x + self.dropout(attn_out) # Feed-forward diff --git a/models/llm.py b/models/llm.py index 9e887372..e8210653 100644 --- a/models/llm.py +++ b/models/llm.py @@ -4,14 +4,16 @@ from typing import Optional from configs.llm_config import BlueberryConfig from models.layers import TransformerBlock +from torch.nn.attention.flex_attention import and_masks, create_block_mask class MinimalLLM(nn.Module): """Minimal dense LLM""" - def __init__(self, config: BlueberryConfig): + def __init__(self, config: BlueberryConfig, eos_token: int | None = None): super().__init__() self.config = config + self.eos_token = eos_token # Token embeddings self.token_embedding = nn.Embedding(config.vocab_size, config.d_model) @@ -50,14 +52,31 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + def gen_mask(self, x): + doc = (x == self.eos_token).cumsum(dim=1) + + def document_mask(b, h, q_idx, kv_idx): + return doc[b, q_idx] == doc[b, kv_idx] + + def casual_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + return and_masks(casual_mask, document_mask) + def forward(self, x): + batch_size, seq_len = x.size(0), x.size(1) + mask_mod = self.gen_mask(x) + mask = create_block_mask( + mask_mod, batch_size, None, seq_len, seq_len, _compile=True + ) + # Token embeddings x = self.token_embedding(x) * math.sqrt(self.config.d_model) x = self.position_dropout(x) # Pass through transformer blocks for block in self.transformer_blocks: - x = block(x) + x = block(x, mask) # Output projection x = self.norm(x) diff --git a/plots/baseline.json b/plots/baseline.json new file mode 100644 index 00000000..c0b8b72a --- /dev/null +++ b/plots/baseline.json @@ -0,0 +1,77 @@ +{ + "final_metrics": { + "val_loss": 4.892065734863281, + "val_accuracy": 0.24040852467024915, + "val_perplexity": 133.22850475359996, + "train_loss": 4.8399858474731445 + }, + "setup_time_seconds": 12.699106931686401, + "active_training_time_seconds": 396.8685128688812, + "total_wall_time_seconds": 409.5676198005676, + "total_time_minutes": 6.826126996676127, + "actual_steps": 489, + "tokens_seen": 8011776, + "train_tokens": 8000000, + "history": { + "steps": [ + 0, + 50, + 100, + 150, + 200, + 300, + 400, + 489 + ], + "val_losses": [ + 10.885406560897827, + 6.553639316558838, + 6.228558578491211, + 5.9009589147567745, + 5.614467568397522, + 5.255833668708801, + 5.036515588760376, + 4.892065734863281 + ], + "val_accuracies": [ + 0.008136907669760626, + 0.14678248656570592, + 0.1532230092818759, + 0.1738275525158769, + 0.1929457743038593, + 0.21529311187103078, + 0.2301080850024426, + 0.24040852467024915 + ], + "val_perplexities": [ + 53391.48529861793, + 701.7935809523002, + 507.02412104808775, + 365.3876755628698, + 274.3672585000841, + 191.68121789674856, + 153.93271460078893, + 133.22850475359996 + ], + "elapsed_times": [ + 0.4335458437601725, + 1.1485339283943177, + 1.8589276591936748, + 2.5722230354944866, + 3.285143864154816, + 4.4299559712409975, + 5.5750153183937075, + 6.614473648866018 + ], + "learning_rates": [ + 0.024, + 0.024, + 0.024, + 0.024, + 0.024, + 0.024, + 0.024, + 0.024 + ] + } +} \ No newline at end of file diff --git a/plots/baseline.png b/plots/baseline.png new file mode 100644 index 00000000..f58bd6a8 Binary files /dev/null and b/plots/baseline.png differ diff --git a/plots/doc_mask.json b/plots/doc_mask.json new file mode 100644 index 00000000..4c7e3639 --- /dev/null +++ b/plots/doc_mask.json @@ -0,0 +1,77 @@ +{ + "final_metrics": { + "val_loss": 4.84444709777832, + "val_accuracy": 0.24380373717635564, + "val_perplexity": 127.03302573095146, + "train_loss": 4.802108287811279 + }, + "setup_time_seconds": 78.49195218086243, + "active_training_time_seconds": 466.7198405265808, + "total_wall_time_seconds": 545.2117927074432, + "total_time_minutes": 9.086863211790721, + "actual_steps": 489, + "tokens_seen": 8011776, + "train_tokens": 8000000, + "history": { + "steps": [ + 0, + 50, + 100, + 150, + 200, + 300, + 400, + 489 + ], + "val_losses": [ + 10.871560344696045, + 6.515311102867127, + 6.170207910537719, + 5.811286811828613, + 5.548059434890747, + 5.207452573776245, + 4.995285930633545, + 4.84444709777832 + ], + "val_accuracies": [ + 0.008136907669760626, + 0.14623900830483635, + 0.1556185881778212, + 0.1805276013678554, + 0.1974218368343918, + 0.2186675622862726, + 0.23282730825598436, + 0.24380373717635564 + ], + "val_perplexities": [ + 52657.30975598555, + 675.4040489781731, + 478.28553635564975, + 334.04870704765545, + 256.7388537185275, + 182.62823316781476, + 147.71517563947796, + 127.03302573095146 + ], + "elapsed_times": [ + 1.4798673907915751, + 2.206452254454295, + 2.927223034699758, + 3.6552630066871643, + 4.378844745953878, + 5.54952247540156, + 6.71799882253011, + 7.778662610054016 + ], + "learning_rates": [ + 0.024, + 0.024, + 0.024, + 0.024, + 0.024, + 0.024, + 0.024, + 0.024 + ] + } +} \ No newline at end of file diff --git a/plots/doc_mask.png b/plots/doc_mask.png new file mode 100644 index 00000000..ffb94384 Binary files /dev/null and b/plots/doc_mask.png differ diff --git a/plots/flex-attention.json b/plots/flex-attention.json new file mode 100644 index 00000000..26df2f3e --- /dev/null +++ b/plots/flex-attention.json @@ -0,0 +1,77 @@ +{ + "final_metrics": { + "val_loss": 4.847935905456543, + "val_accuracy": 0.2435326086956522, + "val_perplexity": 127.47699353532882, + "train_loss": 4.811601161956787 + }, + "setup_time_seconds": 14.949411869049072, + "active_training_time_seconds": 403.03150844573975, + "total_wall_time_seconds": 417.9809203147888, + "total_time_minutes": 6.966348671913147, + "actual_steps": 489, + "tokens_seen": 8011776, + "train_tokens": 8000000, + "history": { + "steps": [ + 0, + 50, + 100, + 150, + 200, + 300, + 400, + 489 + ], + "val_losses": [ + 10.871569671630859, + 6.514704399108886, + 6.1700096559524535, + 5.807818622589111, + 5.547788681983948, + 5.20478675365448, + 4.9944330596923825, + 4.847935905456543 + ], + "val_accuracies": [ + 0.008136907669760626, + 0.14628358573522227, + 0.1555410356619443, + 0.18063385442110405, + 0.19776624328285294, + 0.21926538837322912, + 0.2330013434294089, + 0.2435326086956522 + ], + "val_perplexities": [ + 52657.800889571496, + 674.994403082745, + 478.1907234538577, + 332.8921696210005, + 256.6693503371409, + 182.14202750539155, + 147.58924736652187, + 127.47699353532882 + ], + "elapsed_times": [ + 0.36853692134221394, + 1.109649940331777, + 1.8349166552225749, + 2.562107467651367, + 3.2909557660420736, + 4.46957155863444, + 5.64816772143046, + 6.717189677556356 + ], + "learning_rates": [ + 0.024, + 0.024, + 0.024, + 0.024, + 0.024, + 0.024, + 0.024, + 0.024 + ] + } +} \ No newline at end of file diff --git a/plots/flex-attention.png b/plots/flex-attention.png new file mode 100644 index 00000000..d23f4c33 Binary files /dev/null and b/plots/flex-attention.png differ diff --git a/plots/sparse-attention.json b/plots/sparse-attention.json new file mode 100644 index 00000000..91aaeb15 --- /dev/null +++ b/plots/sparse-attention.json @@ -0,0 +1,77 @@ +{ + "final_metrics": { + "val_loss": 4.851117086410523, + "val_accuracy": 0.24334147532975087, + "val_perplexity": 127.88316663175688, + "train_loss": 4.811080455780029 + }, + "setup_time_seconds": 13.218045234680176, + "active_training_time_seconds": 399.6598746776581, + "total_wall_time_seconds": 412.87791991233826, + "total_time_minutes": 6.881298665205637, + "actual_steps": 489, + "tokens_seen": 8011776, + "train_tokens": 8000000, + "history": { + "steps": [ + 0, + 50, + 100, + 150, + 200, + 300, + 400, + 489 + ], + "val_losses": [ + 10.871553955078125, + 6.513376259803772, + 6.171892399787903, + 5.809717321395874, + 5.546902008056641, + 5.207251658439636, + 4.99858512878418, + 4.851117086410523 + ], + "val_accuracies": [ + 0.008136907669760626, + 0.1463873961895457, + 0.1553767708842208, + 0.18037799218368344, + 0.19762884709330727, + 0.21875488519785052, + 0.2327076209086468, + 0.24334147532975087 + ], + "val_perplexities": [ + 52656.97329697048, + 674.0985115513492, + 479.0918821499028, + 333.5248320138406, + 256.4418691819663, + 182.59154404070117, + 148.20332207690265, + 127.88316663175688 + ], + "elapsed_times": [ + 0.456688125928243, + 1.1772284309069316, + 1.8894020279248556, + 2.605119510491689, + 3.3206779837608336, + 4.469060146808625, + 5.617923351128896, + 6.660996170838674 + ], + "learning_rates": [ + 0.024, + 0.024, + 0.024, + 0.024, + 0.024, + 0.024, + 0.024, + 0.024 + ] + } +} \ No newline at end of file diff --git a/plots/sparse-attention.png b/plots/sparse-attention.png new file mode 100644 index 00000000..c6736536 Binary files /dev/null and b/plots/sparse-attention.png differ diff --git a/train_llm.py b/train_llm.py index 323bf623..6afd3b00 100644 --- a/train_llm.py +++ b/train_llm.py @@ -357,6 +357,7 @@ def main(): val_loader, output_dir=output_dir, load_weights_path=args.load_checkpoint, + eos_token=tokenizer.eos_token_id, ) diff --git a/training/trainer.py b/training/trainer.py index 159d023d..3d6d2651 100644 --- a/training/trainer.py +++ b/training/trainer.py @@ -432,6 +432,7 @@ def train_minimal_llm( output_dir: Optional[str] = None, load_weights_path: Optional[str] = None, compare_baseline: bool = False, + eos_token: int | None = None, ): print(f"\nšŸš€ Training dense model") setup_start = time.time() @@ -441,7 +442,7 @@ def train_minimal_llm( # 1. Initialize model with fixed seed # ============================================ set_seed(42) - model = MinimalLLM(config) + model = MinimalLLM(config, eos_token) model = model.to(device) # Load pretrained weights if specified