From 90174b96fae7e3e17d4703c1d73159c7f3b27478 Mon Sep 17 00:00:00 2001 From: tomiock Date: Tue, 7 Oct 2025 13:38:18 +0000 Subject: [PATCH 01/18] [WIP] state_dict loaded but not generation --- torchtitan/generate.py | 1 + torchtitan/models/llama3/model/args.py | 2 +- torchtitan/train.py | 2 +- torchtitan/vlr/smolvlm/__init__.py | 5 +- .../vlr/smolvlm/datasets/mm_datasets.py | 3 +- torchtitan/vlr/smolvlm/model/args.py | 2 +- torchtitan/vlr/smolvlm/model/model.py | 9 ++- torchtitan/vlr/smolvlm/model/siglip2.py | 4 +- .../vlr/smolvlm/model/state_dict_adapter.py | 62 +++++++++++++++---- .../train_configs/llama_siglip_256.toml | 15 ++--- 10 files changed, 75 insertions(+), 30 deletions(-) diff --git a/torchtitan/generate.py b/torchtitan/generate.py index 475e9467..f2e7f2a3 100644 --- a/torchtitan/generate.py +++ b/torchtitan/generate.py @@ -262,6 +262,7 @@ def generate( # Decode output generated_ids = output_ids[0, input_ids.shape[1]:] + print(generated_ids.v) generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) return generated_text diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 3273605d..f71454e5 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -29,7 +29,7 @@ class TransformerModelArgs(BaseModelArgs): norm_eps: float = 1e-5 rope_theta: float = 10000 - ffn_dim: int = 8192 + ffn_dim: int = 1536 max_seq_len: int = 131072 # If `True`, then each transformer block init uses its layer ID, and if diff --git a/torchtitan/train.py b/torchtitan/train.py index e909964e..57e775ff 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -551,8 +551,8 @@ def train_step( def train(self): job_config = self.job_config - self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Training starts at step {self.step + 1}") + self.checkpointer.load(step=job_config.checkpoint.load_step) leaf_folder = ( "" diff --git a/torchtitan/vlr/smolvlm/__init__.py b/torchtitan/vlr/smolvlm/__init__.py index a4147920..a0714c48 100644 --- a/torchtitan/vlr/smolvlm/__init__.py +++ b/torchtitan/vlr/smolvlm/__init__.py @@ -16,6 +16,7 @@ # from .infra.pipeline import pipeline_llama from .model.args import Llama3Siglip2ModelArgs, Siglip2ModelArgs from .model.model import Llama3Siglip2Transformer +from .model.state_dict_adapter import SmolVLMStateDictAdapter __all__ = [ "parallelize_vlm", @@ -35,7 +36,7 @@ ), "256M": Siglip2ModelArgs( dim=768, - ffn_dim=2304, + ffn_dim=3072, n_layers=12, n_heads=12, ) @@ -87,6 +88,6 @@ build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, build_validator_fn=build_validator, - # state_dict_adapter=Llama3StateDictAdapter, + state_dict_adapter=SmolVLMStateDictAdapter, ) ) diff --git a/torchtitan/vlr/smolvlm/datasets/mm_datasets.py b/torchtitan/vlr/smolvlm/datasets/mm_datasets.py index 952d5005..840de827 100644 --- a/torchtitan/vlr/smolvlm/datasets/mm_datasets.py +++ b/torchtitan/vlr/smolvlm/datasets/mm_datasets.py @@ -440,5 +440,6 @@ def build_mm_dataloader( ) for sample in dataset: - print(sample) + #print(sample) + print(sample['input_ids'].v) exit() diff --git a/torchtitan/vlr/smolvlm/model/args.py b/torchtitan/vlr/smolvlm/model/args.py index 4386619c..8357ad29 100644 --- a/torchtitan/vlr/smolvlm/model/args.py +++ b/torchtitan/vlr/smolvlm/model/args.py @@ -22,7 +22,7 @@ class Siglip2ModelArgs: patch_size: int = 16 image_size: int = 512 - scale_factor: int = 2 + scale_factor: int = 4 layer_norm_eps: float = 1e-6 use_flex_attn: bool = True diff --git a/torchtitan/vlr/smolvlm/model/model.py b/torchtitan/vlr/smolvlm/model/model.py index 042b2f0a..08ca6776 100644 --- a/torchtitan/vlr/smolvlm/model/model.py +++ b/torchtitan/vlr/smolvlm/model/model.py @@ -14,13 +14,16 @@ from .args import Llama3Siglip2ModelArgs, Siglip2ModelArgs from .siglip2 import VisionTransformer +import lovely_tensors as lt +lt.monkey_patch() + class SmolVLMSimpleMLP(nn.Module): def __init__(self, config): super().__init__() # TODO: scale_factor to config - input_size = config.encoder.dim * (config.encoder.scale_factor**2) + input_size = 12288 output_size = config.dim - self.proj = nn.Linear(input_size, output_size, bias=False) + self.proj = nn.Linear(12288, 576, bias=False) def init_weights(self): nn.init.trunc_normal_(self.proj.weight, mean=0.0, std=0.02) @@ -35,7 +38,7 @@ def __init__(self, config): self.scale_factor = config.encoder.scale_factor self.modality_projection = SmolVLMSimpleMLP(config) - def pixel_shuffle(self, x, scale_factor=2): + def pixel_shuffle(self, x, scale_factor=4): bsz, seq, embed_dim = x.size() height = width = int(seq**0.5) x = x.view(bsz, height, width, embed_dim) diff --git a/torchtitan/vlr/smolvlm/model/siglip2.py b/torchtitan/vlr/smolvlm/model/siglip2.py index e94ec5f0..f66fe290 100644 --- a/torchtitan/vlr/smolvlm/model/siglip2.py +++ b/torchtitan/vlr/smolvlm/model/siglip2.py @@ -137,8 +137,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class FeedForward(nn.Module): def __init__(self, args: Siglip2ModelArgs): super().__init__() - self.fc1 = nn.Linear(args.dim, args.ffn_dim) - self.fc2 = nn.Linear(args.ffn_dim, args.dim) + self.fc1 = nn.Linear(args.dim, args.ffn_dim, bias=True) + self.fc2 = nn.Linear(args.ffn_dim, args.dim, bias=True) self.act_fn = PytorchGELUTanh() def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/torchtitan/vlr/smolvlm/model/state_dict_adapter.py b/torchtitan/vlr/smolvlm/model/state_dict_adapter.py index c994a869..fa353c94 100644 --- a/torchtitan/vlr/smolvlm/model/state_dict_adapter.py +++ b/torchtitan/vlr/smolvlm/model/state_dict_adapter.py @@ -26,19 +26,57 @@ def __init__( self.model_args = model_args self.hf_assets_path = hf_assets_path self.from_hf_map = { - "model.embed_tokens.weight": "tok_embeddings.weight", - "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", - "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", - "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", - "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - "model.layers.{}.self_attn.rotary_emb.inv_freq": None, - "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", - "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", - "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", - "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", - "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", - "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", + + "model.text_model.embed_tokens.weight": "tok_embeddings.weight", # check + + "model.text_model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", # check + "model.text_model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", # check + "model.text_model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", # check + "model.text_model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", # check + + #"model.layers.{}.self_attn.rotary_emb.inv_freq": None, + + "model.text_model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.gate_proj.weight", # check + "model.text_model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.up_proj.weight", # check + "model.text_model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.down_proj.weight", # check + + "model.text_model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", # check + "model.text_model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", # check + + "model.text_model.norm.weight": "norm.weight", # check + + "model.vision_model.embeddings.patch_embedding.weight": "encoder.embeddings.patch_embedding.weight", + "model.vision_model.embeddings.patch_embedding.bias": "encoder.embeddings.patch_embedding.bias", + + "model.vision_model.embeddings.position_embedding.weight": "encoder.embeddings.position_embedding.weight", + + "model.vision_model.post_layernorm.weight": "encoder.post_layernorm.weight", + "model.vision_model.post_layernorm.bias": "encoder.post_layernorm.bias", + + "model.vision_model.encoder.layers.{}.layer_norm1.weight": "encoder.layers.{}.layer_norm1.weight", + "model.vision_model.encoder.layers.{}.layer_norm1.bias": "encoder.layers.{}.layer_norm1.bias", + "model.vision_model.encoder.layers.{}.layer_norm2.weight": "encoder.layers.{}.layer_norm2.weight", + "model.vision_model.encoder.layers.{}.layer_norm2.bias": "encoder.layers.{}.layer_norm2.bias", + + "model.vision_model.encoder.layers.{}.mlp.fc1.weight": "encoder.layers.{}.mlp.fc1.weight", + "model.vision_model.encoder.layers.{}.mlp.fc1.bias": "encoder.layers.{}.mlp.fc1.bias", + "model.vision_model.encoder.layers.{}.mlp.fc2.weight": "encoder.layers.{}.mlp.fc2.weight", + "model.vision_model.encoder.layers.{}.mlp.fc2.bias": "encoder.layers.{}.mlp.fc2.bias", + + "model.vision_model.encoder.layers.{}.self_attn.k_proj.weight": "encoder.layers.{}.self_attn.k_proj.weight", + "model.vision_model.encoder.layers.{}.self_attn.k_proj.bias": "encoder.layers.{}.self_attn.k_proj.bias", + + "model.vision_model.encoder.layers.{}.self_attn.out_proj.weight": "encoder.layers.{}.self_attn.out_proj.weight", + "model.vision_model.encoder.layers.{}.self_attn.out_proj.bias": "encoder.layers.{}.self_attn.out_proj.bias", + + "model.vision_model.encoder.layers.{}.self_attn.q_proj.weight": "encoder.layers.{}.self_attn.q_proj.weight", + "model.vision_model.encoder.layers.{}.self_attn.q_proj.bias": "encoder.layers.{}.self_attn.q_proj.bias", + + "model.vision_model.encoder.layers.{}.self_attn.v_proj.weight": "encoder.layers.{}.self_attn.v_proj.weight", + "model.vision_model.encoder.layers.{}.self_attn.v_proj.bias": "encoder.layers.{}.self_attn.v_proj.bias", + + "model.connector.modality_projection.proj.weight": "projector.modality_projection.proj.weight", } # HuggingFace permutation function (exact copy from their conversion script) diff --git a/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml b/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml index b3afdd29..1630b1f0 100644 --- a/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml +++ b/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml @@ -4,7 +4,7 @@ custom_args_module = "torchtitan.vlr.smolvlm.assets.job_config" [job] -dump_folder = "./outputs_large" +dump_folder = "./outputs" description = "Llama 3 Siglip2 VLM training" print_args = false @@ -27,7 +27,7 @@ name = "llama3-siglip2" flavor = "256M" # test folder with tokenizer.json, for debug purpose only # hf_assets_path = "torchtitan/experiments/vlm/assets/tokenizer" -hf_assets_path = "./assets/hf/SmolLM2-360M-Instruct" +hf_assets_path = "./assets/hf/SmolVLM2-256M-Video-Instruct" # converters = ["float8"] [optimizer] @@ -42,8 +42,8 @@ decay_type = "cosine" min_lr_factor = 0.0 [training] -local_batch_size = 9 -seq_len = 2048 +local_batch_size = 2 +seq_len = 1048 # packing_buffer_size = 100 max_norm = 1.0 # grad norm clipping steps = 13100 @@ -72,10 +72,11 @@ enable = true folder = "checkpoint" interval = 50 last_save_model_only = false -#initial_load_in_hf = true -#last_save_in_hf = true -export_dtype = "bfloat16" +initial_load_in_hf = true +last_save_in_hf = true +export_dtype = "float32" async_mode = "async" # ["disabled", "async", "async_with_pinned_mem"] +exclude_from_loading = ["dataloader", "optimizer", "train_state"] [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] From a6f0cfc39a37245e43591cfc78cc7ef5911dff70 Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Wed, 8 Oct 2025 10:19:37 -0700 Subject: [PATCH 02/18] [DSV3] Offload dequantization process to DCP QuantizedHFReader (#1804) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Benchmarking
Step | time | log -- | -- | -- to_hf() | 0.1103s | [trainer0\|0]:[titan] 2025-10-03 17:07:45,697 - root - INFO - Completed to_hf conversion, generated 189 keys, duration: 0.1103s Split local GroupedExperts DTensor to individual experts’ weight | 0.008 s per layer per matrix (total 58 MoE Layers * 3 weight matrices per layer) | [trainer0\|0]:[titan] 2025-10-03 17:07:45,697 - root - INFO - Completed _get_local_experts_weights for layer 6, abstract_key: model.layers.{}.mlp.experts.{}.up_proj.weight, duration: 0.0082s dcp.load()Threads count=4 | 193.20s | [trainer0\|0]:[titan] 2025-10-03 17:10:58,899 - root - INFO - dcp.load with HuggingFaceStorageReader completed in 193.20 seconds from_hf() | 0.48s | [trainer0\|0]:[titan] 2025-10-03 17:10:59,378 - root - INFO - Completed from_hf conversion, processed 189 keys, duration: 0.4787s Concatenate individual experts weight into GroupedExperts weight | 0.01s per layer per matrix (total 58 MoE Layers * 3 weight matrices) | [trainer0\|0]:[titan] 2025-10-03 17:10:59,120 - root - INFO - Completed _concatenate_expert_weights_dtensor for layer 5, abstract_key: layers.{}.moe.experts.w2, duration: 0.0142s Total | 193.87s | [trainer0\|0]:[titan] 2025-10-03 17:10:59,458 - root - INFO - Finished loading the checkpoint in 193.87 seconds.
## End-to-End verification for 671B model Parallelsim: FSDP=32, PP=8, 1F1B, EP=32 Screenshot 2025-10-06 at 8 32 37 PM Screenshot 2025-10-06 at 8 32 54 PM --- torchtitan/components/checkpoint.py | 24 ++++-- torchtitan/config/job_config.py | 8 ++ torchtitan/models/deepseek_v3/__init__.py | 2 +- .../models/deepseek_v3/model/quantization.py | 73 ---------------- .../deepseek_v3/model/state_dict_adapter.py | 86 ++++++------------- .../train_configs/deepseek_v3_671b.toml | 2 +- torchtitan/protocols/state_dict_adapter.py | 29 ++++++- 7 files changed, 82 insertions(+), 142 deletions(-) delete mode 100644 torchtitan/models/deepseek_v3/model/quantization.py diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 1b25aa3f..e2e643db 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -19,10 +19,7 @@ import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn -from torch.distributed.checkpoint import ( - HuggingFaceStorageReader, - HuggingFaceStorageWriter, -) +from torch.distributed.checkpoint import HuggingFaceStorageWriter from torch.distributed.checkpoint._consolidate_hf_safetensors import ( consolidate_safetensors_files_on_every_rank, ) @@ -249,6 +246,9 @@ def load_state_dict(state_dict): self.initial_load_model_only = checkpoint_config.initial_load_model_only self.initial_load_in_hf = checkpoint_config.initial_load_in_hf self.initial_load_path = checkpoint_config.initial_load_path + self.initial_load_in_hf_quantized = ( + checkpoint_config.initial_load_in_hf_quantized + ) self.last_save_model_only = checkpoint_config.last_save_model_only self.last_save_in_hf = checkpoint_config.last_save_in_hf if self.last_save_in_hf: @@ -418,6 +418,7 @@ def dcp_load( state_dict: dict[str, Any], checkpoint_id: str, from_hf: bool, + from_quantized: bool, ) -> None: """Load the checkpoint with dcp. Args: @@ -432,10 +433,13 @@ def dcp_load( self.sd_adapter is not None ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided." hf_state_dict = self.sd_adapter.to_hf(state_dict) + hf_storage_reader = self.sd_adapter.get_hf_storage_reader( + checkpoint_id, from_quantized + ) dcp.load( hf_state_dict, - storage_reader=HuggingFaceStorageReader(path=checkpoint_id), + storage_reader=hf_storage_reader, ) state_dict = self.sd_adapter.from_hf(hf_state_dict) @@ -544,13 +548,21 @@ def load(self, step: int = -1) -> bool: model_only = False from_hf = False + from_quantized = False if not os.path.exists(self.folder): model_only = self.initial_load_model_only from_hf = self.initial_load_in_hf + from_quantized = self.initial_load_in_hf_quantized if from_hf: assert ( model_only ), "Only model can be loaded when loading from HF's safetensors checkpoint." + + if from_quantized: + assert ( + from_hf + ), "Quantized checkpoint can only be loaded from HuggingFace format." + if self.initial_load_path: checkpoint_id = self.initial_load_path if not os.path.isdir(checkpoint_id): @@ -602,6 +614,7 @@ def load(self, step: int = -1) -> bool: states, checkpoint_id=checkpoint_id, from_hf=from_hf, + from_quantized=from_quantized, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( @@ -679,6 +692,7 @@ def _ft_load(self) -> None: checkpoint_id=checkpoint_id, # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader. from_hf=False, + from_quantized=False, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index eb477941..2f200ef1 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -453,6 +453,14 @@ class Checkpoint: non-tensors. The default value is False. """ + initial_load_in_hf_quantized: bool = False + """ + Enable loading of HuggingFace's safetensors format with quantized state dict keys. The option + is only used when `initial_load_path` and `initial_load_path_in_hf` is specified. This will load + checkpoints in HF's model definition and dequantize on model weights if necessary. To support + this parameter, the model need to define proper HuggingFaceStorageReader to perform dequantize. + """ + last_save_model_only: bool = True """ When last_save_model_only=True, only the model will be saved at the end of training, diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index a290ea7e..5125a790 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -134,7 +134,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=61, + n_layers=4, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( diff --git a/torchtitan/models/deepseek_v3/model/quantization.py b/torchtitan/models/deepseek_v3/model/quantization.py deleted file mode 100644 index a8ac6003..00000000 --- a/torchtitan/models/deepseek_v3/model/quantization.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torchtitan.tools.logging import logger - -# Fixed block size of 128x128 as specified in the algorithm -BLOCK_SIZE = 128 - - -def calculate_scale_shape( - weight: torch.Tensor, BLOCK_SIZE: int = BLOCK_SIZE -) -> torch.Size: - # Calculate the scale tensor shape - orig_shape = weight.shape - - # Calculate number of blocks needed - block_rows = (orig_shape[0] + BLOCK_SIZE - 1) // BLOCK_SIZE - block_cols = (orig_shape[1] + BLOCK_SIZE - 1) // BLOCK_SIZE - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = torch.Size((block_rows, block_cols)) - - return expected_scale_shape - - -def dequantize_from_fp8( - weight: torch.Tensor, - scale_inv: torch.Tensor, - dtype=torch.bfloat16, - BLOCK_SIZE: int = BLOCK_SIZE, -) -> torch.Tensor: - # Convert to float32 for computation - float_weight = weight.to(torch.float32) - # Get original dimensions - orig_shape = weight.shape - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = calculate_scale_shape(weight, BLOCK_SIZE) - block_rows, block_cols = expected_scale_shape - if scale_inv.shape != expected_scale_shape: - logger.warning( - f"scale_inv shape {scale_inv.shape} doesn't match expected shape {expected_scale_shape}" - ) - - # NOTE: When processing large models on-the-fly, misalignment between block boundaries - # and DTensor local shape partitioning can lead to silent numerical inaccuracies. - dequantized = float_weight.detach().clone().to(dtype=dtype) - - # Apply scaling factors to each block - for i in range(block_rows): - row_start = i * BLOCK_SIZE - row_end = min(row_start + BLOCK_SIZE, orig_shape[0]) - - for j in range(block_cols): - col_start = j * BLOCK_SIZE - col_end = min(col_start + BLOCK_SIZE, orig_shape[1]) - - # Get the block - block = float_weight[row_start:row_end, col_start:col_end] - - scale = scale_inv[i, j] - block = block * scale - - # Explicitly convert block to dtype - block_converted = block.to(dtype=torch.float32) - # Store the dequantized block - dequantized[row_start:row_end, col_start:col_end] = block_converted - - return dequantized diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index b366910f..11d54ffb 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -8,13 +8,14 @@ import re from typing import Any +import torch +from torch.distributed.checkpoint import HuggingFaceStorageReader + from torch.distributed.tensor import DTensor from torchtitan.models.utils import MoEStateDictAdapter from .args import DeepSeekV3ModelArgs -from .quantization import calculate_scale_shape, dequantize_from_fp8 - class DeepSeekV3StateDictAdapter(MoEStateDictAdapter): """ @@ -70,60 +71,33 @@ def __init__( } ) - def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: """ - Dequantize the weights from float8 to float32. + Override default get_hf_storage_reader function to return QuantizedHFStorageReader. """ + if from_quantized: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) - scale_inv_keys = [] - for key, weight in state_dict.items(): - if key.endswith(".weight") and key + "_scale_inv" in state_dict: - scale_inv = state_dict[key + "_scale_inv"] - dequantized_weight = dequantize_from_fp8( - weight, scale_inv, dtype=torch.float32 - ) - # update the weight and remove the scale_inv tensor - state_dict[key] = dequantized_weight - scale_inv_keys.append(key + "_scale_inv") - - for key in scale_inv_keys: - state_dict.pop(key) - - return state_dict - - def _add_quantization_scale_inv_tensors( - self, state_dict: dict[str, Any] - ) -> dict[str, Any]: - """ - Add quantization scale tensors the state_dict. - """ - non_quantized_keys = [ - "input_layernorm.weight", - "post_attention_layernorm.weight", - "norm.weight", - "lm_head.weight", - "embed_tokens.weight", - "mlp.gate.weight", - ] - - weight_scale_inv_state_dict = {} - for key, value in state_dict.items(): - if key.endswith(".weight") and not any( - non_quantized_key in key for non_quantized_key in non_quantized_keys - ): - expected_scale_shape = calculate_scale_shape(value) - # add weight_scale_inv to the state_dict - weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones( - expected_scale_shape, dtype=torch.float32 - ) - - state_dict.update(weight_scale_inv_state_dict) - return state_dict + # NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model. + # If loading checkpoints without quantization, use HuggingFaceStorageReader instead + BLOCK_SIZE = 128 + return QuantizedHuggingFaceStorageReader( + path=path, + target_dtype=torch.float32, + block_size=BLOCK_SIZE, + thread_count=4, + ) + else: + return HuggingFaceStorageReader(path) def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ 1. Convert between the HF shape and the torchtitan shape. - 2. Split the GroupedExperts' weight into separate expert's wegiht. + 2. Split the GroupedExperts' weight into separate expert's weight. """ to_hf_map = {v: k for k, v in self.from_hf_map.items()} @@ -172,24 +146,16 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: new_key = to_hf_map[key] hf_state_dict[new_key] = value - # Prepare for dequantization - hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( - hf_state_dict - ) - return hf_state_dict_with_scale_inv + return hf_state_dict def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ 1. When loading from HF checkpoint, dequantize the weights from float8 to float32. 2. Convert between the HF shape and the torchtitan shape. - 3. Concate separate expert's wegiht into GroupedExperts' weight. + 3. Concat separate expert's weight into GroupedExperts' weight. """ - # dequantize the tensor in state_dict and remove the scale_inv tensor - - hf_state_dict = self._dequantize(hf_state_dict) state_dict = {} - expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} for key, value in hf_state_dict.items(): @@ -215,7 +181,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: layer_num, value.device_mesh, ) - else: # keep this path to be compatibile with offline conversion + else: # keep this path to be compatible with offline conversion stacked_value = self._concatenate_expert_weights( expert_weights_by_layer, titan_abstract_key, diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index c6dee817..9d8625a2 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -65,7 +65,7 @@ mode = "selective" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] -enable=true +enable = true components = ["loss"] # ["model", "loss"] [quantize.linear.float8] diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index 5b441e9b..e22692bd 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -5,13 +5,14 @@ # LICENSE file in the root directory of this source tree. import json -import logging import os import re from abc import ABC, abstractmethod from typing import Any -logger = logging.getLogger() +from torch.distributed.checkpoint import HuggingFaceStorageReader + +from torchtitan.tools.logging import logger from .model import BaseModelArgs @@ -58,6 +59,21 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ pass + @abstractmethod + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: + """Returns hf storage reader to read HF checkpoint + + Args: + path: the path to read HF checkpoint + + Returns: + The HuggingFace storage reader to read from HF checkpoint + + """ + pass + class StateDictAdapter(BaseStateDictAdapter): """State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping""" @@ -86,3 +102,12 @@ def __init__( self.fqn_to_index_mapping[hf_key] = int(indx) else: self.fqn_to_index_mapping = None + + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: + if from_quantized: + logger.warning( + "Loading from quantized checkpoint format is not supported for this model." + ) + return HuggingFaceStorageReader(path) From 41eff5325b74a4a97c3849f2882ba823c2dbdd32 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 8 Oct 2025 10:45:40 -0700 Subject: [PATCH 03/18] Disable FlexAttention max-autotune when deterministic is used (#1808) With max-autotune, FlexAttention is not deterministic even if torch.use_deterministic_algorithms is True. When deterministic mode is set, we should also remove the usage of `max-autotune`. --- torchtitan/distributed/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index a2f1feb3..7bec252b 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -106,6 +106,14 @@ def set_determinism( # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + # Ensure flex_attention is compiled without max-autotune. This is needed to ensure + # reproducibility, since the autotune results may not be deterministic. + from torch.nn.attention.flex_attention import flex_attention + + from torchtitan.models.attention import FlexAttention + + FlexAttention.flex_attn = torch.compile(flex_attention) + if not world_mesh: if seed is not None: torch.manual_seed(seed) From 21739fdcfed53ba8a8ec2beaad4ffd69ecdb55c8 Mon Sep 17 00:00:00 2001 From: Tushar Jain <8455015+tushar00jain@users.noreply.github.com> Date: Wed, 8 Oct 2025 15:11:07 -0400 Subject: [PATCH 04/18] enhance profiler config (#1809) Summary: allow users to specify the profiler schedule --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1809). * #1811 * #1810 * #1812 * __->__ #1809 Co-authored-by: Tushar Jain --- torchtitan/config/job_config.py | 14 ++++++++++++++ torchtitan/tools/profiling.py | 10 +++++----- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 2f200ef1..b7cfcb11 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -34,6 +34,20 @@ class Profiling: profile_freq: int = 10 """How often to collect profile traces, in iterations""" + profiler_active: int = 1 + """ + The steps profiler is active for. + + This is used to configure torch.profile.schedule. + """ + + profiler_warmup: int = 3 + """ + The number of warmup steps before the active step in each profiling cycle. + + This is used to configure torch.profile.schedule. + """ + enable_memory_snapshot: bool = False """Whether to dump memory snapshot""" diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 0e851d33..f398dba9 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -14,9 +14,6 @@ from torchtitan.config import Profiling as ProfilingConfig from torchtitan.tools.logging import logger -# the number of warmup steps before the active step in each profiling cycle -WARMUP = 3 - # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 @@ -34,7 +31,11 @@ def maybe_enable_profiling( if enable_profiling: trace_dir = os.path.join(base_folder, profiling_config.save_traces_folder) - profile_freq = profiling_config.profile_freq + profile_freq, warmup, active = ( + profiling_config.profile_freq, + profiling_config.profiler_warmup, + profiling_config.profiler_active, + ) rank = torch.distributed.get_rank() @@ -58,7 +59,6 @@ def trace_handler(prof): if not os.path.exists(trace_dir): os.makedirs(trace_dir, exist_ok=True) - warmup, active = WARMUP, 1 wait = profile_freq - (active + warmup) assert ( wait >= 0 From f014f31c250817709210a72049527f34bc366893 Mon Sep 17 00:00:00 2001 From: Ruisi Zhang Date: Wed, 8 Oct 2025 18:10:13 -0700 Subject: [PATCH 05/18] [simplefsdp] fix simplefsdp gradient_divide_factor (#1793) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit this PR is a followup of SimpleFSDP+EP [PR](https://github.com/pytorch/torchtitan/pull/1529). Here, we add a `gradient_divide_factor` following FSDP2 to ensure modules wrapped by (FSDP+EP) has the correct gradient reduction value. - The original FSDP2 implementation is in this [PR](https://github.com/pytorch/torchtitan/pull/1551). - The `gradient_divide_factor` logic is [here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688) We have two ways of handling `gradient_divide_factor` in `reduce_scatter`: 1. The first one is to use `ReduceOp.PREMUL_SUM` to handle the `gradient_divide_factor`. However, DTensor's `_reduce_shard_value` only accepts `reduce_op` as a str input ([here](https://github.com/pytorch/pytorch/blob/8f705d019a64b1ca882e043b3eb98559273a9e59/torch/distributed/tensor/placement_types.py#L177-L210)). To make` _reduce_shard_value` work correctly with ReduceOp.PREMUL_SUM, we need to update the DTensor `_reduce_shard_tensor` and `torch.distributed._functional_collectives.reduce_scatter_tensor` so that it can pass the factor associated with ReduceOp.PREMUL_SUM as an input. 2. Another way is to simulate `ReduceOp.PREMUL_SUM` with `ReduceOp.SUM`. The logic is in this [Diff](https://www.internalfb.com/diff/D76546536). It does a `div_` over gradient before performing `ReduceOp.SUM`. Currently I'm following 2 since it is requires less change to `_functional_collectives`. After enabling `reduction_divide_factor`, we will see FSDP(=2) + EP (=4) have identical loss: Screenshot 2025-10-08 at 5 27 24 PM --- .../simple_fsdp/deepseek_v3/parallelize.py | 12 ++++-- .../experiments/simple_fsdp/simple_fsdp.py | 43 ++++++++++++++++++- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index e76370e5..df916054 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -125,6 +125,13 @@ def parallelize_deepseekv3( ): experts_shard_dim = 1 + # when EP is enable, the routed experts' gradient reduction is done over + # dp_mod_ep_mesh instead of whole dp_mesh. + # we add a `fsdp_gradient_divide_factor` to scale gradient over dp_mesh + # to be consistent with data. + # TODO (ruisizhang123): update the logic following the link below instead + # of using a reduction_divide_factor + # https://github.com/pytorch/torchtitan/pull/1803#discussion_r2415190883 transformer_block.moe.experts = data_parallel( transformer_block.moe.experts, dp_mod_ep_mesh, @@ -132,11 +139,8 @@ def parallelize_deepseekv3( ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, shard_dim=experts_shard_dim, + reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) - # TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp - # transformer_block.moe.experts.set_gradient_divide_factor( - # parallel_dims.fsdp_gradient_divide_factor, - # ) model = data_parallel( model, diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 8cb2a447..9ca74601 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -49,6 +49,37 @@ class MixedPrecisionPolicy: reduce_dtype: Optional[torch.dtype] = None +class _ScaledPartial(Partial): + # A subclass of Partial placement that allows user to perform reduction with a custom + # factor (reduction_divide_factor) other than the default world size. + def __init__( + self, + reduction_divide_factor: float, + ): + self.reduction_divide_factor = reduction_divide_factor + super().__init__(reduce_op="sum") + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # for all_reduce in DDP + tensor.div_(self.reduction_divide_factor) + reduced = super()._reduce_value(tensor, mesh, mesh_dim) + return reduced + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # for reduce_scatter in FSDP + tensor.div_(self.reduction_divide_factor) + reduced = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec) + return reduced + + def _distribute_dtensor( tensor: DTensor, device_mesh: DeviceMesh, @@ -192,18 +223,24 @@ def __init__( mode, regional_ac, mp_policy, + reduction_divide_factor, ): super().__init__() self.device_mesh = device_mesh self.param_sharding = param_sharding self.mode = mode self.compute_placements = [Replicate()] * self.device_mesh.ndim - self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim + self.grad_placements = [ + _ScaledPartial( + reduction_divide_factor=reduction_divide_factor, + ) + if reduction_divide_factor is not None + else Partial(reduce_op="avg") + ] * self.device_mesh.ndim self.regional_ac = regional_ac mp_policy = mp_policy or MixedPrecisionPolicy() self.param_dtype = mp_policy.param_dtype self.reduce_dtype = mp_policy.reduce_dtype - self.ep_mesh_name, self.tp_mesh_name = "ep", "tp" def replicate_compute(self, x): # data parallel runtime replicate parameters and do local compute @@ -286,6 +323,7 @@ def data_parallel( ac_mode: str = "none", mp_policy: Optional[MixedPrecisionPolicy] = None, shard_dim: int = 0, + reduction_divide_factor: Optional[float] = None, ): if mode == "replicate": param_sharding = (Replicate(),) @@ -348,6 +386,7 @@ def data_parallel( mode, regional_ac, mp_policy=mp_policy, + reduction_divide_factor=reduction_divide_factor, ), ) return model From 5bbda42c7ce51b6dc4abae7a29b78fc9aab493cc Mon Sep 17 00:00:00 2001 From: Shuhua Yu <18108279+shuhuayu@users.noreply.github.com> Date: Wed, 8 Oct 2025 20:54:21 -0700 Subject: [PATCH 06/18] [Llama] Add scaled RoPE support for Llama 3 and 4 (#1839) Llama 3.1 models use scaled RoPE by default, and Llama 4 17B x 16E uses scaled RoPE while 17B x 128E does not. 1. Verified forward parity between Titan Llama 3.1 8B and HuggingFace Llama 3.1 8B. The KL divergence of outputs from the same sample inputs is small. ![llama 3 8b forward parity small](https://github.com/user-attachments/assets/891df89b-006f-4ed0-a68a-36e939d2169b) For comparison, before adding scaled RoPE support, the forward parity check on the Llama 3.1 8B model incurred a slightly larger KL divergence on sample inputs. ![llama 3 8b forward parity without scaled rope](https://github.com/user-attachments/assets/9a68357a-34d4-497f-977f-27cc548d8f62) 2. Verified training of Llama 3.1 8B with tensor parallel degree = 4. ![llama 3-1 8b training tp=4](https://github.com/user-attachments/assets/a8b1ab10-0da0-4d02-afbb-a775716beaa3) 3. Verified training of Llama 4 debug model with scaled RoPE. ![llama 4 debug model training](https://github.com/user-attachments/assets/1fbf8939-31a5-475f-987c-d5bcf6d2376b) --- torchtitan/experiments/llama4/__init__.py | 6 ++- torchtitan/experiments/llama4/model/args.py | 9 ++++ torchtitan/experiments/llama4/model/model.py | 44 ++++++++++++++++++-- torchtitan/models/llama3/model/args.py | 11 ++++- torchtitan/models/llama3/model/model.py | 43 +++++++++++++++++-- 5 files changed, 105 insertions(+), 8 deletions(-) diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py index 71a2eecc..f759dc39 100644 --- a/torchtitan/experiments/llama4/__init__.py +++ b/torchtitan/experiments/llama4/__init__.py @@ -14,7 +14,7 @@ from torchtitan.protocols.train_spec import TrainSpec from .infra.parallelize import parallelize_llama -from .model.args import TransformerModelArgs +from .model.args import RoPEScalingArgs, TransformerModelArgs from .model.model import Transformer from .model.state_dict_adapter import Llama4StateDictAdapter @@ -32,6 +32,7 @@ n_heads=16, vocab_size=2048, rope_theta=500000, + rope_scaling_args=RoPEScalingArgs(), ), "17bx16e": TransformerModelArgs( dim=5120, @@ -41,6 +42,7 @@ ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, + rope_scaling_args=RoPEScalingArgs(), max_seq_len=10485760, moe_args=MoEArgs(num_experts=16), interleave_moe_layer_step=1, @@ -61,6 +63,7 @@ n_heads=16, vocab_size=2048, rope_theta=500000, + rope_scaling_args=RoPEScalingArgs(), every_n_layers_nope=4, fixed_attn_block_size=256, use_flex_attn=True, @@ -74,6 +77,7 @@ ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, + rope_scaling_args=RoPEScalingArgs(), max_seq_len=10485760, moe_args=MoEArgs(num_experts=16), interleave_moe_layer_step=1, diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index e34d4d3c..faeb60aa 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -18,6 +18,14 @@ from torchtitan.tools.utils import has_cuda_capability +@dataclass +class RoPEScalingArgs: + scaling_factor: float = 16.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 1.0 + original_max_position_embeddings: int = 8192 + + @dataclass class TransformerModelArgs(BaseModelArgs): dim: int = 4096 @@ -29,6 +37,7 @@ class TransformerModelArgs(BaseModelArgs): ffn_dim_multiplier: float | None = None norm_eps: float = 1e-5 rope_theta: float = 10000 + rope_scaling_args: RoPEScalingArgs | None = None max_seq_len: int = 1048576 # If `True`, then each transformer block init uses its layer ID, and if diff --git a/torchtitan/experiments/llama4/model/model.py b/torchtitan/experiments/llama4/model/model.py index c88286e5..871c2072 100644 --- a/torchtitan/experiments/llama4/model/model.py +++ b/torchtitan/experiments/llama4/model/model.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math import torch import torch.nn.functional as F @@ -13,10 +14,15 @@ from torchtitan.models.moe import MoE from torchtitan.protocols import ModelProtocol -from .args import TransformerModelArgs +from .args import RoPEScalingArgs, TransformerModelArgs -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + scaling_args: RoPEScalingArgs | None = None, +) -> torch.Tensor: """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -28,11 +34,42 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Te dim (int): Dimension of the frequency tensor. end (int): End index for precomputing frequencies. theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. - + scaling_args (RoPEScalingArgs | None): RoPE scaling arguments. Defaults to None. + scaling_factor (float): RoPE scaling multiplier; larger values + stretch positions to support longer contexts. Defaults to 16.0. + low_freq_factor (float): Extra scaling applied to the low-frequency + (long-wavelength) RoPE bands. Defaults to 1.0. + high_freq_factor (float): Extra scaling applied to the high-frequency + (short-wavelength) RoPE bands. Defaults to 1.0. + original_max_position_embeddings (int): Maximum position embeddings + for original model. Defaults to 8192. Returns: torch.Tensor: Precomputed frequency tensor with complex exponentials. """ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + + # apply rope scaling + if scaling_args is not None: + scaling_factor = scaling_args.scaling_factor + low_freq_factor = scaling_args.low_freq_factor + high_freq_factor = scaling_args.high_freq_factor + original_max_position_embeddings = scaling_args.original_max_position_embeddings + wavelen = 2 * math.pi / freqs + high_freq_wavelen = original_max_position_embeddings / high_freq_factor + low_freq_wavelen = original_max_position_embeddings / low_freq_factor + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by scaling factor + freqs = torch.where(wavelen > low_freq_wavelen, freqs / scaling_factor, freqs) + # wavelen in between: linear interpolation of the scaled freqs and the original freqs + smooth_factor = ( + original_max_position_embeddings / wavelen - low_freq_factor + ) / (high_freq_factor - low_freq_factor) + smoothed_freqs = ( + 1 - smooth_factor + ) * freqs / scaling_factor + smooth_factor * freqs + is_medium_freqs = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + freqs = torch.where(is_medium_freqs, smoothed_freqs, freqs) + t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 @@ -445,6 +482,7 @@ def _precompute_freqs_cis(self) -> torch.Tensor: # relaxing in our CP enablement PR self.model_args.max_seq_len, self.model_args.rope_theta, + self.model_args.rope_scaling_args, ) def forward( diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index d3d21163..2bdafa69 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -7,7 +7,7 @@ # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from dataclasses import dataclass +from dataclasses import dataclass, field from torch import nn @@ -17,6 +17,14 @@ from torchtitan.tools.logging import logger +@dataclass +class RoPEScalingArgs: + scaling_factor: float = 8.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 4.0 + original_max_position_embeddings: int = 8192 + + @dataclass class TransformerModelArgs(BaseModelArgs): dim: int = 4096 @@ -28,6 +36,7 @@ class TransformerModelArgs(BaseModelArgs): ffn_dim_multiplier: float | None = None norm_eps: float = 1e-5 rope_theta: float = 10000 + rope_scaling_args: RoPEScalingArgs = field(default_factory=RoPEScalingArgs) max_seq_len: int = 131072 # If `True`, then each transformer block init uses its layer ID, and if diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 7d713045..753ffae0 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -6,6 +6,7 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. +import math import torch import torch.nn.functional as F @@ -14,10 +15,15 @@ from torchtitan.models.attention import build_attention from torchtitan.protocols.train_spec import ModelProtocol -from .args import TransformerModelArgs +from .args import RoPEScalingArgs, TransformerModelArgs -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + scaling_args: RoPEScalingArgs = RoPEScalingArgs(), +) -> torch.Tensor: """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -29,11 +35,41 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Te dim (int): Dimension of the frequency tensor. end (int): End index for precomputing frequencies. theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0. - + scaling_args (RoPEScalingArgs | None): RoPE scaling arguments. Defaults to None. + scaling_factor (float): RoPE scaling multiplier; larger values + stretch positions to support longer contexts. Defaults to 8.0. + low_freq_factor (float): Extra scaling applied to the low-frequency + (long-wavelength) RoPE bands. Defaults to 1.0. + high_freq_factor (float): Extra scaling applied to the high-frequency + (short-wavelength) RoPE bands. Defaults to 4.0. + original_max_position_embeddings (int): Maximum position embeddings + for original model. Defaults to 8192. Returns: torch.Tensor: Precomputed frequency tensor with complex exponentials. """ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + + # apply rope scaling + scaling_factor = scaling_args.scaling_factor + low_freq_factor = scaling_args.low_freq_factor + high_freq_factor = scaling_args.high_freq_factor + original_max_position_embeddings = scaling_args.original_max_position_embeddings + wavelen = 2 * math.pi / freqs + high_freq_wavelen = original_max_position_embeddings / high_freq_factor + low_freq_wavelen = original_max_position_embeddings / low_freq_factor + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by scaling factor + freqs = torch.where(wavelen > low_freq_wavelen, freqs / scaling_factor, freqs) + # wavelen in between: linear interpolation of the scaled freqs and the original freqs + smooth_factor = (original_max_position_embeddings / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smoothed_freqs = ( + 1 - smooth_factor + ) * freqs / scaling_factor + smooth_factor * freqs + is_medium_freqs = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + freqs = torch.where(is_medium_freqs, smoothed_freqs, freqs) + t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 @@ -389,6 +425,7 @@ def _precompute_freqs_cis(self) -> torch.Tensor: # relaxing in our CP enablement PR self.model_args.max_seq_len, self.model_args.rope_theta, + self.model_args.rope_scaling_args, ) def forward( From 44e9218418ce61b53a2d1b2f64c1a3a392b0e6b8 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 9 Oct 2025 10:35:24 -0700 Subject: [PATCH 07/18] Fix docstring formatting (#1843) The forge's doc build is failing with some formatting issues that seem to come from the torchtitan docstrings: ``` docstring of torchtitan.config.job_config.Parallelism.fsdp_reshard_after_forward:7: ERROR: Unexpected indentation. docstring of torchtitan.config.job_config.Parallelism.fsdp_reshard_after_forward:8: WARNING: Block quote ends without a blank line; unexpected unindent. docstring of torchtitan.config.job_config.Parallelism.expert_parallel_degree:4: ERROR: Unexpected indentation. docstring of torchtitan.config.job_config.Parallelism.expert_parallel_degree:7: WARNING: Block quote ends without a blank line; unexpected unindent. docstring of torchtitan.config.job_config.Parallelism.expert_parallel_degree:11: WARNING: Bullet list ends without a blank line; unexpected unindent. docstring of torchtitan.config.job_config.Checkpoint.async_mode:5: ERROR: Unexpected indentation. ``` Failing [job](https://github.com/meta-pytorch/forge/actions/runs/18360538773/job/52303073438?pr=336#step:11:73). This PR fixes those minor formatting issues. --- torchtitan/config/job_config.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index b7cfcb11..aa6a1076 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -291,9 +291,11 @@ class Parallelism: within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward, trading off memory and communication. See torch's `fully_shard` API for more documentation on `reshard_after_forward`. + The supported policies include "default", "always" and "never": + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal - scenarios. + scenarios. - "always" will enable `reshard_after_forward` for all forward passes. - "never" will disable `reshard_after_forward` for all forward passes. """ @@ -393,15 +395,21 @@ class Parallelism: expert_parallel_degree: int = 1 """ Expert parallelism degree. 1 means disabled. No effect for non-MoE models. + Currently, it is supported with the following constraints: + - when etp = tp: + - cp <= ep <= dp_shard * cp - ep % cp == 0 - dp_shard * cp % ep == 0 + - when etp = 1: + - cp * tp <= ep <= dp_shard * cp * tp - ep % (cp * tp) == 0 - dp_shard * cp * tp % ep == 0 + Note that this is still an experimental feature. Some constraints will be relaxed soon when we have more flexible DeviceMesh support. """ @@ -503,6 +511,7 @@ class Checkpoint: async_mode: Literal["disabled", "async", "async_with_pinned_mem"] = "disabled" """ Which async checkpoint mode to use. Currently there are 3 different modes. + - "disabled": synchronized checkpointing will be used. - "async": torch.distributed.checkpoint.async_save will be used. - "async_with_pinned_mem": this option utilizes a dedicated pinned memory space and creates a From b6ccf2279fc8acf2ff05da9a0ec8a6287e943652 Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Thu, 9 Oct 2025 10:37:43 -0700 Subject: [PATCH 08/18] Fix num of layers for deepseek-v3 (#1845) Fix the number of layer issue introduced by #1804 --- torchtitan/models/deepseek_v3/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 5125a790..a290ea7e 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -134,7 +134,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=4, + n_layers=61, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( From b276387321c4fb1ebf40e918526887b151cd5b9a Mon Sep 17 00:00:00 2001 From: tohskai Date: Thu, 9 Oct 2025 22:15:55 +0200 Subject: [PATCH 09/18] Add support for AC budget API (#1731) Inspired by the blogpost: https://pytorch.org/blog/activation-checkpointing-techniques/ --- torchtitan/config/job_config.py | 20 ++++++++++- .../distributed/activation_checkpoint.py | 34 +++++++++++++------ .../experiments/llama4/infra/parallelize.py | 1 + .../experiments/qwen3/infra/parallelize.py | 1 + .../simple_fsdp/llama3/parallelize.py | 1 + .../models/deepseek_v3/infra/parallelize.py | 1 + torchtitan/models/llama3/infra/parallelize.py | 1 + 7 files changed, 48 insertions(+), 11 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index aa6a1076..a003859a 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -569,7 +569,7 @@ class Checkpoint: @dataclass class ActivationCheckpoint: - mode: Literal["selective", "full", "none"] = "selective" + mode: Literal["selective", "full", "memory_budget", "none"] = "selective" """Type of activation checkpointing to use""" selective_ac_option: str = "2" @@ -598,6 +598,24 @@ class ActivationCheckpoint: rematerialized. """ + memory_budget: float = 0.5 + """ + When mode is set to "memory_budget", this value determines how much + partitioner in the compiler should trade off compute for memory. + 0.0 corresponds to the activation memory from applying + activation checkpointing to the full compiled region, and 1.0 corresponds to + the activation memory from the default runtime-optimized strategy. Read here: + https://pytorch.org/blog/activation-checkpointing-techniques/ + """ + + visualize_memory_budget_pareto: bool = False + """ + This dumps out a SVG visualization of the expected runtime vs. activation + memory tradeoffs for all memory budget values from 0 to 1 in increments of + 0.05 in {--job.dump_folder}/memory_budget_pareto folder. See an example here: + https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015 + """ + @dataclass class Compile: diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 227c2ca2..57809c45 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -7,6 +7,7 @@ # This file provides the util functions to apply activation checkpointing to the model. # Technically, this is not a part of distributed, but distributed module is the best place to put it. +import os from collections import defaultdict import torch @@ -279,6 +280,7 @@ def apply_ac( model_compile_enabled: bool = False, use_flex_attn: bool = False, op_sac_save_list: set[torch._ops.OpOverload] | None = None, + base_folder: str = "", ) -> None: """Apply activation checkpointing to the model. @@ -297,15 +299,27 @@ def apply_ac( None """ - for layer_id, transformer_block in model.layers.named_children(): - transformer_block = _apply_ac_to_transformer_block( - transformer_block, - ac_config, - base_fqn=f"layers.{layer_id}", - model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, - op_sac_save_list=op_sac_save_list, - ) - model.layers.register_module(layer_id, transformer_block) + if ac_config.mode == "memory_budget": + assert model_compile_enabled, "Memory budget mode requires model to be compiled" + if ac_config.visualize_memory_budget_pareto: + pareto_dir = os.path.join(base_folder, "memory_budget_pareto") + if not os.path.exists(pareto_dir): + os.makedirs(pareto_dir, exist_ok=True) + torch._functorch.config.memory_budget_pareto_dir = pareto_dir + torch._functorch.config.visualize_memory_budget_pareto = True + + torch._functorch.config.activation_memory_budget = ac_config.memory_budget + logger.info(f"Selected {ac_config.memory_budget} budget option") + else: + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = _apply_ac_to_transformer_block( + transformer_block, + ac_config, + base_fqn=f"layers.{layer_id}", + model_compile_enabled=model_compile_enabled, + use_flex_attn=use_flex_attn, + op_sac_save_list=op_sac_save_list, + ) + model.layers.register_module(layer_id, transformer_block) logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index dba6d69e..18b2253b 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -120,6 +120,7 @@ def parallelize_llama( model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP diff --git a/torchtitan/experiments/qwen3/infra/parallelize.py b/torchtitan/experiments/qwen3/infra/parallelize.py index 93f4caea..27406bf7 100644 --- a/torchtitan/experiments/qwen3/infra/parallelize.py +++ b/torchtitan/experiments/qwen3/infra/parallelize.py @@ -114,6 +114,7 @@ def parallelize_qwen3( model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index cf3f1dd4..3e2775b7 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -85,6 +85,7 @@ def parallelize_llama( model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, ) # apply data parallel diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index c7cd45f4..1c73cef7 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -113,6 +113,7 @@ def parallelize_deepseekv3( model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, ) if model_compile_enabled: diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 9f98eaf2..89066e86 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -102,6 +102,7 @@ def parallelize_llama( model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP From 98d904f2c563ea05d9801ce2eb96b634161efc01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ph=C3=BAc=20H=2E=20L=C3=AA=20Kh=E1=BA=AFc?= Date: Fri, 10 Oct 2025 04:39:28 +0700 Subject: [PATCH 10/18] [VLM] Add token-imbalance loss (#1803) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In VLM interleaved training, with native resolution and aspect ratio, the number of tokens participating in loss computation differ per rank. Naive FSDP gradient averaging across data ranks can causes tokens on ranks with fewer valid tokens to contribute more to the loss than on other ranks. This PR address this via loss balancing, which incur an additional comm in the loss computation. In practice, I haven't notice any impacts from this comm. #### Quick sanity check Let have a sum loss of all tokens on each rank i, with $N_i$ number of tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i = \sum_{j=1}^{N_i}\nabla\ell_{ij}$ If we multiply the *loss* on each rank by a constant factor **c** (the same for all ranks), then after `backward()`: $$ \tilde g_i = c \cdot g_i . $$ FSDP will *average* these gradients across ranks: $$ g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i =\frac{c}{R}\sum_{i=1}^{R} g_i . $$ We want this to equal the **global‑sample average**: $$ g_{\text{true}} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla \ell_{ij} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i . $$ Thus for FSDP gradient to be correct, we need $$ \frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad c=\frac{R}{N_{\text{total}}}. $$ So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide the per-rank sum loss with $N_{\text{total}}/R$, which is **average number of tokens per rank**. Intuitively, this is the same as default cross-entropy loss, but instead of diving sum loss on a rank by the number of tokens **on that rank**, we now divide by the **average number of tokens across all rank** P/s: sorry this PR is based on #1802 but I couldn't choose that as the base branch. Maybe it will be easier to review once that PR is merged. --- torchtitan/components/loss.py | 3 +- torchtitan/experiments/vlm/__init__.py | 1 - torchtitan/experiments/vlm/infra/loss.py | 113 +++++++++++++++++++++++ torchtitan/train.py | 4 +- 4 files changed, 118 insertions(+), 3 deletions(-) create mode 100644 torchtitan/experiments/vlm/infra/loss.py diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index 2b14b9a8..30366113 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -23,7 +23,8 @@ def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor ) -def build_cross_entropy_loss(job_config: JobConfig): +def build_cross_entropy_loss(job_config: JobConfig, **kwargs): + del kwargs # delete any unused arguments loss_fn = cross_entropy_loss if job_config.compile.enable and "loss" in job_config.compile.components: logger.info("Compiling the loss function with torch.compile") diff --git a/torchtitan/experiments/vlm/__init__.py b/torchtitan/experiments/vlm/__init__.py index 7d62a8ed..051f66eb 100644 --- a/torchtitan/experiments/vlm/__init__.py +++ b/torchtitan/experiments/vlm/__init__.py @@ -6,7 +6,6 @@ from dataclasses import asdict, replace -from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers from torchtitan.components.tokenizer import build_hf_tokenizer diff --git a/torchtitan/experiments/vlm/infra/loss.py b/torchtitan/experiments/vlm/infra/loss.py new file mode 100644 index 00000000..bba51f28 --- /dev/null +++ b/torchtitan/experiments/vlm/infra/loss.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch import distributed as dist +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.components.ft.manager import FTManager +from torchtitan.config.job_config import JobConfig +from torchtitan.distributed.parallel_dims import ParallelDims +from torchtitan.tools.logging import logger + + +IGNORE_INDEX = -100 # Pytorch's default for F.cross_entropy + + +# WARNING: currently this does not take into account gradient accumulation +# and the gradient can still be biased toward grad accum step with less valid tokens +# See: https://github.com/pytorch/torchtitan/issues/1842 +def token_imbalance_ce_loss( + pred: torch.Tensor, + labels: torch.Tensor, + token_mesh: DeviceMesh, + ft_pg: dist.ProcessGroup | None, +) -> torch.Tensor: + """ + Cross‑entropy loss that is *robust* to varying numbers of valid tokens across ranks. + + In a typical distributed training setup (data parallel + sequence parallel), + each rank computes the loss over **only its local tokens** and returns an + *average* over those tokens: + + Afterwards, when Fully‑Sharded Data Parallel (FSDP) averages the gradients + across all ranks, the resulting update is equivalent to a **global sample + average** *only if every rank contains the same number of tokens*. + In practice that assumption is violated for many workloads: + - Sequences are padded to a fixed length -> some ranks see fewer real tokens. + - SFT finetuning where user's queries tokens are masked out. + - Vision encoders often injects a large number of “ignored” + tokens as context that are not trained with text tokens' loss. + + This function fixes the issue by **scaling the sum-of-loss** with the *average* + number of non‑ignored tokens per rank, computed via an all-reduce over + `token_mesh`. The returned scalar therefore represents the loss that would + be obtained if every token in the entire distributed batch contributed with + equal weight to the global gradient, regardless of how many padded or + ignored tokens each rank contains. + + Parameters + ---------- + pred : torch.Tensor + labels : torch.Tensor + token_mesh : DeviceMesh + A device mesh that contains all ranks participating in this training step's + loss computation. The function performs an ``all_reduce`` (mean) over the + `num_tokens` tensor of a rank across this mesh. + ft_pg: dist.ProcessGroup | None + Optional pg for Fault Tolerance training. + + Returns + ------- + torch.Tensor + A scalar loss tensor, ready for ``backward()`` and FSDP all-reduce mean + + Notes + ----- + * The function internally uses :func:`torch.nn.functional.cross_entropy` + with ``reduction="sum"`` so that each token contributes exactly once to + the numerator. The denominator is the **average** number of valid tokens + per rank, not the local count. + * If a rank contains no valid tokens (i.e., all labels are ``IGNORE_INDEX``), + its contribution to the sum is zero and its `num_tokens` becomes zero. + In that case the mean across ranks will still be well‑defined as long as + at least one rank has non‑zero token count. + """ + sum_loss = torch.nn.functional.cross_entropy( + pred.flatten(0, 1).float(), + labels.flatten(0, 1), + reduction="sum", + ignore_index=IGNORE_INDEX, + ) + num_tokens = (labels != IGNORE_INDEX).sum() + avg_num_tokens_per_rank = funcol.all_reduce( + num_tokens, reduceOp=c10d.ReduceOp.AVG.name, group=token_mesh + ) + if ft_pg is not None: + avg_num_tokens_per_rank = funcol.all_reduce( + avg_num_tokens_per_rank, reduceOp=c10d.ReduceOp.AVG.name, group=ft_pg + ) + return sum_loss / avg_num_tokens_per_rank + + +def build_token_imbalance_ce_loss( + job_config: JobConfig, parallel_dims: ParallelDims, ft_manager: FTManager, **kwargs +): + del kwargs # delete any unused arguments + # NOTE: The device mesh where the input tokens w/ shape BSD can be sliced: + # DP split the batch dim B + # CP split the sequence dim S + token_mesh = parallel_dims.world_mesh["dp_cp"] + ft_pg = ft_manager.loss_sync_pg + loss_fn = partial(token_imbalance_ce_loss, token_mesh=token_mesh, ft_pg=ft_pg) + if job_config.compile.enable and "loss" in job_config.compile.components: + logger.info("Compiling the loss function with torch.compile") + loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend) + return loss_fn diff --git a/torchtitan/train.py b/torchtitan/train.py index 69e5c03e..ffd8b77a 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -197,7 +197,9 @@ def __init__(self, job_config: JobConfig): init_device = device_type buffer_device = None - self.loss_fn = self.train_spec.build_loss_fn(job_config) + self.loss_fn = self.train_spec.build_loss_fn( + job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager + ) # verify batch sizes global_batch_size = job_config.training.global_batch_size From aa000a3c42e8bb37e51f26eb3e3e024b37ccc479 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Thu, 9 Oct 2025 18:22:41 -0700 Subject: [PATCH 11/18] refactor TrainSpec to remove the name field (#1850) --- docs/extension.md | 2 +- scripts/estimate/estimation.py | 2 +- tests/unit_tests/test_train_spec.py | 6 ++---- torchtitan/experiments/deepseek_v3/__init__.py | 4 ++-- torchtitan/experiments/flux/__init__.py | 1 - torchtitan/experiments/forge/engine.py | 2 +- torchtitan/experiments/forge/example_train.py | 4 ++-- torchtitan/experiments/forge/train_spec.py | 10 ++++------ torchtitan/experiments/llama4/__init__.py | 3 ++- torchtitan/experiments/multimodal/__init__.py | 4 ++-- torchtitan/experiments/qwen3/__init__.py | 1 - torchtitan/experiments/vlm/__init__.py | 2 +- .../experiments/vlm/train_configs/debug_model.toml | 2 +- torchtitan/models/README.md | 2 +- torchtitan/models/deepseek_v3/__init__.py | 1 - torchtitan/models/llama3/__init__.py | 1 - torchtitan/models/llama3_ft/__init__.py | 3 +-- torchtitan/protocols/train_spec.py | 9 ++++----- torchtitan/train.py | 6 +++--- 19 files changed, 28 insertions(+), 37 deletions(-) diff --git a/docs/extension.md b/docs/extension.md index f529b05b..98902528 100644 --- a/docs/extension.md +++ b/docs/extension.md @@ -14,7 +14,7 @@ The extension points and protocols mentioned in this note are subject to change. The coarse level abstraction tries to hit a balance between flexible component swapping and a straightforward train script ([train.py](../torchtitan/train.py)). Note that among all training components, currently [`CheckpointManager`](../torchtitan/components/checkpoint.py) and [`FTManager`](../torchtitan/components/ft.py) are not configurable since we do not expect them to be customized, but we are open to requests. -To register a `TrainSpec`, please follow the example of [Llama 3.1](../torchtitan/models/llama3/__init__.py) to `register_train_spec`. Please make sure the registration code is called before training initialization. In torchtitan, it is performed during [module import](../torchtitan/__init__.py). +To register a `TrainSpec`, please use the `register_train_spec` API, and make sure registration happens before `get_train_spec` is called during training initialization. In torchtitan, `get_train_spec` will dynamically look for models in `torchtitan/models` or `torchtitan/experiments`. ### `ModelConverter` diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 8103ae0b..b1f45c40 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -95,7 +95,7 @@ def estimate_memory(job_config: JobConfig): else contextlib.nullcontext() ): logger.info( - f"Building {train_spec.name} {job_config.model.flavor} with {model_args}" + f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}" ) with torch.device("meta"): model = train_spec.model_cls(model_args) diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py index 57167304..fb326a47 100644 --- a/tests/unit_tests/test_train_spec.py +++ b/tests/unit_tests/test_train_spec.py @@ -76,7 +76,6 @@ class TestTrainSpec: def test_register_train_spec(self): fake_config = {"fake": BaseModelArgs()} spec = TrainSpec( - name="fake", model_cls=FakeModel, model_args=fake_config, parallelize_fn=parallelize_llama, @@ -87,7 +86,7 @@ def test_register_train_spec(self): build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, ) - register_train_spec(spec) + register_train_spec("fake", spec) new_spec = get_train_spec("fake") assert new_spec == spec @@ -98,7 +97,6 @@ def test_optim_hook(self): fake_config = {"fake": BaseModelArgs()} spec = TrainSpec( - name="fake2", model_cls=FakeModel, model_args=fake_config, parallelize_fn=parallelize_llama, @@ -109,7 +107,7 @@ def test_optim_hook(self): build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, ) - register_train_spec(spec) + register_train_spec("fake2", spec) new_spec = get_train_spec("fake2") model = new_spec.model_cls(BaseModelArgs()) diff --git a/torchtitan/experiments/deepseek_v3/__init__.py b/torchtitan/experiments/deepseek_v3/__init__.py index f93d0d80..f5829dab 100644 --- a/torchtitan/experiments/deepseek_v3/__init__.py +++ b/torchtitan/experiments/deepseek_v3/__init__.py @@ -40,8 +40,8 @@ register_train_spec( + "deepseek3", TrainSpec( - name="deepseek3", model_cls=DeepseekForCausalLM, model_args=deepseek_configs, parallelize_fn=parallelize_deepseek, @@ -51,5 +51,5 @@ build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=get_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, - ) + ), ) diff --git a/torchtitan/experiments/flux/__init__.py b/torchtitan/experiments/flux/__init__.py index 89c3f68b..2d648f51 100644 --- a/torchtitan/experiments/flux/__init__.py +++ b/torchtitan/experiments/flux/__init__.py @@ -109,7 +109,6 @@ def get_train_spec() -> TrainSpec: return TrainSpec( - name="flux", model_cls=FluxModel, model_args=flux_configs, parallelize_fn=parallelize_flux, diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index 3d0c52c0..f8b14129 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -167,7 +167,7 @@ def __init__(self, job_config: ForgeJobConfig): if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: raise RuntimeError( - f"Pipeline Parallel is enabled but {self.train_spec.name} " + f"Pipeline Parallel is enabled but {job_config.model.name} " f"does not support pipelining" ) diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 7bd1531d..8feb547b 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -66,7 +66,7 @@ def __init__(self, job_config: JobConfig): model_args = self.model_args logger.info( - f"Built {self.train_spec.name} {job_config.model.flavor} with {model_args}" + f"Built {job_config.model.name} {job_config.model.flavor} with {model_args}" ) # metrics logging @@ -78,7 +78,7 @@ def __init__(self, job_config: JobConfig): self.metrics_processor.num_flops_per_token = self.num_flops_per_token logger.info( - f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " + f"{color.blue}Model {job_config.model.name} {job_config.model.flavor} " f"{color.red}size: {self.model_param_count:,} total parameters{color.reset}" ) diff --git a/torchtitan/experiments/forge/train_spec.py b/torchtitan/experiments/forge/train_spec.py index b7b1d605..f9ad1d65 100644 --- a/torchtitan/experiments/forge/train_spec.py +++ b/torchtitan/experiments/forge/train_spec.py @@ -21,7 +21,6 @@ @dataclass class ForgeTrainSpec: - name: str model_cls: type[ModelProtocol] model_args: Mapping[str, BaseModelArgs] parallelize_fn: ParallelizeFunction @@ -39,7 +38,6 @@ def _transform_train_spec(original_spec: TrainSpec): """Transform the original train spec to ForgeTrainSpec format.""" # Create a new TrainSpec with only the fields we need in forge return ForgeTrainSpec( - name=original_spec.name, model_cls=original_spec.model_cls, model_args=original_spec.model_args, parallelize_fn=original_spec.parallelize_fn, @@ -51,13 +49,13 @@ def _transform_train_spec(original_spec: TrainSpec): ) -def register_train_spec(train_spec: ForgeTrainSpec) -> None: +def register_train_spec(name: str, train_spec: ForgeTrainSpec) -> None: global _extra_train_specs - if train_spec.name in _extra_train_specs: - raise ValueError(f"ForgeTrainSpec {train_spec.name} is already registered.") + if name in _extra_train_specs: + raise ValueError(f"ForgeTrainSpec {name} is already registered.") # user can define a ForgeTrainSpec from outside of torchtitan - _extra_train_specs[train_spec.name] = train_spec + _extra_train_specs[name] = train_spec def get_train_spec(name: str) -> ForgeTrainSpec: diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py index f759dc39..325cd6ac 100644 --- a/torchtitan/experiments/llama4/__init__.py +++ b/torchtitan/experiments/llama4/__init__.py @@ -8,6 +8,7 @@ from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.models.llama3 import pipeline_llama from torchtitan.models.moe import MoEArgs @@ -103,7 +104,6 @@ def get_train_spec() -> TrainSpec: return TrainSpec( - name="llama4", model_cls=Transformer, model_args=llama4_configs, parallelize_fn=parallelize_llama, @@ -113,5 +113,6 @@ def get_train_spec() -> TrainSpec: build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, state_dict_adapter=Llama4StateDictAdapter, ) diff --git a/torchtitan/experiments/multimodal/__init__.py b/torchtitan/experiments/multimodal/__init__.py index bbb37d5c..b35bc165 100644 --- a/torchtitan/experiments/multimodal/__init__.py +++ b/torchtitan/experiments/multimodal/__init__.py @@ -22,8 +22,8 @@ } register_train_spec( + "llama4_multimodal", TrainSpec( - name="llama4_multimodal", model_cls=MultimodalDecoder, model_args=llama4_mm_configs, parallelize_fn=parallelize_llama, @@ -33,5 +33,5 @@ build_dataloader_fn=build_mm_dataloader, build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, - ) + ), ) diff --git a/torchtitan/experiments/qwen3/__init__.py b/torchtitan/experiments/qwen3/__init__.py index b468ff96..32ba652f 100644 --- a/torchtitan/experiments/qwen3/__init__.py +++ b/torchtitan/experiments/qwen3/__init__.py @@ -180,7 +180,6 @@ def get_train_spec() -> TrainSpec: return TrainSpec( - name="qwen3", model_cls=Qwen3Model, model_args=qwen3_configs, # Change from dict to Mapping parallelize_fn=parallelize_qwen3, diff --git a/torchtitan/experiments/vlm/__init__.py b/torchtitan/experiments/vlm/__init__.py index 051f66eb..7fd59564 100644 --- a/torchtitan/experiments/vlm/__init__.py +++ b/torchtitan/experiments/vlm/__init__.py @@ -6,6 +6,7 @@ from dataclasses import asdict, replace +from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers from torchtitan.components.tokenizer import build_hf_tokenizer @@ -41,7 +42,6 @@ def get_train_spec() -> TrainSpec: return TrainSpec( - name="llama3-siglip2", model_cls=Llama3Siglip2Transformer, model_args=llama3_siglip2_configs, parallelize_fn=parallelize_vlm, diff --git a/torchtitan/experiments/vlm/train_configs/debug_model.toml b/torchtitan/experiments/vlm/train_configs/debug_model.toml index c4f97463..91b7c0c3 100644 --- a/torchtitan/experiments/vlm/train_configs/debug_model.toml +++ b/torchtitan/experiments/vlm/train_configs/debug_model.toml @@ -23,7 +23,7 @@ save_tb_folder = "tb" enable_wandb = false [model] -name = "llama3-siglip2" +name = "vlm" flavor = "debugmodel" # test folder with tokenizer.json, for debug purpose only hf_assets_path = "tests/assets/tokenizer" diff --git a/torchtitan/models/README.md b/torchtitan/models/README.md index 467031ce..456fe14b 100644 --- a/torchtitan/models/README.md +++ b/torchtitan/models/README.md @@ -40,7 +40,7 @@ The folder should be organized as follows - `__init__.py` - A dictionary of the actual model configurations, of the type `[str: ModelArgs]`. - Define `get_train_spec` to return a [`TrainSpec`](/torchtitan/protocols/train_spec.py), consisting a tuple of - - model name, model class, model args + - model class, model args - Model name should be the same as the folder name, which should be added to `torchtitan/models/__init__.py` or ``torchtitan/experiments/__init__.py``. - parallelizing function, pipelining function - builder functions for optimizer, lr scheduler, data loader, tokenizer, and loss function diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index a290ea7e..4e8d500b 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -161,7 +161,6 @@ def get_train_spec() -> TrainSpec: return TrainSpec( - name="deepseek_v3", model_cls=DeepSeekV3Model, model_args=deepseekv3_configs, parallelize_fn=parallelize_deepseekv3, diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 33e0d66a..2c0572a4 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -72,7 +72,6 @@ def get_train_spec() -> TrainSpec: return TrainSpec( - name="llama3", model_cls=Transformer, model_args=llama3_configs, parallelize_fn=parallelize_llama, diff --git a/torchtitan/models/llama3_ft/__init__.py b/torchtitan/models/llama3_ft/__init__.py index 1dad5e72..f6337eeb 100644 --- a/torchtitan/models/llama3_ft/__init__.py +++ b/torchtitan/models/llama3_ft/__init__.py @@ -33,12 +33,10 @@ def get_train_spec() -> TrainSpec: return FaultTolerantTrainSpec( - name="llama3_ft", model_cls=Transformer, model_args=llama3_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, - fragment_fn=fragment_llm, build_optimizers_fn=build_optimizers, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, @@ -46,4 +44,5 @@ def get_train_spec() -> TrainSpec: build_loss_fn=build_cross_entropy_loss, build_validator_fn=build_validator, state_dict_adapter=Llama3StateDictAdapter, + fragment_fn=fragment_llm, ) diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index 71d2a98a..22bfa7df 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -42,7 +42,6 @@ @dataclass class TrainSpec: - name: str model_cls: type[ModelProtocol] model_args: Mapping[str, BaseModelArgs] parallelize_fn: ParallelizeFunction @@ -60,13 +59,13 @@ class TrainSpec: _extra_train_specs: dict[str, TrainSpec] = {} -def register_train_spec(train_spec: TrainSpec) -> None: +def register_train_spec(name: str, train_spec: TrainSpec) -> None: global _extra_train_specs - if train_spec.name in _extra_train_specs: - raise ValueError(f"TrainSpec {train_spec.name} is already registered.") + if name in _extra_train_specs: + raise ValueError(f"TrainSpec {name} is already registered.") # user can define a TrainSpec from outside of torchtitan - _extra_train_specs[train_spec.name] = train_spec + _extra_train_specs[name] = train_spec def get_train_spec(name: str) -> TrainSpec: diff --git a/torchtitan/train.py b/torchtitan/train.py index ffd8b77a..287828d8 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -152,7 +152,7 @@ def __init__(self, job_config: JobConfig): self.model_args = model_args logger.info( - f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" + f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}" ) with ( torch.device("meta"), @@ -182,7 +182,7 @@ def __init__(self, job_config: JobConfig): ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) logger.info( - f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " + f"{color.blue}Model {job_config.model.name} {job_config.model.flavor} " f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) @@ -229,7 +229,7 @@ def __init__(self, job_config: JobConfig): if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: raise RuntimeError( - f"Pipeline Parallel is enabled but {self.train_spec.name} " + f"Pipeline Parallel is enabled but {job_config.model.name} " f"does not support pipelining" ) From abbb47aa6129d4ef55bc63905786ef2ca4483bb0 Mon Sep 17 00:00:00 2001 From: Shuhua Yu <18108279+shuhuayu@users.noreply.github.com> Date: Thu, 9 Oct 2025 21:47:03 -0700 Subject: [PATCH 12/18] [VLM] Update config import from Llama 3 to support scaled RoPE args (#1849) A test run on vlm debugmodel: image --- torchtitan/experiments/vlm/__init__.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchtitan/experiments/vlm/__init__.py b/torchtitan/experiments/vlm/__init__.py index 7fd59564..86c3faa5 100644 --- a/torchtitan/experiments/vlm/__init__.py +++ b/torchtitan/experiments/vlm/__init__.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import asdict, replace +from dataclasses import fields +from typing import Any from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers @@ -27,9 +28,14 @@ ] +def _get_dict(obj) -> dict[str, Any]: + """Convert dataclass to dict, preserving nested dataclasses (unlike asdict).""" + return {field.name: getattr(obj, field.name) for field in fields(obj)} + + llama3_siglip2_configs = { "debugmodel": Llama3Siglip2ModelArgs( - **asdict(replace(llama3_configs["debugmodel"], vocab_size=2048)), + **_get_dict(llama3_configs["debugmodel"]), encoder=Siglip2ModelArgs( dim=128, ffn_dim=256, From 6a3a9da9564d82a1120c7639ef6236bb4cffa049 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 9 Oct 2025 23:03:57 -0700 Subject: [PATCH 13/18] Refactor attention and make attention mask an argument to the model (#1776) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #1797 * __->__ #1776 **Status** 1. Change all models, including the experimental ones. 2. E2E loss verification. 3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a separate PR. **Summary** This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks. The previous design has several issues, one particular one is https://github.com/pytorch/torchtitan/issues/1723. https://github.com/pytorch/pytorch/pull/164111/ proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward(). The new design: 1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask. Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks. 2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward(). Justification: this will allow us to fix https://github.com/pytorch/torchtitan/issues/1723 with https://github.com/pytorch/pytorch/pull/164111/ and this PR. 3. Now SDPA and FlexAttention are wrapped in two different classes. ~~Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certainly can be confusion for Titan's users. I'm opnn to merge them to AttentionOp.~~ See the discussion in https://github.com/pytorch/torchtitan/issues/1723. **Verification** *llama3* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" ``` *llama3 flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ``` *llama4* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *llama4 irope* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *deepseek* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ``` *deepseek flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ``` --- torchtitan/distributed/utils.py | 12 +- torchtitan/experiments/forge/example_train.py | 24 +- .../experiments/llama4/infra/parallelize.py | 4 +- torchtitan/experiments/llama4/model/model.py | 69 +++- torchtitan/experiments/qwen3/model/model.py | 56 ++- torchtitan/experiments/vlm/__init__.py | 2 +- torchtitan/experiments/vlm/model/model.py | 28 +- torchtitan/experiments/vlm/model/siglip2.py | 59 ++- torchtitan/models/attention.py | 338 ++++++++---------- .../models/deepseek_v3/infra/parallelize.py | 24 +- torchtitan/models/deepseek_v3/model/model.py | 66 +++- torchtitan/models/llama3/infra/parallelize.py | 4 +- torchtitan/models/llama3/model/model.py | 59 ++- torchtitan/protocols/model.py | 17 + torchtitan/train.py | 22 +- 15 files changed, 508 insertions(+), 276 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 7bec252b..c2ec7bd7 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -110,9 +110,9 @@ def set_determinism( # reproducibility, since the autotune results may not be deterministic. from torch.nn.attention.flex_attention import flex_attention - from torchtitan.models.attention import FlexAttention + from torchtitan.models.attention import FlexAttentionWrapper - FlexAttention.flex_attn = torch.compile(flex_attention) + FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) if not world_mesh: if seed is not None: @@ -207,14 +207,6 @@ def context(cp_context: Generator[None, None, None] | None = None): torch._dynamo.utils.maybe_enable_compiled_autograd(True) ) - if cp_context is not None: - from torch.nn.attention import SDPBackend - - from torchtitan.models.attention import ScaledDotProductAttention - - if SDPBackend.MATH in ScaledDotProductAttention.backends: - ScaledDotProductAttention.backends.remove(SDPBackend.MATH) - stack.enter_context(cp_context) yield diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 8feb547b..d3a7d39b 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -157,15 +157,14 @@ def forward_backward_step( model_parts = self.model_parts parallel_dims = self.parallel_dims - # apply context parallelism if cp is enabled - # ensure CP handles the separate freqs_cis buffer for each pp stage inputs = input_dict["input"] - # Create the FlexAttention mask according to the input + extra_args = {} + if getattr(self.model_args, "use_flex_attn", False): - cp_mesh = ( - parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None + extra_args["attention_masks"] = model_parts[0].get_attention_masks( + input_batch=inputs, + tokenizer=self.tokenizer, ) - init_attention_mask(inputs, self.tokenizer.eos_id, cp_mesh) optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( @@ -187,11 +186,18 @@ def forward_backward_step( ) if self.pp_has_first_stage: self.pp_schedule.step( - inputs, target=targets, losses=losses, input_batch=inputs + inputs, + **extra_args, + target=targets, + losses=losses, + input_batch=inputs, ) else: self.pp_schedule.step( - target=targets, losses=losses, input_batch=inputs + **extra_args, + target=targets, + losses=losses, + input_batch=inputs, ) # accumulate losses across pipeline microbatches @@ -209,7 +215,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs) + pred = model_parts[0](inputs, **extra_args) loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 18b2253b..c0607318 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -239,8 +239,8 @@ def apply_non_moe_tp( layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), + input_layouts=(Shard(1), None, None), + desired_input_layouts=(Replicate(), None, None), ), "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), diff --git a/torchtitan/experiments/llama4/model/model.py b/torchtitan/experiments/llama4/model/model.py index 871c2072..93ff4e89 100644 --- a/torchtitan/experiments/llama4/model/model.py +++ b/torchtitan/experiments/llama4/model/model.py @@ -9,10 +9,20 @@ import torch import torch.nn.functional as F from torch import nn - -from torchtitan.models.attention import build_attention +from torch.nn.attention.flex_attention import and_masks + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + get_fixed_block_mask_mod, + ScaledDotProductAttentionWrapper, +) from torchtitan.models.moe import MoE -from torchtitan.protocols import ModelProtocol +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.protocols.train_spec import ModelProtocol from .args import RoPEScalingArgs, TransformerModelArgs @@ -192,9 +202,11 @@ def __init__( # values of these two variables. self.use_rope = use_rope - self.sdpa = build_attention( - model_args.use_flex_attn, model_args.attn_mask_type, fixed_block_size - ) + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -205,6 +217,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Forward pass of the attention module. @@ -239,7 +252,13 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - output = self.sdpa(xq, xk, xv) + if self.use_flex_attn: + assert isinstance(attention_masks, dict), attention_masks + attention_mask = attention_masks["rope" if self.use_rope else "nope"] + output = self.inner_attention(xq, xk, xv, block_mask=attention_mask) + else: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) output = output.transpose( 1, 2 @@ -372,6 +391,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Perform a forward pass through the TransformerBlock. @@ -384,7 +404,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis) + h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) if self.moe_enabled: out = h + self.moe(self.ffn_norm(h)) else: @@ -485,9 +505,40 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.rope_scaling_args, ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + B = input_batch.shape[0] + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + + rope_mask_mod = and_masks( + *mask_mods, + get_fixed_block_mask_mod(self.model_args.fixed_attn_block_size), + ) + nope_mask_mod = and_masks(*mask_mods) + + seqlen = input_batch.shape[1] + return { + "rope": create_attention_mask(rope_mask_mod, B, None, seqlen, seqlen), + "nope": create_attention_mask(nope_mask_mod, B, None, seqlen, seqlen), + } + def forward( self, tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -511,7 +562,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + h = layer(h, self.freqs_cis, attention_masks) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/experiments/qwen3/model/model.py b/torchtitan/experiments/qwen3/model/model.py index f2a77e99..0fff490b 100644 --- a/torchtitan/experiments/qwen3/model/model.py +++ b/torchtitan/experiments/qwen3/model/model.py @@ -10,13 +10,23 @@ import torch import torch.nn.functional as F from torch import nn - -from torchtitan.models.attention import build_attention +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) from torchtitan.models.moe import MoE +from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol from .args import Qwen3ModelArgs + # Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_positional_embeddings.py def precompute_rope_cache( dim: int, max_seq_len: int, base: float = 1_000_000.0 @@ -133,6 +143,7 @@ def __init__(self, model_args: Qwen3ModelArgs): self.n_rep = self.n_heads // self.n_kv_heads self.head_dim = model_args.head_dim self.scaling = self.head_dim**-0.5 + self.use_flex_attn = getattr(model_args, "use_flex_attn", False) # RMSNorm added here to the here to include the q-k norm # This is one of the main differences between Llama3 and Qwen3 @@ -155,7 +166,11 @@ def __init__(self, model_args: Qwen3ModelArgs): self.wo = nn.Linear( model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -170,6 +185,7 @@ def forward( self, x: torch.Tensor, rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Forward pass of the attention module. @@ -210,7 +226,12 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - output = self.sdpa(xq, xk, xv, scale=self.scaling) + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + else: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) output = output.transpose( 1, 2 @@ -308,6 +329,7 @@ def forward( self, x: torch.Tensor, rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Perform a forward pass through the TransformerBlock. @@ -320,7 +342,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - x = x + self.attention(self.attention_norm(x), rope_cache) + x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks) if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) @@ -423,9 +445,31 @@ def _precompute_rope_cache(self) -> torch.Tensor: self.model_args.rope_theta, ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] + ) + def forward( self, tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -449,7 +493,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.rope_cache) + h = layer(h, self.rope_cache, attention_masks) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/experiments/vlm/__init__.py b/torchtitan/experiments/vlm/__init__.py index 86c3faa5..19452ac1 100644 --- a/torchtitan/experiments/vlm/__init__.py +++ b/torchtitan/experiments/vlm/__init__.py @@ -35,7 +35,7 @@ def _get_dict(obj) -> dict[str, Any]: llama3_siglip2_configs = { "debugmodel": Llama3Siglip2ModelArgs( - **_get_dict(llama3_configs["debugmodel"]), + **_get_dict(llama3_configs["debugmodel_flex_attn"]), encoder=Siglip2ModelArgs( dim=128, ffn_dim=256, diff --git a/torchtitan/experiments/vlm/model/model.py b/torchtitan/experiments/vlm/model/model.py index 71c8a739..712cd805 100644 --- a/torchtitan/experiments/vlm/model/model.py +++ b/torchtitan/experiments/vlm/model/model.py @@ -7,8 +7,11 @@ import einops as E import torch from torch import nn +from torch.nn.attention.flex_attention import BlockMask +from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.llama3 import Transformer as Llama3 +from torchtitan.protocols.model import AttentionMasksType from ..datasets.mm_datasets import SpecialTokens @@ -71,28 +74,49 @@ def init_weights(self, buffer_device=None): if self.projector is not None: self.projector.init_weights() + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + masks = super().get_attention_masks(input_batch, tokenizer, extra_inputs) + assert isinstance(masks, BlockMask) + if self.encoder is not None: + encoder_masks = self.encoder.get_attention_masks( + input_batch, tokenizer, extra_inputs + ) + assert isinstance(encoder_masks, BlockMask) + return {"llama3_masks": masks, "encoder_masks": encoder_masks} + def forward( self, tokens: torch.Tensor, pixel_values: torch.Tensor, grid_thw: torch.Tensor, special_tokens: SpecialTokens, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h_BSD = self.tok_embeddings(tokens) if self.tok_embeddings else tokens if self.encoder is not None: + assert ( + attention_masks is not None + ), "encoder only allows FlexAttention, so the llama3 must use FlexAttention as well." grid_hw = grid_thw[:, :, 1:] # Siglip2 only support image hw pixel_masks = E.reduce(grid_hw != -1, "n l hw -> n l", reduction="all") - i_NLD = self.encoder(pixel_values, pixel_masks, grid_hw) + i_NLD = self.encoder( + pixel_values, pixel_masks, grid_hw, attention_masks["encoder_masks"] + ) i_NLD = self.projector(i_NLD) h_BSD = _scatter_img_tokens( h_BSD, tokens, i_NLD, pixel_masks, special_tokens.img_id ) for layer in self.layers.values(): - h_BSD = layer(h_BSD, self.freqs_cis) + h_BSD = layer(h_BSD, self.freqs_cis, attention_masks["llama3_masks"]) h_BSD = self.norm(h_BSD) if self.norm else h_BSD output = self.output(h_BSD) if self.output else h_BSD diff --git a/torchtitan/experiments/vlm/model/siglip2.py b/torchtitan/experiments/vlm/model/siglip2.py index a1183f7c..69278350 100644 --- a/torchtitan/experiments/vlm/model/siglip2.py +++ b/torchtitan/experiments/vlm/model/siglip2.py @@ -8,8 +8,16 @@ import torch import torch.nn.functional as F from torch import nn +from torch.nn.attention.flex_attention import and_masks, BlockMask -from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, +) +from torchtitan.protocols.model import AttentionMasksType from .args import Siglip2ModelArgs @@ -125,11 +133,9 @@ def __init__(self, args: Siglip2ModelArgs): self.v_proj = nn.Linear(self.dim, self.dim) self.out_proj = nn.Linear(self.dim, self.dim) - self.attn = build_attention( - use_flex_attn=True, attn_mask_type=args.attn_mask_type - ) + self.inner_attention = FlexAttentionWrapper() - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, attention_masks: AttentionMasksType): xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) # Use self.head_dim instead of `n_heads` to infer the actual @@ -139,7 +145,8 @@ def forward(self, x: torch.Tensor): xk = E.rearrange(xk, "b l (h d) -> b h l d", d=self.head_dim) xv = E.rearrange(xv, "b l (h d) -> b h l d", d=self.head_dim) - output = self.attn(xq, xk, xv) + assert isinstance(attention_masks, BlockMask) + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) output = E.rearrange(output, "b h l d -> b l (h d)").contiguous() return self.out_proj(output) @@ -174,8 +181,10 @@ def __init__(self, args: Siglip2ModelArgs): self.layer_norm2 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) self.mlp = FeedForward(args) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.self_attn(self.layer_norm1(x)) + def forward( + self, x: torch.Tensor, attention_masks: AttentionMasksType + ) -> torch.Tensor: + x = x + self.self_attn(self.layer_norm1(x), attention_masks) x = x + self.mlp(self.layer_norm2(x)) return x @@ -198,18 +207,46 @@ def __init__(self, args: Siglip2ModelArgs): ) self.post_layernorm = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + + # TODO: this is duplicated in the main model forward. + # TODO: is this really required? Can we call this `get_attention_masks` + # inside the main model forward? At that time PP should already split the + # grid_thw correctly. + grid_hw = extra_inputs["grid_thw"][:, :, 1:] # Siglip2 only support image hw + pixel_masks = E.reduce(grid_hw != -1, "n l hw -> n l", reduction="all") + + mask_mods = [get_causal_mask_mod()] + match self.args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = pixel_masks.shape[0] + mask_mods.append(get_document_mask_mod(pixel_masks, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, pixel_masks.shape[1], pixel_masks.shape[1] + ) + def forward( self, pixel_values_NLD: torch.FloatTensor, pixel_masks_NL: torch.BoolTensor, grid_hw: torch.LongTensor, + attention_masks: AttentionMasksType, ): - init_attention_mask(pixel_masks_NL, eos_id=self.eos_id) - h = self.embeddings(pixel_values_NLD, grid_hw) for layer in self.layers.values(): - h = layer(h) + h = layer(h, attention_masks) h = self.post_layernorm(h) return h diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 277d64be..bf963a5b 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -6,238 +6,182 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from typing import Callable, ClassVar +import functools +from collections.abc import Callable +from typing import ClassVar import torch import torch.nn.functional as F from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention.flex_attention import ( _mask_mod_signature, + AuxOutput, BlockMask, create_block_mask, flex_attention, ) -from torchtitan.tools.utils import has_cuda_capability -# FlexAttention mask type. For each mask type, we initialize it at most once per -# batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to -# track the initialized mask. -FLEX_ATTN_MASK_T = tuple[str, int | None] +__all__ = [ + "FlexAttentionWrapper", + "ScaledDotProductAttentionWrapper", + "get_causal_mask_mod", + "get_document_mask_mod", + "get_fixed_block_mask_mod", + "create_attention_mask", +] -class FlexAttention(torch.nn.Module): - """FlexAttention module that uses torch.nn.attention.flex_attention. +class FlexAttentionWrapper(torch.nn.Module): + """Wrapper around `flex_attention` to make it torch.compile and CP compatible. - This module is a wrapper around torch.nn.attention.flex_attention. This module - implements certain common attention types, such as causal and block_causal. + This wrapper serves two purposes: + 1) Invoke `torch.compile` with a valid mode "max-autotune-no-cudagraphs" to + achieve good performance. + 2) Being a wrapper allows us to apply _ContextParallel to it. - Args: - attn_mask_type (str): The type of attention mask. Currently, we support - "causal" and "block_causal". "causal" means the lower triangle of the - attention matrix is masked. "block_causal" means the attention matrix - is divided into blocks, where block boundary is defined by EOS token, - and the lower triangle of each block is masked. - fixed_block_size (int | None): The block size to be used to perform attention. - If specified, each sequence will be further divided to blocks, where each - block has the maximum size of ``fixed_block_size``. A query will only attend - to the keys within the same block. + Note: + The forward function must have q, k, v as the first three arguments, and + block_mask as a keyword argument to be compatible with _ContextParallel. """ - # We registered flex_attention related attributes as class variables as we - # need to amortize the cost of compilation. - flex_attn: ClassVar[Callable] = torch.compile( + _compiled_flex_attn: ClassVar[Callable] = torch.compile( flex_attention, mode="max-autotune-no-cudagraphs" ) - compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) - used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set() - # Attention mask type to the created BlockMask. - # This allows us to keep track the created block masks for each - # new batch. We will use this to update the block mask when a - # new batch is created. This also allows user to create different - # block masks for different layers. - block_masks: ClassVar[dict[FLEX_ATTN_MASK_T, BlockMask]] = {} - - # Instance variables. - attn_mask_type: str - - def __init__( - self, attn_mask_type: str, fixed_block_size: int | None = None - ) -> None: - super().__init__() - if attn_mask_type not in ["causal", "block_causal"]: - raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") - self.attn_mask_type = attn_mask_type - self.fixed_block_size = fixed_block_size - - FlexAttention.used_attn_mask_types.add(self.mask_key) - - @property - def mask_key(self) -> FLEX_ATTN_MASK_T: - return (self.attn_mask_type, self.fixed_block_size) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + *, + block_mask: BlockMask, scale: float | None = None, - ) -> torch.Tensor: - block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) - - @staticmethod - def _get_causal_mask_mod() -> _mask_mod_signature: - def causal_mask( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ): - return q_idx >= kv_idx - - return causal_mask - - @staticmethod - def _get_block_causal_mask_mod( - batch: torch.Tensor, eos_id: int - ) -> _mask_mod_signature: - # batch is [b, s, h, d] shape - mask = batch == eos_id - mask[:, -1] = True - acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1) - seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32) - seq_idx[:, 1:] = acc_mask[:, :-1] - - def block_causal_mask( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ): - return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx) - - return block_causal_mask - - @staticmethod - def _fixed_block_mask_mod( - mask_mod: _mask_mod_signature, fixed_block_size: int - ) -> _mask_mod_signature: - """ - Given an arbitrary mask_mod, divide the input sequence to blocks - and only allow attention within the same block. - - Args: - mask_mod: The mask mod to apply to the documents - fixed_block_size: The number of tokens in each block. - """ - - # Credit to @drisspg. - def blocked_mask_mod( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ): - # Get the block index of the query and key - q_block = q_idx // fixed_block_size - kv_block = kv_idx // fixed_block_size - # Only allow attention within the same block - same_block = q_block == kv_block - # Apply the original mask mod - inner_mask = mask_mod( - b, h, q_idx % fixed_block_size, kv_idx % fixed_block_size - ) - - return same_block & inner_mask - - blocked_mask_mod.__name__ = ( - f"blocked_mask_mod_{mask_mod.__name__}_fixed_block_size_{fixed_block_size}" + ) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]: + # 1. _compiled_flex_attn has to be a class variable, otherwise there will + # be multiple compiled flex_attention instances, which can be slow. + # 2. `self._compiled_flex_attn` is not correct, `self` will be passed in + # as the first argument, which will cause an error. + # `FlexAttentionWrapper._compiled_flex_attn` is correct. + return FlexAttentionWrapper._compiled_flex_attn( + q, k, v, block_mask=block_mask, scale=scale ) - return blocked_mask_mod - - @staticmethod - @torch.no_grad() - def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: - # batch is [b, s, h, d] shape - for mask_key in FlexAttention.used_attn_mask_types: - attn_mask_type, fixed_block_size = mask_key - match attn_mask_type: - case "causal": - if FlexAttention.block_masks.get(mask_key, None) is not None: - continue - # We don't care about batch dimension -- - # all samples have the same lower triangle mask. - batch_dimension = 1 - mask_mod = FlexAttention._get_causal_mask_mod() - case "block_causal": - if eos_id is None: - raise RuntimeError( - "eos_id must be provided for block_causal mask." - ) - batch_dimension = batch.shape[0] - mask_mod = FlexAttention._get_block_causal_mask_mod(batch, eos_id) - case _: - raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") - - if fixed_block_size is not None and fixed_block_size > 0: - mask_mod = FlexAttention._fixed_block_mask_mod( - mask_mod, fixed_block_size - ) - - seq_len = batch.shape[1] - block_mask = FlexAttention.compiled_create_block_mask( - mask_mod, batch_dimension, None, seq_len, seq_len - ) - FlexAttention.block_masks[mask_key] = block_mask - - -class ScaledDotProductAttention(torch.nn.Module): - backends: ClassVar[list[SDPBackend]] = [] - - def __init__(self, attn_mask_type: str) -> None: + +class ScaledDotProductAttentionWrapper(torch.nn.Module): + """Wrapper around `F.scaled_dot_product_attention` to make it CP compatible. + + This wrapper is needed because `F.scaled_dot_product_attention` is not + a torch.nn.Module, and thus cannot be applied with _ContextParallel. + We need to wrap it into a torch.nn.Module. + + Note: + The forward function must have q, k, v as the first three arguments to be + compatible with _ContextParallel. + """ + + # TODO: remove sdpa_backends after PyTorch 2.9 is released. + sdpa_backends: ClassVar[list[SDPBackend]] = [] + + def __init__(self) -> None: super().__init__() - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) - - ScaledDotProductAttention._init_backend() - - @classmethod - def _init_backend(cls) -> None: - if cls.backends: - return - - # Add CuDNN on B200 w/ highest priority - cls.backends = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.MATH, - ] - if has_cuda_capability(10, 0): - cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) + if not self.sdpa_backends: + self.sdpa_backends = [ + SDPBackend.CUDNN_ATTENTION, + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + ] def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + *, scale: float | None = None, ) -> torch.Tensor: - assert self.backends, "SDPA Backends should not be empty." - with sdpa_kernel(self.backends, set_priority=True): - return F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale) - - -def build_attention( - use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None -): - if use_flex_attn: - return FlexAttention(attn_mask_type, fixed_block_size) - else: - if fixed_block_size is not None: - raise ValueError( - "TorchTitan with SDPA currently does not support fixed_block_size." - ) - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) - return ScaledDotProductAttention(attn_mask_type) - - -def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: - FlexAttention.init_attention_mask(batch, eos_id) + with sdpa_kernel(self.sdpa_backends, set_priority=True): + return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True) + + +# We cannot do inner function/closure because we won't be able to cache it -- +# if we an inner function, a new closure will be created every time +# `get_causal_mask_mod` is called. +def _causal_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor +) -> torch.Tensor: + """Causal mask that prevents attention to future tokens.""" + return q_idx >= kv_idx + + +def get_causal_mask_mod() -> _mask_mod_signature: + """Returns a causal mask modifier for flex attention. + + Returns: + A mask modifier function that implements causal masking. + """ + return _causal_mask + + +def get_document_mask_mod(batch: torch.Tensor, eos_id: int) -> _mask_mod_signature: + """Creates a document mask that prevents attention across document boundaries. + + Args: + batch: Input batch tensor with shape [b, s, h, d] + eos_id: End-of-sequence token ID that marks document boundaries + + Returns: + A mask modifier function that implements document-level masking. + """ + # batch is [b, s, h, d] shape + eos_mask = batch == eos_id + eos_mask[:, -1] = True + cumulative_mask = torch.cumsum(torch.where(eos_mask, 1, 0), dim=1) + sequence_indices = torch.zeros_like(cumulative_mask, dtype=torch.int32) + sequence_indices[:, 1:] = cumulative_mask[:, :-1] + + def document_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx] + + return document_mask + + +def get_fixed_block_mask_mod(fixed_block_size: int) -> _mask_mod_signature: + """ + Divide the input sequence into blocks and only allow attention within the same block. + + Args: + fixed_block_size: The number of tokens in each block. + + Returns: + A mask modifier function that implements block-wise attention masking. + """ + + # Credit to @drisspg. + def blocked_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + # Get the block index of the query and key + q_block = q_idx // fixed_block_size + kv_block = kv_idx // fixed_block_size + # Only allow attention within the same block + return q_block == kv_block + + blocked_mask_mod.__name__ = f"blocked_mask_mod_fixed_block_size_{fixed_block_size}" + + return blocked_mask_mod + + +_compiled_create_block_mask = torch.compile(create_block_mask) + + +@functools.lru_cache(4) +def create_attention_mask(*args, **kwargs): + """Create an attention mask using compiled create_block_mask. + + This function is cached to avoid recreating BlockMasks for the same + argumens. + """ + return _compiled_create_block_mask(*args, **kwargs) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 1c73cef7..fc79e5ba 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -84,6 +84,7 @@ def parallelize_deepseekv3( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, + use_flex_attn=use_flex_attn, ) maybe_enable_async_tp(job_config, world_mesh["tp"]) @@ -181,6 +182,7 @@ def apply_non_moe_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, + use_flex_attn: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -210,6 +212,18 @@ def apply_non_moe_tp( PrepareModuleInput, ) + if use_flex_attn: + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) + else: + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. @@ -218,8 +232,8 @@ def apply_non_moe_tp( layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate()), - desired_input_layouts=(Replicate(), Replicate()), + input_layouts=(Shard(1), Replicate(), None), + desired_input_layouts=(Replicate(), Replicate(), None), ), # NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor # so that the intermedidate results k is generated as a DTensor and its gradient is @@ -228,11 +242,7 @@ def apply_non_moe_tp( "attention.wkv_b": colwise_parallel(use_local_output=False), "attention.kv_norm": NoParallel(use_local_output=False), # NOTE: use_local_output=True so that the inputs to FlexAttention are plain Tensors - "attention.sdpa": prepare_module_input( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(1), Shard(1), Shard(1)), - use_local_output=True, - ), + "attention.inner_attention": attention_kernel_plan, "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), } diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index dc612faf..d5bc9b10 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -5,13 +5,22 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Tuple import torch from torch import nn -from torchtitan.models.attention import build_attention +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) from torchtitan.models.moe import FeedForward, MoE +from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs @@ -58,7 +67,7 @@ def find_correction_dim( def find_correction_range( low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """ Computes the range of correction dimensions for rotary positional embeddings. @@ -70,7 +79,7 @@ def find_correction_range( max_seq_len (int): Maximum sequence length. Returns: - Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. """ low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) @@ -175,12 +184,17 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): mscale = 0.1 * model_args.mscale * math.log(model_args.rope_factor) + 1.0 self.softmax_scale = self.softmax_scale * mscale * mscale - self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Forward pass for the Multi-Head Latent Attention (MLA) Layer. @@ -231,7 +245,14 @@ def forward( k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) - output = self.sdpa(q, k, v, scale=self.softmax_scale) + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask) + output = self.inner_attention( + q, k, v, block_mask=attention_masks, scale=self.softmax_scale + ) + else: + assert attention_masks is None + output = self.inner_attention(q, k, v, scale=self.softmax_scale) # Reshape and project output output = output.transpose( @@ -284,7 +305,12 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 self.layer_id = layer_id - def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): """ Forward pass for the Transformer block. @@ -295,7 +321,7 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): Returns: torch.Tensor: Output tensor with the same shape as the input. """ - x = x + self.attention(self.attention_norm(x), freqs_cis) + x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) else: @@ -360,9 +386,31 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: b=cutoff_factor * final_out_std, ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] + ) + def forward( self, tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -385,7 +433,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + h = layer(h, self.freqs_cis, attention_masks) h = self.norm(h) if self.norm is not None else h output = self.output(h) if self.output is not None else h return output diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 89066e86..4944af56 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -207,8 +207,8 @@ def apply_tp( layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), + input_layouts=(Shard(1), None, None), + desired_input_layouts=(Replicate(), None, None), ), "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 753ffae0..6f10719d 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -11,8 +11,17 @@ import torch import torch.nn.functional as F from torch import nn - -from torchtitan.models.attention import build_attention +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) +from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol from .args import RoPEScalingArgs, TransformerModelArgs @@ -181,7 +190,12 @@ def __init__(self, model_args: TransformerModelArgs): self.wo = nn.Linear( model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -192,6 +206,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Forward pass of the attention module. @@ -225,7 +240,16 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - output = self.sdpa(xq, xk, xv) + assert ( + isinstance(attention_masks, BlockMask) or attention_masks is None + ), attention_masks + + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + else: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) output = output.transpose( 1, 2 @@ -321,6 +345,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Perform a forward pass through the TransformerBlock. @@ -333,7 +358,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis) + h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -428,9 +453,31 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.rope_scaling_args, ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] + ) + def forward( self, tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -454,7 +501,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + h = layer(h, self.freqs_cis, attention_masks=attention_masks) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index a4f28bc8..a713bec6 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -11,9 +11,16 @@ import torch import torch.nn as nn +from torch.nn.attention.flex_attention import BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer + from torchtitan.config import JobConfig +AttentionMasksType = dict[str, BlockMask] | BlockMask + + @dataclass class BaseModelArgs: """All ModelArgs should inherit from this class. @@ -53,3 +60,13 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: buffer_device: Optional device to place buffers on during initialization. """ pass + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + raise NotImplementedError( + "This model does not support attention masking/Flex Attention." + ) diff --git a/torchtitan/train.py b/torchtitan/train.py index 287828d8..6441ff0b 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -24,7 +24,6 @@ ) from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils -from torchtitan.models.attention import init_attention_mask from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger @@ -416,12 +415,21 @@ def forward_backward_step( inputs = input_dict["input"] extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} - # Create the FlexAttention mask according to the input + # For arguments, like attention_masks, we have to put them in a separate + # dict as extra_inputs are not forwarded to other stages in PP, but + # extra_args are. + extra_args = {} + if getattr(self.model_args, "use_flex_attn", False): - init_attention_mask(inputs, self.tokenizer.eos_id) + extra_args["attention_masks"] = model_parts[0].get_attention_masks( + input_batch=inputs, + tokenizer=self.tokenizer, + extra_inputs=extra_inputs, + ) # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage + cp_mesh = parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( cp_mesh=parallel_dims.world_mesh["cp"], @@ -444,13 +452,17 @@ def forward_backward_step( self.pp_schedule.step( inputs, **extra_inputs, + **extra_args, target=targets, losses=losses, input_batch=inputs, ) else: self.pp_schedule.step( - target=targets, losses=losses, input_batch=inputs + **extra_args, + target=targets, + losses=losses, + input_batch=inputs, ) # accumulate losses across pipeline microbatches @@ -468,7 +480,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, **extra_inputs) + pred = model_parts[0](inputs, **extra_inputs, **extra_args) loss = self.loss_fn(pred, labels) # need to free pred before bwd to avoid peaking memory del pred From 2e32a67438e48feca8595f7a8b5558266e2d5b45 Mon Sep 17 00:00:00 2001 From: tomiock Date: Fri, 10 Oct 2025 10:52:27 +0000 Subject: [PATCH 14/18] sync --- run_generate.sh | 26 ++ run_generate_llama3.sh | 24 + torchtitan/generate.py | 67 ++- torchtitan/generate_llama3.py | 438 ++++++++++++++++++ torchtitan/generate_simple.py | 346 ++++++++++++++ torchtitan/models/attention.py | 338 ++++++-------- .../llama3/train_configs/llama3_1b.toml | 4 +- torchtitan/vlr/smolvlm/__init__.py | 2 + torchtitan/vlr/smolvlm/model/model.py | 8 + torchtitan/vlr/smolvlm/model/siglip2.py | 10 +- .../train_configs/llama_siglip_256.toml | 2 +- 11 files changed, 1024 insertions(+), 241 deletions(-) create mode 100755 run_generate.sh create mode 100755 run_generate_llama3.sh create mode 100644 torchtitan/generate_llama3.py create mode 100644 torchtitan/generate_simple.py diff --git a/run_generate.sh b/run_generate.sh new file mode 100755 index 00000000..6dee6461 --- /dev/null +++ b/run_generate.sh @@ -0,0 +1,26 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overwrites for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./run_train.sh +NGPU=${NGPU:-"1"} +export LOG_RANK=${LOG_RANK:-0} +CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml"} +INFERENCE_FILE=${INFERENCE_FILE:-"torchtitan.generate"} + + +CUDA_VISIBLE_DEVICES=2 \ +NCCL_P2P_DISABLE=1 \ +TORCH_NCCL_DUMP_ON_TIMEOUT=1 \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +-m ${INFERENCE_FILE} --job.config_file ${CONFIG_FILE} "$@" \ +--checkpoint.exclude-from-loading dataloader,optimizer,lr_scheduler + diff --git a/run_generate_llama3.sh b/run_generate_llama3.sh new file mode 100755 index 00000000..c36829d2 --- /dev/null +++ b/run_generate_llama3.sh @@ -0,0 +1,24 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overwrites for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./run_train.sh +NGPU=${NGPU:-"1"} +export LOG_RANK=${LOG_RANK:-0} +CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_1b.toml"} +INFERENCE_FILE=${INFERENCE_FILE:-"torchtitan.generate_llama3"} + + +CUDA_VISIBLE_DEVICES=2 \ +NCCL_P2P_DISABLE=1 \ +TORCH_NCCL_DUMP_ON_TIMEOUT=1 \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +-m ${INFERENCE_FILE} --job.config_file ${CONFIG_FILE} "$@" diff --git a/torchtitan/generate.py b/torchtitan/generate.py index f2e7f2a3..f7462743 100644 --- a/torchtitan/generate.py +++ b/torchtitan/generate.py @@ -24,6 +24,23 @@ from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger +def multinomial_sample_one(probs: torch.Tensor, rng = None): + q = torch.empty_like(probs).exponential_(1, generator=rng) + return torch.argmax(probs/q, dim=-1, keepdim=True).to(dtype=torch.long) + +def logits_to_probs( + logits, + temperature, + top_k +): + logits = logits / max(temperature, 1e-5) + if top_k is not None: + v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) + pivot = v.select(dim=-1, index=-1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs class Generator: """Generator class for SmolVLM model inference.""" @@ -173,6 +190,7 @@ def generate( top_p: Optional[float] = None, top_k: Optional[int] = None, do_sample: bool = True, + rng = None, ) -> str: """Generate text from messages and optional images. @@ -247,22 +265,22 @@ def generate( generate_fn = torch.compile(generate_fn, mode="reduce-overhead") # Generate tokens - with torch.cuda.amp.autocast(dtype=torch.bfloat16): + with torch.amp.autocast('cuda', dtype=torch.bfloat16): output_ids = generate_fn( model=model, input_ids=input_ids, - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, + pixel_values=None, + patch_attention_mask=None, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=do_sample, + rng=rng, ) # Decode output generated_ids = output_ids[0, input_ids.shape[1]:] - print(generated_ids.v) generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) return generated_text @@ -278,6 +296,7 @@ def _generate_tokens( top_p: float, top_k: int, do_sample: bool, + rng = None, ) -> torch.Tensor: """Core generation loop.""" @@ -306,38 +325,8 @@ def _generate_tokens( logits = model(**input_dict) # Get next token logits - next_token_logits = logits[:, -1, :] - - # Apply temperature - if temperature > 0: - next_token_logits = next_token_logits / temperature - - # Sample or greedy decode - if do_sample: - # Apply top-k filtering - if top_k > 0: - indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] - next_token_logits[indices_to_remove] = -float('inf') - - # Apply top-p filtering - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - next_token_logits[indices_to_remove] = -float('inf') - - # Sample - probs = torch.softmax(next_token_logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - else: - # Greedy decoding - next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) + probs = logits_to_probs(logits[:, -1, :], temperature, top_k) + next_token = multinomial_sample_one(probs, rng=rng) # Append to generated sequence generated_ids = torch.cat([generated_ids, next_token], dim=1) @@ -411,6 +400,10 @@ def batch_generate(self, input_file: str, output_file: str): with open(input_file, 'r') as f: inputs = json.load(f) + + rng = None + if seed is not None: + rng = torch.Generator(input_ids.device).manual_seed(42) results = [] for i, item in enumerate(inputs): @@ -427,7 +420,7 @@ def batch_generate(self, input_file: str, output_file: str): images.append(image) # Generate response - response = self.generate(messages, images=images if images else None) + response = self.generate(messages, images=images if images else None, rng=rng) results.append({ 'input': item, diff --git a/torchtitan/generate_llama3.py b/torchtitan/generate_llama3.py new file mode 100644 index 00000000..083b34ce --- /dev/null +++ b/torchtitan/generate_llama3.py @@ -0,0 +1,438 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os +import time +from typing import Optional, List, Dict, Any +import numpy as np + +import torch +from torch.distributed.elastic.multiprocessing.errors import record +from transformers import AutoProcessor + +import torchtitan.protocols.train_spec as train_spec_module +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.tokenizer import HuggingFaceTokenizer +from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.protocols.model_converter import build_model_converters +from torchtitan.tools import utils +from torchtitan.tools.logging import init_logger, logger + + +class Generator: + """Generator class for SmolVLM model inference.""" + + def __init__(self, job_config: JobConfig): + torch._C._log_api_usage_once("torchtitan.generate") + + self.job_config = job_config + + logger.info(f"Starting generation: {job_config.job.description}") + + if job_config.experimental.custom_import: + importlib.import_module(job_config.experimental.custom_import) + + if job_config.job.print_args: + logger.info(f"Running with args: {job_config.to_dict()}") + + device_module, device_type = utils.device_module, utils.device_type + self.device = torch.device(f"{device_type}:{int(os.environ.get('LOCAL_RANK', 0))}") + device_module.set_device(self.device) + + # Initialize distributed + dist_utils.init_distributed( + job_config.comm, + enable_cpu_backend=False, + base_folder=job_config.job.dump_folder, + ) + + world_size = int(os.environ.get("WORLD_SIZE", 1)) + parallelism_config = job_config.parallelism + self.parallel_dims = parallel_dims = ParallelDims( + dp_shard=parallelism_config.data_parallel_shard_degree, + dp_replicate=parallelism_config.data_parallel_replicate_degree, + cp=parallelism_config.context_parallel_degree, + tp=parallelism_config.tensor_parallel_degree, + pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, + etp=parallelism_config.expert_tensor_parallel_degree, + world_size=world_size, + ) + + world_mesh = parallel_dims.world_mesh + + # Set random seed + dist_utils.set_determinism( + world_mesh, + self.device, + job_config.training.seed, + deterministic=False, + ) + + self.train_spec = train_spec_module.get_train_spec(job_config.model.name) + + # Build tokenizer + self.tokenizer = ( + self.train_spec.build_tokenizer_fn(job_config) + if self.train_spec.build_tokenizer_fn is not None + else None + ) + + # Build model + model_args = self.train_spec.model_args[job_config.model.flavor] + model_args.update_from_config(job_config) + self.model_args = model_args + + logger.info( + f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" + ) + + with ( + torch.device("meta"), + utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), + ): + model = self.train_spec.model_cls(model_args) + + # Build model converters (e.g., for float8) + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) + + # Apply parallelism + if parallel_dims.pp_enabled: + raise NotImplementedError("Pipeline parallelism not supported for generation") + else: + model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) + + # Move to device and initialize + init_device = self.device.type + model.to_empty(device=init_device) + with torch.no_grad(): + model.init_weights() + model.eval() + + self.model_parts = [model] + + # Setup checkpoint manager for loading + self.checkpointer = CheckpointManager( + dataloader=None, # No dataloader needed for generation + model_parts=self.model_parts, + optimizers=None, # No optimizer needed for generation + lr_schedulers=None, # No lr_scheduler needed for generation + states={}, + checkpoint_config=job_config.checkpoint, + sd_adapter=( + self.train_spec.state_dict_adapter( + model_args, job_config.model.hf_assets_path + ) + if self.train_spec.state_dict_adapter + else None + ), + base_folder=job_config.job.dump_folder, + ft_manager=None, # No fault tolerance for generation + ) + + # Load checkpoint + self.checkpointer.load(step=job_config.checkpoint.load_step) + logger.info(f"Loaded checkpoint from step {job_config.checkpoint.load_step}") + + self.processor = AutoProcessor.from_pretrained(job_config.model.hf_assets_path) + + # Load chat template + template_path = "torchtitan/vlr/smolvlm/datasets/template.jinja" + if os.path.exists(template_path): + with open(template_path, 'r') as f: + self.chat_template = f.read() + else: + logger.warning(f"Chat template not found at {template_path}, using default") + self.chat_template = None + + # Setup generation parameters + self.max_new_tokens = getattr(job_config, 'max_new_tokens', 256) + self.temperature = getattr(job_config, 'temperature', 0.7) + self.top_p = getattr(job_config, 'top_p', 0.9) + self.top_k = getattr(job_config, 'top_k', 50) + + logger.info("Generator initialized successfully") + + @torch.no_grad() + def generate( + self, + messages: List[Dict[str, Any]], + max_new_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + do_sample: bool = True, + ) -> str: + """Generate text from messages. + + Args: + messages: List of message dictionaries with 'role' and 'content' + max_new_tokens: Maximum number of tokens to generate + temperature: Sampling temperature + top_p: Top-p sampling parameter + top_k: Top-k sampling parameter + do_sample: Whether to use sampling or greedy decoding + + Returns: + Generated text string + """ + max_new_tokens = max_new_tokens or self.max_new_tokens + temperature = temperature or self.temperature + top_p = top_p or self.top_p + top_k = top_k or self.top_k + + model = self.model_parts[0] + model.eval() + + # Tokenize input + if self.chat_template: + input_ids = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + chat_template=self.chat_template, + add_generation_prompt=True, + return_tensors="pt", + ) + else: + # Fallback to default chat template + input_ids = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + ) + + if isinstance(input_ids, dict): + input_ids = input_ids["input_ids"] + + input_ids = input_ids.to(self.device) + + # Setup generation context (compile if enabled) + generate_fn = self._generate_tokens + if self.job_config.compile.enable and "model" in self.job_config.compile.components: + logger.info('Compiling model...') + generate_fn = torch.compile(generate_fn, mode="reduce-overhead") + + # Generate tokens + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + output_ids = generate_fn( + model=model, + input_ids=input_ids, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + do_sample=do_sample, + ) + + # Decode output + generated_ids = output_ids[0, input_ids.shape[1]:] + generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) + + return generated_text + + def _generate_tokens( + self, + model: torch.nn.Module, + input_ids: torch.Tensor, + max_new_tokens: int, + temperature: float, + top_p: float, + top_k: int, + do_sample: bool, + ) -> torch.Tensor: + """Core generation loop.""" + + batch_size = input_ids.shape[0] + generated_ids = input_ids.clone() + + # Cache for key-value pairs (if using KV cache in the future) + past_key_values = None + + for _ in range(max_new_tokens): + # Forward pass + with torch.no_grad(): + # Prepare input dict + input_dict = { + "tokens": generated_ids, + } + + # Get model output + logits = model(**input_dict) + + # Get next token logits + next_token_logits = logits[:, -1, :] + + # Apply temperature + if temperature > 0: + next_token_logits = next_token_logits / temperature + + # Sample or greedy decode + if False: + # Apply top-k filtering + if top_k > 0: + indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] + next_token_logits[indices_to_remove] = -float('inf') + + # Apply top-p filtering + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + next_token_logits[indices_to_remove] = -float('inf') + + # Sample + probs = torch.softmax(next_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + # Greedy decoding + next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) + + # Append to generated sequence + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + # Check for EOS token + if (next_token == self.tokenizer.eos_token): + break + + return generated_ids + + def interactive_generate(self): + """Interactive generation mode for testing.""" + logger.info("Starting interactive generation mode. Type 'quit' to exit.") + + while True: + try: + user_input = input("\nEnter your prompt (or 'quit' to exit): ").strip() + + if user_input.lower() == 'quit': + break + + + # Create message format + messages = [ + { + "user": user_input, + "assistant": "" # Will be filled by generation + } + ] + + logger.info("Generating response...") + start_time = time.perf_counter() + + response = self.generate(messages) + + generation_time = time.perf_counter() - start_time + logger.info(f"Generation completed in {generation_time:.2f}s") + + print(f"\nGenerated response:\n{response}") + + except KeyboardInterrupt: + logger.info("\nInterrupted by user") + break + except Exception as e: + logger.error(f"Error during generation: {e}") + import traceback + traceback.print_exc() + + def batch_generate(self, input_file: str, output_file: str): + """Generate responses for a batch of inputs from a file. + + Args: + input_file: Path to JSON file with inputs + output_file: Path to save outputs + """ + import json + + logger.info(f"Loading inputs from {input_file}") + + with open(input_file, 'r') as f: + inputs = json.load(f) + + results = [] + for i, item in enumerate(inputs): + logger.info(f"Processing item {i+1}/{len(inputs)}") + + messages = item.get('messages', []) + + # Generate response + response = self.generate(messages) + + results.append({ + 'input': item, + 'output': response + }) + + # Save results + with open(output_file, 'w') as f: + json.dump(results, f, indent=2) + + logger.info(f"Results saved to {output_file}") + + def close(self): + """Cleanup resources.""" + if hasattr(self, 'checkpointer'): + self.checkpointer.close() + logger.info("Generator closed") + + +@record +def main(): + """Main entry point for generation.""" + init_logger() + + # Parse configuration + config_manager = ConfigManager() + config = config_manager.parse_args() + + generator = None + try: + # Initialize generator + generator = Generator(config) + + # Check for generation mode from config or command line + generation_mode = getattr(config, 'generation_mode', 'interactive') + + if generation_mode == 'interactive': + generator.interactive_generate() + elif generation_mode == 'batch': + input_file = getattr(config, 'input_file', 'inputs.json') + output_file = getattr(config, 'output_file', 'outputs.json') + generator.batch_generate(input_file, output_file) + else: + # Single generation example + messages = [ + { + "user": "What is the capital of France?", + "assistant": "" + } + ] + response = generator.generate(messages) + logger.info(f"Generated: {response}") + + except Exception as e: + logger.error(f"Error during generation: {e}") + if generator: + generator.close() + raise + else: + if generator: + generator.close() + torch.distributed.destroy_process_group() + logger.info("Process group destroyed") + + +if __name__ == "__main__": + main() diff --git a/torchtitan/generate_simple.py b/torchtitan/generate_simple.py new file mode 100644 index 00000000..f21021f9 --- /dev/null +++ b/torchtitan/generate_simple.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import time +from pathlib import Path +from typing import Optional + +import torch +import numpy as np +from PIL import Image +from transformers import AutoProcessor + +from torchtitan.config import JobConfig, ConfigManager, TORCH_DTYPE_MAP +from torchtitan.tools import utils +from torchtitan.tools.logging import init_logger, logger +from torchtitan.components.tokenizer import HuggingFaceTokenizer +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.distributed import ParallelDims, utils as dist_utils + +# Import SmolVLM specific components +from torchtitan.vlr.smolvlm.model.args import Llama3Siglip2ModelArgs, Siglip2ModelArgs +from torchtitan.vlr.smolvlm.model.model import Llama3Siglip2Transformer +from torchtitan.vlr.smolvlm.model.state_dict_adapter import SmolVLMStateDictAdapter + + +class SimpleGenerator: + """Barebones generator for debugging using CheckpointManager.""" + + def __init__(self, job_config: JobConfig): + self.job_config = job_config + + # Setup device + device_module, device_type = utils.device_module, utils.device_type + self.device = torch.device(f"{device_type}:{int(os.environ.get('LOCAL_RANK', 0))}") + device_module.set_device(self.device) + + logger.info(f"Device: {self.device}") + + # Init distributed (needed for checkpoint loading) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + if world_size > 1 or int(os.environ.get("RANK", 0)) >= 0: + dist_utils.init_distributed( + job_config.comm, + enable_cpu_backend=False, + base_folder=job_config.job.dump_folder, + ) + + # Setup parallel dims (minimal - no parallelism for inference) + self.parallel_dims = ParallelDims( + dp_shard=1, + dp_replicate=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=world_size, + ) + + # Load tokenizer using model's hf_assets_path + tokenizer_path = job_config.model.hf_assets_path + self.tokenizer = HuggingFaceTokenizer(tokenizer_path) + self.tokenizer.image_id = job_config.special_tokens.img_id + + logger.info(f"Tokenizer loaded from: {tokenizer_path}") + logger.info(f"Vocab size: {len(self.tokenizer)}") + logger.info(f"Special tokens - BOS: {self.tokenizer.bos_id}, EOS: {self.tokenizer.eos_id}, PAD: {self.tokenizer.pad_id}") + logger.info(f"Image token ID: {self.tokenizer.image_id}") + + # Load image processor + processor = AutoProcessor.from_pretrained(tokenizer_path) + self.image_processor = processor.image_processor + + # Load chat template + template_path = Path("torchtitan/vlr/smolvlm/datasets/template.jinja") + if template_path.exists(): + with open(template_path, 'r') as f: + self.chat_template = f.read() + logger.info("Chat template loaded") + else: + logger.warning(f"Template not found at {template_path}") + self.chat_template = None + + # Build model + self.model_args = self._get_model_args() + self.model = self._build_model() + + # Load checkpoint using CheckpointManager + self._load_checkpoint() + + self.model.eval() + logger.info("Model loaded and ready") + + def _get_model_args(self): + """Get model args from job config.""" + from torchtitan.protocols import train_spec as train_spec_module + + train_spec = train_spec_module.get_train_spec(self.job_config.model.name) + model_args = train_spec.model_args[self.job_config.model.flavor] + model_args.update_from_config(self.job_config) + + # Override for inference + model_args.use_flex_attn = False + model_args.encoder.use_flex_attn = False + + logger.info(f"Model args: {model_args}") + return model_args + + def _build_model(self): + """Build model using torchtitan's approach.""" + logger.info(f"Building {self.job_config.model.name} {self.job_config.model.flavor}") + + dtype = TORCH_DTYPE_MAP[self.job_config.training.dtype] + + with torch.device("meta"), utils.set_default_dtype(dtype): + model = Llama3Siglip2Transformer(self.model_args) + + # Initialize on device + device_type = utils.device_type + model.to_empty(device=device_type) + with torch.no_grad(): + model.init_encoder_weights(buffer_device=device_type) + + logger.info("Model structure created") + return model + + def _load_checkpoint(self): + """Load checkpoint using CheckpointManager.""" + logger.info("Setting up CheckpointManager") + + # Create state dict adapter if available + sd_adapter = SmolVLMStateDictAdapter( + self.model_args, + self.job_config.model.hf_assets_path + ) + + # Create checkpoint manager + self.checkpointer = CheckpointManager( + dataloader=None, # Not needed for inference + model_parts=[self.model], + optimizers=None, # Not needed for inference + lr_schedulers=None, # Not needed for inference + states={}, # No training state needed + checkpoint_config=self.job_config.checkpoint, + sd_adapter=sd_adapter, + base_folder=self.job_config.job.dump_folder, + ft_manager=None, + ) + + # Load checkpoint + load_step = self.job_config.checkpoint.load_step + logger.info(f"Loading checkpoint at step: {load_step}") + self.checkpointer.load(step=load_step) + logger.info("Checkpoint loaded successfully") + + def prepare_inputs(self, prompt: str, image_path: Optional[str] = None): + """Prepare inputs - debug version.""" + + # Create messages + messages = [{"user": prompt, "assistant": ""}] + + # Apply chat template (without tokenizing first) + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + chat_template=self.chat_template, + add_generation_prompt=True, + ) + + print("\n" + "="*80) + print("FORMATTED TEXT:") + print(repr(text)) + print("="*80) + + # Tokenize + input_ids = self.tokenizer.encode(text) + print(f"\nInput tokens ({len(input_ids)}): {input_ids[:50]}...") + + # Decode to verify + decoded = self.tokenizer.decode(input_ids) + print(f"\nDecoded input:\n{repr(decoded[:200])}...") + + input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device).unsqueeze(0) + + # Process image + pixel_values = None + patch_attention_mask = None + + if image_path: + image = Image.open(image_path).resize((512, 512)) + vision_inputs = self.image_processor([image]) + pixel_values = torch.tensor(np.array(vision_inputs['pixel_values'])).squeeze() + pixel_values = pixel_values.unsqueeze(0).unsqueeze(0).to(self.device, dtype=torch.bfloat16) + + patch_attention_mask = torch.tensor(vision_inputs['pixel_attention_mask']) + patch_attention_mask = patch_attention_mask.unsqueeze(0).unsqueeze(0).to(self.device) + + print(f"\nImage processed. Pixel values shape: {pixel_values.shape}") + + return input_ids, pixel_values, patch_attention_mask + + @torch.no_grad() + def generate_greedy(self, prompt: str, image_path: Optional[str] = None, max_tokens: int = 50): + """Greedy generation with detailed logging.""" + + print("\n" + "="*80) + print("STARTING GENERATION") + print("="*80) + + input_ids, pixel_values, patch_attention_mask = self.prepare_inputs(prompt, image_path) + + print(f"\nInitial input_ids shape: {input_ids.shape}") + print(f"Starting generation loop...\n") + + generated = input_ids.clone() + + for step in range(max_tokens): + # Forward pass + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): + logits = self.model( + input_ids=generated, + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Get next token (greedy) + next_token_logits = logits[:, -1, :] + next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) + + # Decode the token + token_text = self.tokenizer.decode([next_token.item()]) + + print(f"Step {step:3d} | Token: {next_token.item():5d} | Text: {repr(token_text)}") + + # Check for EOS + if next_token.item() == self.tokenizer.eos_id: + print("\n*** EOS token generated ***") + break + + # Append + generated = torch.cat([generated, next_token], dim=1) + + # Check for repetition + if step > 5: + last_tokens = generated[0, -6:].tolist() + if len(set(last_tokens)) <= 2: + print(f"\n*** WARNING: Repetition detected in last 6 tokens: {last_tokens} ***") + + print("\n" + "="*80) + print("GENERATION COMPLETE") + print("="*80) + + # Decode full response + generated_ids = generated[0].tolist() + full_text = self.tokenizer.decode(generated_ids) + + print(f"\nGenerated tokens: {generated_ids}") + print(f"\nFull decoded text:\n{full_text}") + + # Try to extract assistant response + if "<|im_start|>assistant" in full_text: + response = full_text.split("<|im_start|>assistant")[-1] + if "<|im_end|>" in response: + response = response.split("<|im_end|>")[0] + response = response.strip() + print(f"\nExtracted assistant response:\n{response}") + return response + + return full_text + + @torch.no_grad() + def test_forward_pass(self, prompt: str, image_path: Optional[str] = None): + """Test a single forward pass with detailed output.""" + + print("\n" + "="*80) + print("TESTING FORWARD PASS") + print("="*80) + + input_ids, pixel_values, patch_attention_mask = self.prepare_inputs(prompt, image_path) + + print(f"\nInput shapes:") + print(f" input_ids: {input_ids.shape}") + if pixel_values is not None: + print(f" pixel_values: {pixel_values.shape}") + print(f" patch_attention_mask: {patch_attention_mask.shape}") + + # Forward pass + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): + logits = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + print(f"\nOutput logits shape: {logits.shape}") + print(f"Logits dtype: {logits.dtype}") + print(f"Logits range: [{logits.min().item():.2f}, {logits.max().item():.2f}]") + + # Get next token predictions + last_logits = logits[0, -1, :] + print(f"\nLast position logits stats:") + print(f" Mean: {last_logits.mean().item():.4f}") + print(f" Std: {last_logits.std().item():.4f}") + print(f" Min: {last_logits.min().item():.4f}") + print(f" Max: {last_logits.max().item():.4f}") + + # Top 10 tokens + top_logits, top_indices = torch.topk(last_logits, k=10) + print(f"\nTop 10 predicted tokens:") + for i, (logit, idx) in enumerate(zip(top_logits, top_indices)): + token_text = self.tokenizer.decode([idx.item()]) + print(f" {i+1}. Token {idx.item():5d} (logit: {logit.item():7.2f}): {repr(token_text)}") + + return logits + + +def main(): + config_manager = ConfigManager() + job_config = config_manager.parse_args() + + # Initialize logger + init_logger() + + logger.info("Job config loaded:") + logger.info(f" Model: {job_config.model.name} / {job_config.model.flavor}") + logger.info(f" HF assets path: {job_config.model.hf_assets_path}") + logger.info(f" Checkpoint folder: {job_config.checkpoint.folder}") + logger.info(f" Load step: {job_config.checkpoint.load_step}") + + # Create generator + generator = SimpleGenerator(job_config) + + # Run test or generation + if args.test_forward: + generator.test_forward_pass(args.prompt, args.image) + else: + generator.generate_greedy(args.prompt, args.image, args.max_tokens) + + +if __name__ == "__main__": + main() diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 277d64be..bf963a5b 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -6,238 +6,182 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from typing import Callable, ClassVar +import functools +from collections.abc import Callable +from typing import ClassVar import torch import torch.nn.functional as F from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention.flex_attention import ( _mask_mod_signature, + AuxOutput, BlockMask, create_block_mask, flex_attention, ) -from torchtitan.tools.utils import has_cuda_capability -# FlexAttention mask type. For each mask type, we initialize it at most once per -# batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to -# track the initialized mask. -FLEX_ATTN_MASK_T = tuple[str, int | None] +__all__ = [ + "FlexAttentionWrapper", + "ScaledDotProductAttentionWrapper", + "get_causal_mask_mod", + "get_document_mask_mod", + "get_fixed_block_mask_mod", + "create_attention_mask", +] -class FlexAttention(torch.nn.Module): - """FlexAttention module that uses torch.nn.attention.flex_attention. +class FlexAttentionWrapper(torch.nn.Module): + """Wrapper around `flex_attention` to make it torch.compile and CP compatible. - This module is a wrapper around torch.nn.attention.flex_attention. This module - implements certain common attention types, such as causal and block_causal. + This wrapper serves two purposes: + 1) Invoke `torch.compile` with a valid mode "max-autotune-no-cudagraphs" to + achieve good performance. + 2) Being a wrapper allows us to apply _ContextParallel to it. - Args: - attn_mask_type (str): The type of attention mask. Currently, we support - "causal" and "block_causal". "causal" means the lower triangle of the - attention matrix is masked. "block_causal" means the attention matrix - is divided into blocks, where block boundary is defined by EOS token, - and the lower triangle of each block is masked. - fixed_block_size (int | None): The block size to be used to perform attention. - If specified, each sequence will be further divided to blocks, where each - block has the maximum size of ``fixed_block_size``. A query will only attend - to the keys within the same block. + Note: + The forward function must have q, k, v as the first three arguments, and + block_mask as a keyword argument to be compatible with _ContextParallel. """ - # We registered flex_attention related attributes as class variables as we - # need to amortize the cost of compilation. - flex_attn: ClassVar[Callable] = torch.compile( + _compiled_flex_attn: ClassVar[Callable] = torch.compile( flex_attention, mode="max-autotune-no-cudagraphs" ) - compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) - used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set() - # Attention mask type to the created BlockMask. - # This allows us to keep track the created block masks for each - # new batch. We will use this to update the block mask when a - # new batch is created. This also allows user to create different - # block masks for different layers. - block_masks: ClassVar[dict[FLEX_ATTN_MASK_T, BlockMask]] = {} - - # Instance variables. - attn_mask_type: str - - def __init__( - self, attn_mask_type: str, fixed_block_size: int | None = None - ) -> None: - super().__init__() - if attn_mask_type not in ["causal", "block_causal"]: - raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") - self.attn_mask_type = attn_mask_type - self.fixed_block_size = fixed_block_size - - FlexAttention.used_attn_mask_types.add(self.mask_key) - - @property - def mask_key(self) -> FLEX_ATTN_MASK_T: - return (self.attn_mask_type, self.fixed_block_size) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + *, + block_mask: BlockMask, scale: float | None = None, - ) -> torch.Tensor: - block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) - - @staticmethod - def _get_causal_mask_mod() -> _mask_mod_signature: - def causal_mask( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ): - return q_idx >= kv_idx - - return causal_mask - - @staticmethod - def _get_block_causal_mask_mod( - batch: torch.Tensor, eos_id: int - ) -> _mask_mod_signature: - # batch is [b, s, h, d] shape - mask = batch == eos_id - mask[:, -1] = True - acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1) - seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32) - seq_idx[:, 1:] = acc_mask[:, :-1] - - def block_causal_mask( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ): - return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx) - - return block_causal_mask - - @staticmethod - def _fixed_block_mask_mod( - mask_mod: _mask_mod_signature, fixed_block_size: int - ) -> _mask_mod_signature: - """ - Given an arbitrary mask_mod, divide the input sequence to blocks - and only allow attention within the same block. - - Args: - mask_mod: The mask mod to apply to the documents - fixed_block_size: The number of tokens in each block. - """ - - # Credit to @drisspg. - def blocked_mask_mod( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ): - # Get the block index of the query and key - q_block = q_idx // fixed_block_size - kv_block = kv_idx // fixed_block_size - # Only allow attention within the same block - same_block = q_block == kv_block - # Apply the original mask mod - inner_mask = mask_mod( - b, h, q_idx % fixed_block_size, kv_idx % fixed_block_size - ) - - return same_block & inner_mask - - blocked_mask_mod.__name__ = ( - f"blocked_mask_mod_{mask_mod.__name__}_fixed_block_size_{fixed_block_size}" + ) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]: + # 1. _compiled_flex_attn has to be a class variable, otherwise there will + # be multiple compiled flex_attention instances, which can be slow. + # 2. `self._compiled_flex_attn` is not correct, `self` will be passed in + # as the first argument, which will cause an error. + # `FlexAttentionWrapper._compiled_flex_attn` is correct. + return FlexAttentionWrapper._compiled_flex_attn( + q, k, v, block_mask=block_mask, scale=scale ) - return blocked_mask_mod - - @staticmethod - @torch.no_grad() - def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: - # batch is [b, s, h, d] shape - for mask_key in FlexAttention.used_attn_mask_types: - attn_mask_type, fixed_block_size = mask_key - match attn_mask_type: - case "causal": - if FlexAttention.block_masks.get(mask_key, None) is not None: - continue - # We don't care about batch dimension -- - # all samples have the same lower triangle mask. - batch_dimension = 1 - mask_mod = FlexAttention._get_causal_mask_mod() - case "block_causal": - if eos_id is None: - raise RuntimeError( - "eos_id must be provided for block_causal mask." - ) - batch_dimension = batch.shape[0] - mask_mod = FlexAttention._get_block_causal_mask_mod(batch, eos_id) - case _: - raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") - - if fixed_block_size is not None and fixed_block_size > 0: - mask_mod = FlexAttention._fixed_block_mask_mod( - mask_mod, fixed_block_size - ) - - seq_len = batch.shape[1] - block_mask = FlexAttention.compiled_create_block_mask( - mask_mod, batch_dimension, None, seq_len, seq_len - ) - FlexAttention.block_masks[mask_key] = block_mask - - -class ScaledDotProductAttention(torch.nn.Module): - backends: ClassVar[list[SDPBackend]] = [] - - def __init__(self, attn_mask_type: str) -> None: + +class ScaledDotProductAttentionWrapper(torch.nn.Module): + """Wrapper around `F.scaled_dot_product_attention` to make it CP compatible. + + This wrapper is needed because `F.scaled_dot_product_attention` is not + a torch.nn.Module, and thus cannot be applied with _ContextParallel. + We need to wrap it into a torch.nn.Module. + + Note: + The forward function must have q, k, v as the first three arguments to be + compatible with _ContextParallel. + """ + + # TODO: remove sdpa_backends after PyTorch 2.9 is released. + sdpa_backends: ClassVar[list[SDPBackend]] = [] + + def __init__(self) -> None: super().__init__() - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) - - ScaledDotProductAttention._init_backend() - - @classmethod - def _init_backend(cls) -> None: - if cls.backends: - return - - # Add CuDNN on B200 w/ highest priority - cls.backends = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.MATH, - ] - if has_cuda_capability(10, 0): - cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) + if not self.sdpa_backends: + self.sdpa_backends = [ + SDPBackend.CUDNN_ATTENTION, + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + ] def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + *, scale: float | None = None, ) -> torch.Tensor: - assert self.backends, "SDPA Backends should not be empty." - with sdpa_kernel(self.backends, set_priority=True): - return F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale) - - -def build_attention( - use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None -): - if use_flex_attn: - return FlexAttention(attn_mask_type, fixed_block_size) - else: - if fixed_block_size is not None: - raise ValueError( - "TorchTitan with SDPA currently does not support fixed_block_size." - ) - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) - return ScaledDotProductAttention(attn_mask_type) - - -def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: - FlexAttention.init_attention_mask(batch, eos_id) + with sdpa_kernel(self.sdpa_backends, set_priority=True): + return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True) + + +# We cannot do inner function/closure because we won't be able to cache it -- +# if we an inner function, a new closure will be created every time +# `get_causal_mask_mod` is called. +def _causal_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor +) -> torch.Tensor: + """Causal mask that prevents attention to future tokens.""" + return q_idx >= kv_idx + + +def get_causal_mask_mod() -> _mask_mod_signature: + """Returns a causal mask modifier for flex attention. + + Returns: + A mask modifier function that implements causal masking. + """ + return _causal_mask + + +def get_document_mask_mod(batch: torch.Tensor, eos_id: int) -> _mask_mod_signature: + """Creates a document mask that prevents attention across document boundaries. + + Args: + batch: Input batch tensor with shape [b, s, h, d] + eos_id: End-of-sequence token ID that marks document boundaries + + Returns: + A mask modifier function that implements document-level masking. + """ + # batch is [b, s, h, d] shape + eos_mask = batch == eos_id + eos_mask[:, -1] = True + cumulative_mask = torch.cumsum(torch.where(eos_mask, 1, 0), dim=1) + sequence_indices = torch.zeros_like(cumulative_mask, dtype=torch.int32) + sequence_indices[:, 1:] = cumulative_mask[:, :-1] + + def document_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx] + + return document_mask + + +def get_fixed_block_mask_mod(fixed_block_size: int) -> _mask_mod_signature: + """ + Divide the input sequence into blocks and only allow attention within the same block. + + Args: + fixed_block_size: The number of tokens in each block. + + Returns: + A mask modifier function that implements block-wise attention masking. + """ + + # Credit to @drisspg. + def blocked_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + # Get the block index of the query and key + q_block = q_idx // fixed_block_size + kv_block = kv_idx // fixed_block_size + # Only allow attention within the same block + return q_block == kv_block + + blocked_mask_mod.__name__ = f"blocked_mask_mod_fixed_block_size_{fixed_block_size}" + + return blocked_mask_mod + + +_compiled_create_block_mask = torch.compile(create_block_mask) + + +@functools.lru_cache(4) +def create_attention_mask(*args, **kwargs): + """Create an attention mask using compiled create_block_mask. + + This function is cached to avoid recreating BlockMasks for the same + argumens. + """ + return _compiled_create_block_mask(*args, **kwargs) diff --git a/torchtitan/models/llama3/train_configs/llama3_1b.toml b/torchtitan/models/llama3/train_configs/llama3_1b.toml index 56ea5864..4c27a807 100644 --- a/torchtitan/models/llama3/train_configs/llama3_1b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_1b.toml @@ -44,12 +44,14 @@ context_parallel_degree = 1 [checkpoint] enable = true -last_save_in_hf = true folder = "checkpoint" interval = 100 last_save_model_only = true +initial_load_in_hf = true +last_save_in_hf = true export_dtype = "float32" async_mode = "async" # ["disabled", "async", "async_with_pinned_mem"] +exclude_from_loading = ["dataloader", "optimizer", "train_state"] [compile] enable=false diff --git a/torchtitan/vlr/smolvlm/__init__.py b/torchtitan/vlr/smolvlm/__init__.py index a0714c48..5661c412 100644 --- a/torchtitan/vlr/smolvlm/__init__.py +++ b/torchtitan/vlr/smolvlm/__init__.py @@ -71,6 +71,8 @@ multiple_of=1024, rope_theta=100000, vocab_size=49280, + use_flex_attn = False, + attn_mask_type = "causal", ), } diff --git a/torchtitan/vlr/smolvlm/model/model.py b/torchtitan/vlr/smolvlm/model/model.py index 08ca6776..c681a914 100644 --- a/torchtitan/vlr/smolvlm/model/model.py +++ b/torchtitan/vlr/smolvlm/model/model.py @@ -50,6 +50,8 @@ def pixel_shuffle(self, x, scale_factor=4): return x def forward(self, image_hidden_states): + print("image hidden") + print(image_hidden_states) image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) image_hidden_states = self.modality_projection(image_hidden_states) return image_hidden_states @@ -163,9 +165,13 @@ def forward( # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages hidden_states = self.tok_embeddings(input_ids) if self.tok_embeddings else input_ids + """ if self.encoder is not None and pixel_values is not None: vision_tokens = self.get_image_features(pixel_values, patch_attention_mask) hidden_states = self._fuse_vision_text(hidden_states, vision_tokens, input_ids) + else: + "THERE are not images" + """ for layer in self.layers.values(): hidden_states = layer(hidden_states, self.freqs_cis) @@ -204,6 +210,8 @@ def forward( n_heads=9, n_kv_heads=3, ffn_dim=1536, + use_flex_attn = False, + attn_mask_type = "causal", ), } diff --git a/torchtitan/vlr/smolvlm/model/siglip2.py b/torchtitan/vlr/smolvlm/model/siglip2.py index f66fe290..c1e8d68f 100644 --- a/torchtitan/vlr/smolvlm/model/siglip2.py +++ b/torchtitan/vlr/smolvlm/model/siglip2.py @@ -107,10 +107,10 @@ def __init__(self, args: Siglip2ModelArgs): self.out_proj = nn.Linear(self.dim, self.dim) self.attn = build_attention( - use_flex_attn=False, attn_mask_type=args.attn_mask_type + use_flex_attn=False, attn_mask_type=None ) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor): xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) # Use self.head_dim instead of `n_heads` to infer the actual @@ -160,8 +160,8 @@ def __init__(self, args: Siglip2ModelArgs): self.layer_norm2 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) self.mlp = FeedForward(args) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.self_attn(self.layer_norm1(x)) + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + x = x + self.self_attn(self.layer_norm1(x), attention_mask=attention_mask) x = x + self.mlp(self.layer_norm2(x)) return x @@ -194,7 +194,7 @@ def forward( h = self.embeddings(pixel_values, patch_attention_mask) for layer in self.layers.values(): - h = layer(h) + h = layer(h, patch_attention_mask) h = self.post_layernorm(h) return h diff --git a/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml b/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml index 1630b1f0..cac0c8ff 100644 --- a/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml +++ b/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml @@ -4,7 +4,7 @@ custom_args_module = "torchtitan.vlr.smolvlm.assets.job_config" [job] -dump_folder = "./outputs" +dump_folder = "/data/users" description = "Llama 3 Siglip2 VLM training" print_args = false From 8e1723d0b273fe56f85f7bf52abb80e7e3862a0f Mon Sep 17 00:00:00 2001 From: tomiock Date: Mon, 13 Oct 2025 09:39:26 +0000 Subject: [PATCH 15/18] some progress, inference 8B --- run_generate.sh | 3 +- run_generate_llama3.sh | 3 +- scripts/generate/test_generate.py | 28 +++++++++-- torchtitan/components/checkpoint.py | 2 +- torchtitan/generate.py | 5 +- torchtitan/generate_llama3.py | 50 +------------------ torchtitan/models/llama3/__init__.py | 2 +- torchtitan/models/llama3/model/model.py | 35 ++++++------- .../llama3/train_configs/llama3_8b.toml | 7 +-- torchtitan/protocols/train_spec.py | 4 ++ torchtitan/vlr/__init__.py | 1 + torchtitan/vlr/smolvlm/__init__.py | 11 ++-- torchtitan/vlr/smolvlm/model/args.py | 1 + torchtitan/vlr/smolvlm/model/model.py | 6 +-- torchtitan/vlr/smolvlm/model/siglip2.py | 6 +-- .../train_configs/llama_siglip_256.toml | 2 +- 16 files changed, 66 insertions(+), 100 deletions(-) create mode 100644 torchtitan/vlr/__init__.py diff --git a/run_generate.sh b/run_generate.sh index 6dee6461..481ac99b 100755 --- a/run_generate.sh +++ b/run_generate.sh @@ -22,5 +22,4 @@ TORCH_NCCL_DUMP_ON_TIMEOUT=1 \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -m ${INFERENCE_FILE} --job.config_file ${CONFIG_FILE} "$@" \ ---checkpoint.exclude-from-loading dataloader,optimizer,lr_scheduler - +--checkpoint.exclude-from-loading dataloader,optimizer,lr_scheduler \ diff --git a/run_generate_llama3.sh b/run_generate_llama3.sh index c36829d2..69c4ca23 100755 --- a/run_generate_llama3.sh +++ b/run_generate_llama3.sh @@ -12,11 +12,10 @@ set -ex # LOG_RANK=0,1 NGPU=4 ./run_train.sh NGPU=${NGPU:-"1"} export LOG_RANK=${LOG_RANK:-0} -CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_1b.toml"} +CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_8b.toml"} INFERENCE_FILE=${INFERENCE_FILE:-"torchtitan.generate_llama3"} -CUDA_VISIBLE_DEVICES=2 \ NCCL_P2P_DISABLE=1 \ TORCH_NCCL_DUMP_ON_TIMEOUT=1 \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 21322ba2..ea30ce0b 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -32,6 +32,9 @@ from torchtitan.tools.logging import init_logger, logger from torchtitan.tools.utils import device_module, device_type +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.tokenizer import HuggingFaceTokenizer + # support running w/o installing as package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) @@ -143,11 +146,26 @@ def test_generate( state_dict = model.state_dict() - # Checkpoint Loading - begin = time.monotonic() - logger.info(f"Loading chkpt at: {checkpoint_path}") - dcp.load(state_dict, checkpoint_id=checkpoint_path) - logger.info(f"Finished loading chkpt in {time.monotonic() - begin:.2f} seconds.") + # Setup checkpoint manager for loading + checkpointer = CheckpointManager( + dataloader=None, # No dataloader needed for generation + model_parts=[model], + optimizers=None, # No optimizer needed for generation + lr_schedulers=None, # No lr_scheduler needed for generation + states={}, + checkpoint_config=config.checkpoint, + sd_adapter=( + train_spec.state_dict_adapter( + model_args, config.model.hf_assets_path + ) + ), + base_folder=config.job.dump_folder, + ft_manager=None, # No fault tolerance for generation + ) + + # Load checkpoint + checkpointer.load(step=config.checkpoint.load_step) + logger.info(f"Loaded checkpoint from step {config.checkpoint.load_step}") device_mem_stats = device_memory_monitor.get_peak_stats() logger.info( diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index e2e643db..3527bc77 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -339,7 +339,7 @@ def dcp_save( checkpoint_id (str): The checkpoint id to save. async_mode (AsyncMode): Whether the checkpoint is async. enable_garbage_collection (bool): Whether to enable garbage collection after save. - to_hf (bool): Whether to save in HF model definition and safetensors format. + to_hf (bool): Whether to save in HF mel definition and safetensors format. Returns: Future: The future object if the checkpoint is async, otherwise None. diff --git a/torchtitan/generate.py b/torchtitan/generate.py index f7462743..c45a31ff 100644 --- a/torchtitan/generate.py +++ b/torchtitan/generate.py @@ -92,6 +92,7 @@ def __init__(self, job_config: JobConfig): deterministic=False, ) + print(job_config.model.name) self.train_spec = train_spec_module.get_train_spec(job_config.model.name) # Build tokenizer @@ -106,10 +107,6 @@ def __init__(self, job_config: JobConfig): model_args.update_from_config(job_config) self.model_args = model_args - logger.info( - f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" - ) - with ( torch.device("meta"), utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), diff --git a/torchtitan/generate_llama3.py b/torchtitan/generate_llama3.py index 083b34ce..68378aac 100644 --- a/torchtitan/generate_llama3.py +++ b/torchtitan/generate_llama3.py @@ -44,36 +44,6 @@ def __init__(self, job_config: JobConfig): self.device = torch.device(f"{device_type}:{int(os.environ.get('LOCAL_RANK', 0))}") device_module.set_device(self.device) - # Initialize distributed - dist_utils.init_distributed( - job_config.comm, - enable_cpu_backend=False, - base_folder=job_config.job.dump_folder, - ) - - world_size = int(os.environ.get("WORLD_SIZE", 1)) - parallelism_config = job_config.parallelism - self.parallel_dims = parallel_dims = ParallelDims( - dp_shard=parallelism_config.data_parallel_shard_degree, - dp_replicate=parallelism_config.data_parallel_replicate_degree, - cp=parallelism_config.context_parallel_degree, - tp=parallelism_config.tensor_parallel_degree, - pp=parallelism_config.pipeline_parallel_degree, - ep=parallelism_config.expert_parallel_degree, - etp=parallelism_config.expert_tensor_parallel_degree, - world_size=world_size, - ) - - world_mesh = parallel_dims.world_mesh - - # Set random seed - dist_utils.set_determinism( - world_mesh, - self.device, - job_config.training.seed, - deterministic=False, - ) - self.train_spec = train_spec_module.get_train_spec(job_config.model.name) # Build tokenizer @@ -88,31 +58,15 @@ def __init__(self, job_config: JobConfig): model_args.update_from_config(job_config) self.model_args = model_args - logger.info( - f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" - ) - with ( torch.device("meta"), utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), ): model = self.train_spec.model_cls(model_args) - # Build model converters (e.g., for float8) - model_converters = build_model_converters(job_config, parallel_dims) - model_converters.convert(model) - - # Apply parallelism - if parallel_dims.pp_enabled: - raise NotImplementedError("Pipeline parallelism not supported for generation") - else: - model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) - - # Move to device and initialize + with torch.no_grad(): init_device = self.device.type model.to_empty(device=init_device) - with torch.no_grad(): - model.init_weights() model.eval() self.model_parts = [model] @@ -141,7 +95,7 @@ def __init__(self, job_config: JobConfig): logger.info(f"Loaded checkpoint from step {job_config.checkpoint.load_step}") self.processor = AutoProcessor.from_pretrained(job_config.model.hf_assets_path) - + # Load chat template template_path = "torchtitan/vlr/smolvlm/datasets/template.jinja" if os.path.exists(template_path): diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 3b0be28e..2966dbf1 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -51,7 +51,7 @@ "8B": TransformerModelArgs( dim=4096, - ffn_dim=8192, + ffn_dim=14336, n_layers=32, n_heads=32, n_kv_heads=8, diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index e4a691f6..b27e6aa8 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -278,29 +278,30 @@ class FeedForward(nn.Module): def __init__( self, dim: int, - hidden_dim: int, - multiple_of: int | None=None, - ffn_dim_multiplier: float | None=None + hidden_dim: int | None = None, + multiple_of: int | None = None, + ffn_dim_multiplier: float | None = None, + ffn_dim: int | None = None, ): super().__init__() - """ - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - """ - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + if not ffn_dim: + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + ffn_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, ffn_dim, bias=False) + self.w2 = nn.Linear(ffn_dim, dim, bias=False) + self.w3 = nn.Linear(dim, ffn_dim, bias=False) def forward(self, x): - return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + return self.w2(F.silu(self.w1(x)) * self.w3(x)) def init_weights(self, init_std: float): - nn.init.trunc_normal_(self.gate_proj.weight, mean=0.0, std=0.02) - for linear in (self.up_proj, self.down_proj): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) @@ -331,7 +332,7 @@ def __init__(self, layer_id: int, model_args: TransformerModelArgs): self.attention = Attention(model_args) self.feed_forward = FeedForward( dim=model_args.dim, - hidden_dim=model_args.ffn_dim, + ffn_dim=model_args.ffn_dim, ) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index ef86d783..3f610e84 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -43,11 +43,12 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable = false +enable = true folder = "checkpoint" +initial_load_in_hf = true +load_only = true interval = 500 -last_save_model_only = true -export_dtype = "float32" +export_dtype = "bfloat16" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [compile] diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index 22bfa7df..78c9d899 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -76,6 +76,7 @@ def get_train_spec(name: str) -> TrainSpec: from torchtitan.experiments import _supported_experiments from torchtitan.models import _supported_models + from torchtitan.vlr import _supported_vlr_models if name in _supported_models: module = import_module(f"torchtitan.models.{name}") @@ -83,5 +84,8 @@ def get_train_spec(name: str) -> TrainSpec: elif name in _supported_experiments: module = import_module(f"torchtitan.experiments.{name}") return module.get_train_spec() + elif name in _supported_vlr_models: + module = import_module(f"torchtitan.vlr.{name}") + return module.get_train_spec() raise ValueError(f"TrainSpec {name} is not registered.") diff --git a/torchtitan/vlr/__init__.py b/torchtitan/vlr/__init__.py new file mode 100644 index 00000000..429f313f --- /dev/null +++ b/torchtitan/vlr/__init__.py @@ -0,0 +1 @@ +_supported_vlr_models = frozenset(["smolvlm"]) diff --git a/torchtitan/vlr/smolvlm/__init__.py b/torchtitan/vlr/smolvlm/__init__.py index 5661c412..490ccaa8 100644 --- a/torchtitan/vlr/smolvlm/__init__.py +++ b/torchtitan/vlr/smolvlm/__init__.py @@ -9,7 +9,7 @@ from torchtitan.components.optimizer import build_optimizers from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.components.validate import build_validator -from torchtitan.protocols.train_spec import register_train_spec, TrainSpec +from torchtitan.protocols.train_spec import TrainSpec from .datasets.mm_datasets import build_mm_dataloader from .infra.parallelize import parallelize_vlm @@ -64,10 +64,9 @@ "256M": Llama3Siglip2ModelArgs( encoder=siglip2_configs["256M"], dim=576, + ffn_dim=1536, n_layers=30, n_heads=9, - n_kv_heads=3, - ffn_dim_multiplier=1.3, multiple_of=1024, rope_theta=100000, vocab_size=49280, @@ -77,9 +76,8 @@ } -register_train_spec( - TrainSpec( - name="llama3-siglip2", +def get_train_spec() -> TrainSpec: + return TrainSpec( model_cls=Llama3Siglip2Transformer, model_args=llama3_siglip2_configs, parallelize_fn=parallelize_vlm, @@ -92,4 +90,3 @@ build_validator_fn=build_validator, state_dict_adapter=SmolVLMStateDictAdapter, ) -) diff --git a/torchtitan/vlr/smolvlm/model/args.py b/torchtitan/vlr/smolvlm/model/args.py index 8357ad29..a4f6e48c 100644 --- a/torchtitan/vlr/smolvlm/model/args.py +++ b/torchtitan/vlr/smolvlm/model/args.py @@ -35,6 +35,7 @@ class Llama3Siglip2ModelArgs(Llama3Args): tokenizer_name: str = 'HuggingFaceTB/SmolLM2-360M-Instruct' img_token_id: int = 49190 vocab_size: int = 49280 + ffn_dim: int = 1536 def update_from_config(self, job_config: JobConfig, **kwargs) -> None: super().update_from_config(job_config, **kwargs) diff --git a/torchtitan/vlr/smolvlm/model/model.py b/torchtitan/vlr/smolvlm/model/model.py index c681a914..507fd909 100644 --- a/torchtitan/vlr/smolvlm/model/model.py +++ b/torchtitan/vlr/smolvlm/model/model.py @@ -8,7 +8,7 @@ import torch from torch import nn -from torchtitan.models.attention import init_attention_mask +from torchtitan.models.attention import ScaledDotProductAttentionWrapper from torchtitan.models.llama3 import Transformer as Llama3 from .args import Llama3Siglip2ModelArgs, Siglip2ModelArgs @@ -157,10 +157,6 @@ def forward( patch_attention_mask: torch.BoolTensor | None = None, #grid_thw: torch.Tensor | None = None, ): - if self.model_args.use_flex_attn: - init_attention_mask( - input_batch if input_batch is not None else input_ids, eos_id=self.eos_id - ) # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages hidden_states = self.tok_embeddings(input_ids) if self.tok_embeddings else input_ids diff --git a/torchtitan/vlr/smolvlm/model/siglip2.py b/torchtitan/vlr/smolvlm/model/siglip2.py index c1e8d68f..0f80a022 100644 --- a/torchtitan/vlr/smolvlm/model/siglip2.py +++ b/torchtitan/vlr/smolvlm/model/siglip2.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from torch import nn -from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.models.attention import ScaledDotProductAttentionWrapper from .args import Siglip2ModelArgs @@ -106,9 +106,7 @@ def __init__(self, args: Siglip2ModelArgs): self.v_proj = nn.Linear(self.dim, self.dim) self.out_proj = nn.Linear(self.dim, self.dim) - self.attn = build_attention( - use_flex_attn=False, attn_mask_type=None - ) + self.attn = ScaledDotProductAttentionWrapper() def forward(self, x: torch.Tensor, attention_mask: torch.Tensor): xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) diff --git a/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml b/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml index cac0c8ff..fe85e917 100644 --- a/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml +++ b/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml @@ -23,7 +23,7 @@ save_tb_folder = "tb" enable_wandb = true [model] -name = "llama3-siglip2" +name = "smolvlm" flavor = "256M" # test folder with tokenizer.json, for debug purpose only # hf_assets_path = "torchtitan/experiments/vlm/assets/tokenizer" From d93e0449e5349a9bf92ef2d1296cc9e50d2def68 Mon Sep 17 00:00:00 2001 From: tomiock Date: Mon, 13 Oct 2025 11:31:34 +0200 Subject: [PATCH 16/18] new generate scripts --- torchtitan/generate.py | 427 +++++++++++------------------- torchtitan/generate_llama3.py | 474 ++++++++++++++-------------------- 2 files changed, 344 insertions(+), 557 deletions(-) diff --git a/torchtitan/generate.py b/torchtitan/generate.py index c45a31ff..45aed2e9 100644 --- a/torchtitan/generate.py +++ b/torchtitan/generate.py @@ -17,23 +17,28 @@ import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager -from torchtitan.components.tokenizer import HuggingFaceTokenizer from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger -def multinomial_sample_one(probs: torch.Tensor, rng = None): +# --- Generation utilities from scripts/generate/_generation.py --- + +def multinomial_sample_one( + probs: torch.Tensor, rng: Optional[torch.Generator] = None +) -> torch.Tensor: q = torch.empty_like(probs).exponential_(1, generator=rng) - return torch.argmax(probs/q, dim=-1, keepdim=True).to(dtype=torch.long) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) + def logits_to_probs( - logits, - temperature, - top_k -): + logits: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, +) -> torch.Tensor: logits = logits / max(temperature, 1e-5) + if top_k is not None: v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) pivot = v.select(dim=-1, index=-1).unsqueeze(-1) @@ -42,36 +47,94 @@ def logits_to_probs( probs = torch.nn.functional.softmax(logits, dim=-1) return probs + +def generate_next_token( + model, + x: torch.Tensor, + *, + temperature: float = 1.0, + top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, + **model_kwargs, +) -> torch.Tensor: + input_dict = { + "input_ids": x, + **model_kwargs, + } + logits = model(**input_dict) + probs = logits_to_probs(logits[:, -1, :], temperature, top_k) + next_token = multinomial_sample_one(probs, rng=rng) + return next_token + + +@torch.no_grad() +def _generate_sequence( + model, + input_ids: torch.Tensor, + *, + max_new_tokens: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + seed: Optional[int] = None, + **model_kwargs, +) -> torch.Tensor: + # ensure batch dimension (T,) --> (B, T) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + rng = None + if seed is not None: + rng = torch.Generator(input_ids.device).manual_seed(seed) + + generated_tokens = input_ids.clone() + + for _ in range(max_new_tokens): + next_token = generate_next_token( + model, + x=generated_tokens, + temperature=temperature, + top_k=top_k, + rng=rng, + **model_kwargs, + ) + + generated_tokens = torch.cat([generated_tokens, next_token], dim=1) + + return generated_tokens + +# --- End of generation utilities --- + + class Generator: """Generator class for SmolVLM model inference.""" - + def __init__(self, job_config: JobConfig): torch._C._log_api_usage_once("torchtitan.generate") - + self.job_config = job_config - + logger.info(f"Starting generation: {job_config.job.description}") - + if job_config.experimental.custom_import: importlib.import_module(job_config.experimental.custom_import) - + if job_config.job.print_args: logger.info(f"Running with args: {job_config.to_dict()}") - + device_module, device_type = utils.device_module, utils.device_type self.device = torch.device(f"{device_type}:{int(os.environ.get('LOCAL_RANK', 0))}") device_module.set_device(self.device) - - # Initialize distributed - dist_utils.init_distributed( - job_config.comm, - enable_cpu_backend=False, - base_folder=job_config.job.dump_folder, - ) - + world_size = int(os.environ.get("WORLD_SIZE", 1)) + if world_size > 1: + dist_utils.init_distributed( + job_config.comm, + enable_cpu_backend=False, + base_folder=job_config.job.dump_folder, + ) + parallelism_config = job_config.parallelism - self.parallel_dims = parallel_dims = ParallelDims( + self.parallel_dims = ParallelDims( dp_shard=parallelism_config.data_parallel_shard_degree, dp_replicate=parallelism_config.data_parallel_replicate_degree, cp=parallelism_config.context_parallel_degree, @@ -81,63 +144,49 @@ def __init__(self, job_config: JobConfig): etp=parallelism_config.expert_tensor_parallel_degree, world_size=world_size, ) - - world_mesh = parallel_dims.world_mesh - - # Set random seed + dist_utils.set_determinism( - world_mesh, + self.parallel_dims.world_mesh if world_size > 1 else None, self.device, job_config.training.seed, deterministic=False, ) - - print(job_config.model.name) + self.train_spec = train_spec_module.get_train_spec(job_config.model.name) - - # Build tokenizer - self.tokenizer = ( - self.train_spec.build_tokenizer_fn(job_config) - if self.train_spec.build_tokenizer_fn is not None - else None - ) - - # Build model + + self.tokenizer = self.train_spec.build_tokenizer_fn(job_config) + model_args = self.train_spec.model_args[job_config.model.flavor] model_args.update_from_config(job_config) self.model_args = model_args - + with ( torch.device("meta"), utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), ): model = self.train_spec.model_cls(model_args) - - # Build model converters (e.g., for float8) - model_converters = build_model_converters(job_config, parallel_dims) + + model_converters = build_model_converters(job_config, self.parallel_dims) model_converters.convert(model) - - # Apply parallelism - if parallel_dims.pp_enabled: + + if self.parallel_dims.pp_enabled: raise NotImplementedError("Pipeline parallelism not supported for generation") else: - model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) - - # Move to device and initialize + model = self.train_spec.parallelize_fn(model, self.parallel_dims, job_config) + init_device = self.device.type model.to_empty(device=init_device) with torch.no_grad(): model.init_weights() model.eval() - + self.model_parts = [model] - - # Setup checkpoint manager for loading + self.checkpointer = CheckpointManager( - dataloader=None, # No dataloader needed for generation + dataloader=None, model_parts=self.model_parts, - optimizers=None, # No optimizer needed for generation - lr_schedulers=None, # No lr_scheduler needed for generation + optimizers=None, + lr_schedulers=None, states={}, checkpoint_config=job_config.checkpoint, sd_adapter=( @@ -148,19 +197,15 @@ def __init__(self, job_config: JobConfig): else None ), base_folder=job_config.job.dump_folder, - ft_manager=None, # No fault tolerance for generation + ft_manager=None, ) - - # Load checkpoint + self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Loaded checkpoint from step {job_config.checkpoint.load_step}") - - # Setup HF processor for image processing - #processor_path = getattr(model_args, 'tokenizer_name',) - self.processor = AutoProcessor.from_pretrained('HuggingFaceTB/SmolVLM2-256M-Video-Instruct') + + self.processor = AutoProcessor.from_pretrained(job_config.model.hf_assets_path) self.image_processor = self.processor.image_processor - - # Load chat template + template_path = "torchtitan/vlr/smolvlm/datasets/template.jinja" if os.path.exists(template_path): with open(template_path, 'r') as f: @@ -168,15 +213,13 @@ def __init__(self, job_config: JobConfig): else: logger.warning(f"Chat template not found at {template_path}, using default") self.chat_template = None - - # Setup generation parameters + self.max_new_tokens = getattr(job_config, 'max_new_tokens', 256) self.temperature = getattr(job_config, 'temperature', 0.7) - self.top_p = getattr(job_config, 'top_p', 0.9) self.top_k = getattr(job_config, 'top_k', 50) - + logger.info("Generator initialized successfully") - + @torch.no_grad() def generate( self, @@ -184,56 +227,26 @@ def generate( images: Optional[List[Image.Image]] = None, max_new_tokens: Optional[int] = None, temperature: Optional[float] = None, - top_p: Optional[float] = None, top_k: Optional[int] = None, - do_sample: bool = True, - rng = None, + seed: Optional[int] = None, ) -> str: - """Generate text from messages and optional images. - - Args: - messages: List of message dictionaries with 'role' and 'content' - images: Optional list of PIL images - max_new_tokens: Maximum number of tokens to generate - temperature: Sampling temperature - top_p: Top-p sampling parameter - top_k: Top-k sampling parameter - do_sample: Whether to use sampling or greedy decoding - - Returns: - Generated text string - """ max_new_tokens = max_new_tokens or self.max_new_tokens temperature = temperature or self.temperature - top_p = top_p or self.top_p top_k = top_k or self.top_k - + model = self.model_parts[0] model.eval() - - # Process images if provided + pixel_values = None patch_attention_mask = None - + if images: - # Process images using HF processor - vision_inputs = self.image_processor(images) - pixel_values = torch.tensor( - np.array(vision_inputs['pixel_values']) - ).to(self.device, dtype=torch.bfloat16) - + vision_inputs = self.image_processor(images, return_tensors="pt") + pixel_values = vision_inputs['pixel_values'].to(self.device, dtype=torch.bfloat16) + if 'pixel_attention_mask' in vision_inputs: - patch_attention_mask = torch.tensor( - vision_inputs['pixel_attention_mask'] - ).to(self.device) - - # Handle batch dimension - if pixel_values.dim() == 4: - pixel_values = pixel_values.unsqueeze(0) - if patch_attention_mask is not None and patch_attention_mask.dim() == 3: - patch_attention_mask = patch_attention_mask.unsqueeze(0) - - # Tokenize input + patch_attention_mask = vision_inputs['pixel_attention_mask'].to(self.device) + if self.chat_template: input_ids = self.tokenizer.apply_chat_template( messages, @@ -243,139 +256,73 @@ def generate( return_tensors="pt", ) else: - # Fallback to default chat template input_ids = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", ) - + if isinstance(input_ids, dict): input_ids = input_ids["input_ids"] - + input_ids = input_ids.to(self.device) - - # Setup generation context (compile if enabled) - generate_fn = self._generate_tokens - if self.job_config.compile.enable and "model" in self.job_config.compile.components: - generate_fn = torch.compile(generate_fn, mode="reduce-overhead") - - # Generate tokens + + model_kwargs = { + "pixel_values": pixel_values, + "patch_attention_mask": patch_attention_mask, + "eos_id": self.tokenizer.eos_token_id, + } + with torch.amp.autocast('cuda', dtype=torch.bfloat16): - output_ids = generate_fn( + output_ids = _generate_sequence( model=model, input_ids=input_ids, - pixel_values=None, - patch_attention_mask=None, max_new_tokens=max_new_tokens, temperature=temperature, - top_p=top_p, top_k=top_k, - do_sample=do_sample, - rng=rng, + seed=seed, + **model_kwargs, ) - - # Decode output + generated_ids = output_ids[0, input_ids.shape[1]:] - generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) - + generated_text = self.tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True) + return generated_text - - def _generate_tokens( - self, - model: torch.nn.Module, - input_ids: torch.Tensor, - pixel_values: Optional[torch.Tensor], - patch_attention_mask: Optional[torch.Tensor], - max_new_tokens: int, - temperature: float, - top_p: float, - top_k: int, - do_sample: bool, - rng = None, - ) -> torch.Tensor: - """Core generation loop.""" - - batch_size = input_ids.shape[0] - generated_ids = input_ids.clone() - - # Cache for key-value pairs (if using KV cache in the future) - past_key_values = None - - for _ in range(max_new_tokens): - # Forward pass - with torch.no_grad(): - # Prepare input dict - input_dict = { - "input_ids": generated_ids, - "eos_id": self.tokenizer.eos_token, - } - - if pixel_values is not None: - input_dict["pixel_values"] = pixel_values - - if patch_attention_mask is not None: - input_dict["patch_attention_mask"] = patch_attention_mask - - # Get model output - logits = model(**input_dict) - - # Get next token logits - probs = logits_to_probs(logits[:, -1, :], temperature, top_k) - next_token = multinomial_sample_one(probs, rng=rng) - - # Append to generated sequence - generated_ids = torch.cat([generated_ids, next_token], dim=1) - - # Check for EOS token - if (next_token == self.tokenizer.eos_token): - break - - return generated_ids - + def interactive_generate(self): """Interactive generation mode for testing.""" logger.info("Starting interactive generation mode. Type 'quit' to exit.") - + while True: try: user_input = input("\nEnter your prompt (or 'quit' to exit): ").strip() - + if user_input.lower() == 'quit': break - - # Check if user wants to include an image + image_path = input("Enter image path (or press Enter to skip): ").strip() - + images = None if image_path and os.path.exists(image_path): image = Image.open(image_path).convert('RGB') - # Resize to expected size - image = image.resize((512, 512)) images = [image] logger.info(f"Loaded image from {image_path}") elif image_path: logger.warning(f"Image path {image_path} not found, proceeding without image") - - # Create message format - messages = [ - { - "user": user_input, - "assistant": "" # Will be filled by generation - } - ] - + + messages = [{"role": "user", "content": user_input}] + logger.info("Generating response...") start_time = time.perf_counter() - + response = self.generate(messages, images=images) - + generation_time = time.perf_counter() - start_time logger.info(f"Generation completed in {generation_time:.2f}s") - + print(f"\nGenerated response:\n{response}") - + except KeyboardInterrupt: logger.info("\nInterrupted by user") break @@ -383,53 +330,7 @@ def interactive_generate(self): logger.error(f"Error during generation: {e}") import traceback traceback.print_exc() - - def batch_generate(self, input_file: str, output_file: str): - """Generate responses for a batch of inputs from a file. - - Args: - input_file: Path to JSON file with inputs - output_file: Path to save outputs - """ - import json - - logger.info(f"Loading inputs from {input_file}") - - with open(input_file, 'r') as f: - inputs = json.load(f) - - rng = None - if seed is not None: - rng = torch.Generator(input_ids.device).manual_seed(42) - - results = [] - for i, item in enumerate(inputs): - logger.info(f"Processing item {i+1}/{len(inputs)}") - - messages = item.get('messages', []) - image_paths = item.get('images', []) - - # Load images if provided - images = [] - for path in image_paths: - if os.path.exists(path): - image = Image.open(path).convert('RGB').resize((512, 512)) - images.append(image) - - # Generate response - response = self.generate(messages, images=images if images else None, rng=rng) - - results.append({ - 'input': item, - 'output': response - }) - - # Save results - with open(output_file, 'w') as f: - json.dump(results, f, indent=2) - - logger.info(f"Results saved to {output_file}") - + def close(self): """Cleanup resources.""" if hasattr(self, 'checkpointer'): @@ -441,36 +342,15 @@ def close(self): def main(): """Main entry point for generation.""" init_logger() - - # Parse configuration + config_manager = ConfigManager() config = config_manager.parse_args() - + generator = None try: - # Initialize generator generator = Generator(config) - - # Check for generation mode from config or command line - generation_mode = getattr(config, 'generation_mode', 'interactive') - - if generation_mode == 'interactive': - generator.interactive_generate() - elif generation_mode == 'batch': - input_file = getattr(config, 'input_file', 'inputs.json') - output_file = getattr(config, 'output_file', 'outputs.json') - generator.batch_generate(input_file, output_file) - else: - # Single generation example - messages = [ - { - "user": "What is the capital of France?", - "assistant": "" - } - ] - response = generator.generate(messages) - logger.info(f"Generated: {response}") - + generator.interactive_generate() + except Exception as e: logger.error(f"Error during generation: {e}") if generator: @@ -479,7 +359,8 @@ def main(): else: if generator: generator.close() - torch.distributed.destroy_process_group() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() logger.info("Process group destroyed") diff --git a/torchtitan/generate_llama3.py b/torchtitan/generate_llama3.py index 68378aac..1b0e47ed 100644 --- a/torchtitan/generate_llama3.py +++ b/torchtitan/generate_llama3.py @@ -8,75 +8,178 @@ import os import time from typing import Optional, List, Dict, Any -import numpy as np import torch from torch.distributed.elastic.multiprocessing.errors import record -from transformers import AutoProcessor import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager -from torchtitan.components.tokenizer import HuggingFaceTokenizer from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger +# --- Generation utilities from scripts/generate/_generation.py --- + +def multinomial_sample_one( + probs: torch.Tensor, rng: Optional[torch.Generator] = None +) -> torch.Tensor: + q = torch.empty_like(probs).exponential_(1, generator=rng) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) + + +def logits_to_probs( + logits: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, +) -> torch.Tensor: + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) + pivot = v.select(dim=-1, index=-1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def generate_next_token( + model, + x: torch.Tensor, + *, + temperature: float = 1.0, + top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, +) -> torch.Tensor: + # The model forward pass in torchtitan expects a `tokens` argument. + logits = model(tokens=x) + probs = logits_to_probs(logits[:, -1, :], temperature, top_k) + next_token = multinomial_sample_one(probs, rng=rng) + return next_token + + +@torch.no_grad() +def _generate_sequence( + model, + input_ids: torch.Tensor, + *, + max_new_tokens: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + seed: Optional[int] = None, +) -> torch.Tensor: + # ensure batch dimension (T,) --> (B, T) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + rng = None + if seed is not None: + rng = torch.Generator(input_ids.device).manual_seed(seed) + + generated_tokens = input_ids.clone() + + for _ in range(max_new_tokens): + next_token = generate_next_token( + model, + x=generated_tokens, + temperature=temperature, + top_k=top_k, + rng=rng, + ) + + generated_tokens = torch.cat([generated_tokens, next_token], dim=1) + + return generated_tokens + +# --- End of generation utilities --- + class Generator: - """Generator class for SmolVLM model inference.""" - + """Generator class for Llama3 model inference.""" + def __init__(self, job_config: JobConfig): torch._C._log_api_usage_once("torchtitan.generate") - + self.job_config = job_config - + logger.info(f"Starting generation: {job_config.job.description}") - + if job_config.experimental.custom_import: importlib.import_module(job_config.experimental.custom_import) - + if job_config.job.print_args: logger.info(f"Running with args: {job_config.to_dict()}") - + device_module, device_type = utils.device_module, utils.device_type self.device = torch.device(f"{device_type}:{int(os.environ.get('LOCAL_RANK', 0))}") device_module.set_device(self.device) - - self.train_spec = train_spec_module.get_train_spec(job_config.model.name) - - # Build tokenizer - self.tokenizer = ( - self.train_spec.build_tokenizer_fn(job_config) - if self.train_spec.build_tokenizer_fn is not None - else None + + # For generation, we usually use a single process or TP. + # We will not initialize the full distributed setup unless necessary. + world_size = int(os.environ.get("WORLD_SIZE", 1)) + if world_size > 1: + dist_utils.init_distributed( + job_config.comm, + enable_cpu_backend=False, + base_folder=job_config.job.dump_folder, + ) + + parallelism_config = job_config.parallelism + self.parallel_dims = ParallelDims( + dp_shard=parallelism_config.data_parallel_shard_degree, + dp_replicate=parallelism_config.data_parallel_replicate_degree, + cp=parallelism_config.context_parallel_degree, + tp=parallelism_config.tensor_parallel_degree, + pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, + etp=parallelism_config.expert_tensor_parallel_degree, + world_size=world_size, + ) + + dist_utils.set_determinism( + self.parallel_dims.world_mesh if world_size > 1 else None, + self.device, + job_config.training.seed, + deterministic=False, ) - - # Build model + + self.train_spec = train_spec_module.get_train_spec(job_config.model.name) + + self.tokenizer = self.train_spec.build_tokenizer_fn(job_config) + model_args = self.train_spec.model_args[job_config.model.flavor] model_args.update_from_config(job_config) self.model_args = model_args - + with ( torch.device("meta"), utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), ): model = self.train_spec.model_cls(model_args) - - with torch.no_grad(): + + model_converters = build_model_converters(job_config, self.parallel_dims) + model_converters.convert(model) + + if self.parallel_dims.pp_enabled: + raise NotImplementedError("Pipeline parallelism not supported for generation") + else: + model = self.train_spec.parallelize_fn(model, self.parallel_dims, job_config) + init_device = self.device.type model.to_empty(device=init_device) + with torch.no_grad(): + model.init_weights() model.eval() - + self.model_parts = [model] - - # Setup checkpoint manager for loading + self.checkpointer = CheckpointManager( - dataloader=None, # No dataloader needed for generation + dataloader=None, model_parts=self.model_parts, - optimizers=None, # No optimizer needed for generation - lr_schedulers=None, # No lr_scheduler needed for generation + optimizers=None, + lr_schedulers=None, states={}, checkpoint_config=job_config.checkpoint, sd_adapter=( @@ -87,254 +190,57 @@ def __init__(self, job_config: JobConfig): else None ), base_folder=job_config.job.dump_folder, - ft_manager=None, # No fault tolerance for generation + ft_manager=None, ) - - # Load checkpoint + self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Loaded checkpoint from step {job_config.checkpoint.load_step}") - - self.processor = AutoProcessor.from_pretrained(job_config.model.hf_assets_path) - - # Load chat template - template_path = "torchtitan/vlr/smolvlm/datasets/template.jinja" - if os.path.exists(template_path): - with open(template_path, 'r') as f: - self.chat_template = f.read() - else: - logger.warning(f"Chat template not found at {template_path}, using default") - self.chat_template = None - - # Setup generation parameters + self.max_new_tokens = getattr(job_config, 'max_new_tokens', 256) self.temperature = getattr(job_config, 'temperature', 0.7) - self.top_p = getattr(job_config, 'top_p', 0.9) self.top_k = getattr(job_config, 'top_k', 50) - + logger.info("Generator initialized successfully") - + @torch.no_grad() def generate( self, - messages: List[Dict[str, Any]], + prompts: List[str], max_new_tokens: Optional[int] = None, temperature: Optional[float] = None, - top_p: Optional[float] = None, top_k: Optional[int] = None, - do_sample: bool = True, - ) -> str: - """Generate text from messages. - - Args: - messages: List of message dictionaries with 'role' and 'content' - max_new_tokens: Maximum number of tokens to generate - temperature: Sampling temperature - top_p: Top-p sampling parameter - top_k: Top-k sampling parameter - do_sample: Whether to use sampling or greedy decoding - - Returns: - Generated text string - """ + seed: Optional[int] = None, + ) -> List[str]: max_new_tokens = max_new_tokens or self.max_new_tokens temperature = temperature or self.temperature - top_p = top_p or self.top_p top_k = top_k or self.top_k - + model = self.model_parts[0] model.eval() - - # Tokenize input - if self.chat_template: - input_ids = self.tokenizer.apply_chat_template( - messages, - tokenize=True, - chat_template=self.chat_template, - add_generation_prompt=True, - return_tensors="pt", - ) - else: - # Fallback to default chat template - input_ids = self.tokenizer.apply_chat_template( - messages, - tokenize=True, - add_generation_prompt=True, - return_tensors="pt", - ) - - if isinstance(input_ids, dict): - input_ids = input_ids["input_ids"] - - input_ids = input_ids.to(self.device) - - # Setup generation context (compile if enabled) - generate_fn = self._generate_tokens - if self.job_config.compile.enable and "model" in self.job_config.compile.components: - logger.info('Compiling model...') - generate_fn = torch.compile(generate_fn, mode="reduce-overhead") - - # Generate tokens - with torch.amp.autocast('cuda', dtype=torch.bfloat16): - output_ids = generate_fn( - model=model, - input_ids=input_ids, - max_new_tokens=max_new_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - do_sample=do_sample, - ) - - # Decode output - generated_ids = output_ids[0, input_ids.shape[1]:] - generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) - - return generated_text - - def _generate_tokens( - self, - model: torch.nn.Module, - input_ids: torch.Tensor, - max_new_tokens: int, - temperature: float, - top_p: float, - top_k: int, - do_sample: bool, - ) -> torch.Tensor: - """Core generation loop.""" - - batch_size = input_ids.shape[0] - generated_ids = input_ids.clone() - - # Cache for key-value pairs (if using KV cache in the future) - past_key_values = None - - for _ in range(max_new_tokens): - # Forward pass - with torch.no_grad(): - # Prepare input dict - input_dict = { - "tokens": generated_ids, - } - - # Get model output - logits = model(**input_dict) - - # Get next token logits - next_token_logits = logits[:, -1, :] - - # Apply temperature - if temperature > 0: - next_token_logits = next_token_logits / temperature - - # Sample or greedy decode - if False: - # Apply top-k filtering - if top_k > 0: - indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] - next_token_logits[indices_to_remove] = -float('inf') - - # Apply top-p filtering - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - next_token_logits[indices_to_remove] = -float('inf') - - # Sample - probs = torch.softmax(next_token_logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - else: - # Greedy decoding - next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) - - # Append to generated sequence - generated_ids = torch.cat([generated_ids, next_token], dim=1) - - # Check for EOS token - if (next_token == self.tokenizer.eos_token): - break - - return generated_ids - - def interactive_generate(self): - """Interactive generation mode for testing.""" - logger.info("Starting interactive generation mode. Type 'quit' to exit.") - - while True: - try: - user_input = input("\nEnter your prompt (or 'quit' to exit): ").strip() - - if user_input.lower() == 'quit': - break - - - # Create message format - messages = [ - { - "user": user_input, - "assistant": "" # Will be filled by generation - } - ] - - logger.info("Generating response...") - start_time = time.perf_counter() - - response = self.generate(messages) - - generation_time = time.perf_counter() - start_time - logger.info(f"Generation completed in {generation_time:.2f}s") - - print(f"\nGenerated response:\n{response}") - - except KeyboardInterrupt: - logger.info("\nInterrupted by user") - break - except Exception as e: - logger.error(f"Error during generation: {e}") - import traceback - traceback.print_exc() - - def batch_generate(self, input_file: str, output_file: str): - """Generate responses for a batch of inputs from a file. - - Args: - input_file: Path to JSON file with inputs - output_file: Path to save outputs - """ - import json - - logger.info(f"Loading inputs from {input_file}") - - with open(input_file, 'r') as f: - inputs = json.load(f) - - results = [] - for i, item in enumerate(inputs): - logger.info(f"Processing item {i+1}/{len(inputs)}") - - messages = item.get('messages', []) - - # Generate response - response = self.generate(messages) - - results.append({ - 'input': item, - 'output': response - }) - - # Save results - with open(output_file, 'w') as f: - json.dump(results, f, indent=2) - - logger.info(f"Results saved to {output_file}") - + + # For simplicity, this example handles one prompt at a time. + # Batching can be added for efficiency. + generated_texts = [] + for prompt in prompts: + input_ids = self.tokenizer.encode(prompt, add_bos=True, add_eos=False) + input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device) + + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + output_ids = _generate_sequence( + model=model, + input_ids=input_ids, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + seed=seed, + ) + + generated_ids = output_ids[0, input_ids.shape[0]:] + generated_text = self.tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True) + generated_texts.append(generated_text) + + return generated_texts + def close(self): """Cleanup resources.""" if hasattr(self, 'checkpointer'): @@ -346,36 +252,35 @@ def close(self): def main(): """Main entry point for generation.""" init_logger() - + # Parse configuration config_manager = ConfigManager() config = config_manager.parse_args() - + generator = None try: # Initialize generator generator = Generator(config) - - # Check for generation mode from config or command line - generation_mode = getattr(config, 'generation_mode', 'interactive') - - if generation_mode == 'interactive': - generator.interactive_generate() - elif generation_mode == 'batch': - input_file = getattr(config, 'input_file', 'inputs.json') - output_file = getattr(config, 'output_file', 'outputs.json') - generator.batch_generate(input_file, output_file) - else: - # Single generation example - messages = [ - { - "user": "What is the capital of France?", - "assistant": "" - } - ] - response = generator.generate(messages) - logger.info(f"Generated: {response}") - + + prompts = [ + "What is the meaning of life?", + "Translate 'hello world' to French.", + ] + + logger.info(f"Generating for prompts: {prompts}") + start_time = time.perf_counter() + + responses = generator.generate(prompts) + + generation_time = time.perf_counter() - start_time + logger.info(f"Generation completed in {generation_time:.2f}s") + + for prompt, response in zip(prompts, responses): + print("-" * 20) + print(f"Prompt: {prompt}") + print(f"Response: {response}") + print("-" * 20) + except Exception as e: logger.error(f"Error during generation: {e}") if generator: @@ -384,9 +289,10 @@ def main(): else: if generator: generator.close() - torch.distributed.destroy_process_group() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() logger.info("Process group destroyed") if __name__ == "__main__": - main() + main() \ No newline at end of file From ef4a28a2ba287646dcee5038f4010b599ef3eb5d Mon Sep 17 00:00:00 2001 From: tomiock Date: Mon, 13 Oct 2025 13:58:06 +0000 Subject: [PATCH 17/18] the problem was with the input tokens (chat template) --- torchtitan/generate.py | 72 +++++++++++++------ torchtitan/generate_llama3.py | 2 +- torchtitan/models/llama3/model/model.py | 6 +- torchtitan/vlr/smolvlm/__init__.py | 1 + .../vlr/smolvlm/datasets/template.jinja | 4 +- torchtitan/vlr/smolvlm/model/model.py | 15 +--- .../vlr/smolvlm/model/state_dict_adapter.py | 6 +- .../train_configs/llama_siglip_256.toml | 4 +- 8 files changed, 65 insertions(+), 45 deletions(-) diff --git a/torchtitan/generate.py b/torchtitan/generate.py index 45aed2e9..44a78c2d 100644 --- a/torchtitan/generate.py +++ b/torchtitan/generate.py @@ -11,6 +11,7 @@ import numpy as np import torch +torch.set_printoptions(threshold=10_000) from torch.distributed.elastic.multiprocessing.errors import record from transformers import AutoProcessor from PIL import Image @@ -71,12 +72,12 @@ def generate_next_token( def _generate_sequence( model, input_ids: torch.Tensor, - *, max_new_tokens: int, temperature: float = 1.0, + pixel_values: torch.Tensor | None = None, + patch_attention_mask: torch.BoolTensor | None = None, top_k: Optional[int] = None, seed: Optional[int] = None, - **model_kwargs, ) -> torch.Tensor: # ensure batch dimension (T,) --> (B, T) if input_ids.ndim == 1: @@ -95,7 +96,8 @@ def _generate_sequence( temperature=temperature, top_k=top_k, rng=rng, - **model_kwargs, + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, ) generated_tokens = torch.cat([generated_tokens, next_token], dim=1) @@ -214,9 +216,9 @@ def __init__(self, job_config: JobConfig): logger.warning(f"Chat template not found at {template_path}, using default") self.chat_template = None - self.max_new_tokens = getattr(job_config, 'max_new_tokens', 256) - self.temperature = getattr(job_config, 'temperature', 0.7) - self.top_k = getattr(job_config, 'top_k', 50) + self.max_new_tokens = getattr(job_config, 'max_new_tokens', 16) + self.temperature = getattr(job_config, 'temperature', 0) + self.top_k = getattr(job_config, 'top_k', None) logger.info("Generator initialized successfully") @@ -247,6 +249,7 @@ def generate( if 'pixel_attention_mask' in vision_inputs: patch_attention_mask = vision_inputs['pixel_attention_mask'].to(self.device) + """ if self.chat_template: input_ids = self.tokenizer.apply_chat_template( messages, @@ -267,26 +270,56 @@ def generate( input_ids = input_ids["input_ids"] input_ids = input_ids.to(self.device) + """ + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": messages}, + #{"type": "image", "image": "/home-local/tockier/cat.jpg"} + ] + }, + ] + + inputs = self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=False, + return_dict=True, + return_tensors="pt", + ) + + print(inputs) - model_kwargs = { - "pixel_values": pixel_values, - "patch_attention_mask": patch_attention_mask, - "eos_id": self.tokenizer.eos_token_id, - } + inputs = self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ).to(self.device, dtype=torch.bfloat16) + + input_ids = inputs['input_ids'] + pixel_values = inputs.get('pixel_values', None) + patch_attention_mask = inputs.get('patch_attention_mask', None) + + print(inputs) with torch.amp.autocast('cuda', dtype=torch.bfloat16): output_ids = _generate_sequence( model=model, input_ids=input_ids, - max_new_tokens=max_new_tokens, + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + max_new_tokens=64, temperature=temperature, top_k=top_k, seed=seed, - **model_kwargs, ) - generated_ids = output_ids[0, input_ids.shape[1]:] - generated_text = self.tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True) + print(output_ids.v) + generated_text = self.processor.batch_decode(output_ids, skip_special_tokens=True) return generated_text @@ -311,17 +344,16 @@ def interactive_generate(self): elif image_path: logger.warning(f"Image path {image_path} not found, proceeding without image") - messages = [{"role": "user", "content": user_input}] + messages = [{ "user": user_input, "assistant": ""}] logger.info("Generating response...") start_time = time.perf_counter() - response = self.generate(messages, images=images) + response = self.generate(user_input, images=images) - generation_time = time.perf_counter() - start_time - logger.info(f"Generation completed in {generation_time:.2f}s") + print(response) - print(f"\nGenerated response:\n{response}") + generation_time = time.perf_counter() - start_time except KeyboardInterrupt: logger.info("\nInterrupted by user") diff --git a/torchtitan/generate_llama3.py b/torchtitan/generate_llama3.py index 1b0e47ed..eb08813b 100644 --- a/torchtitan/generate_llama3.py +++ b/torchtitan/generate_llama3.py @@ -295,4 +295,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index b27e6aa8..d4b703b9 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -174,11 +174,7 @@ class Attention(nn.Module): def __init__(self, model_args: TransformerModelArgs): super().__init__() self.n_heads = model_args.n_heads - self.n_kv_heads = ( - model_args.n_heads - if model_args.n_kv_heads is None - else model_args.n_kv_heads - ) + self.n_kv_heads = model_args.n_kv_heads self.n_rep = self.n_heads // self.n_kv_heads self.head_dim = model_args.dim // model_args.n_heads diff --git a/torchtitan/vlr/smolvlm/__init__.py b/torchtitan/vlr/smolvlm/__init__.py index 490ccaa8..1759f0b5 100644 --- a/torchtitan/vlr/smolvlm/__init__.py +++ b/torchtitan/vlr/smolvlm/__init__.py @@ -67,6 +67,7 @@ ffn_dim=1536, n_layers=30, n_heads=9, + n_kv_heads=3, multiple_of=1024, rope_theta=100000, vocab_size=49280, diff --git a/torchtitan/vlr/smolvlm/datasets/template.jinja b/torchtitan/vlr/smolvlm/datasets/template.jinja index 50b69a4b..01872f5f 100644 --- a/torchtitan/vlr/smolvlm/datasets/template.jinja +++ b/torchtitan/vlr/smolvlm/datasets/template.jinja @@ -1,2 +1,2 @@ -{%- for message in messages %}{{'<|im_start|>user' + '\\n' + message['user'] + '<|im_end|>' }} -{{'<|im_start|>assistant' + '\\n' + message['assistant'] + '<|im_end|>' }}{%- endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %} \ No newline at end of file +{%- for message in messages %}{{'<|im_start|>user' + message['user'] + '' + '' }} +{%- endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %} diff --git a/torchtitan/vlr/smolvlm/model/model.py b/torchtitan/vlr/smolvlm/model/model.py index 507fd909..88ce8e73 100644 --- a/torchtitan/vlr/smolvlm/model/model.py +++ b/torchtitan/vlr/smolvlm/model/model.py @@ -8,6 +8,7 @@ import torch from torch import nn +from torchtitan.protocols.model import AttentionMasksType from torchtitan.models.attention import ScaledDotProductAttentionWrapper from torchtitan.models.llama3 import Transformer as Llama3 @@ -50,8 +51,6 @@ def pixel_shuffle(self, x, scale_factor=4): return x def forward(self, image_hidden_states): - print("image hidden") - print(image_hidden_states) image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) image_hidden_states = self.modality_projection(image_hidden_states) return image_hidden_states @@ -142,7 +141,6 @@ def get_image_features( patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() image_hidden_states = self.encoder(pixel_values, patch_attention_mask) - #image_hidden_states = image_hidden_states.last_hidden_state image_hidden_states = image_hidden_states.bfloat16() image_hidden_states = self.projector(image_hidden_states) @@ -151,26 +149,19 @@ def get_image_features( def forward( self, input_ids: torch.Tensor, - eos_id: int | None = None, - input_batch: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, patch_attention_mask: torch.BoolTensor | None = None, - #grid_thw: torch.Tensor | None = None, + attention_masks: AttentionMasksType | None = None, ): - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages hidden_states = self.tok_embeddings(input_ids) if self.tok_embeddings else input_ids - """ if self.encoder is not None and pixel_values is not None: vision_tokens = self.get_image_features(pixel_values, patch_attention_mask) hidden_states = self._fuse_vision_text(hidden_states, vision_tokens, input_ids) - else: - "THERE are not images" - """ for layer in self.layers.values(): - hidden_states = layer(hidden_states, self.freqs_cis) + hidden_states = layer(hidden_states, self.freqs_cis, attention_masks=attention_masks) hidden_states = self.norm(hidden_states) output = self.output(hidden_states) diff --git a/torchtitan/vlr/smolvlm/model/state_dict_adapter.py b/torchtitan/vlr/smolvlm/model/state_dict_adapter.py index fa353c94..32508847 100644 --- a/torchtitan/vlr/smolvlm/model/state_dict_adapter.py +++ b/torchtitan/vlr/smolvlm/model/state_dict_adapter.py @@ -37,9 +37,9 @@ def __init__( #"model.layers.{}.self_attn.rotary_emb.inv_freq": None, - "model.text_model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.gate_proj.weight", # check - "model.text_model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.up_proj.weight", # check - "model.text_model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.down_proj.weight", # check + "model.text_model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", # check + "model.text_model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", # check + "model.text_model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", # check "model.text_model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", # check "model.text_model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", # check diff --git a/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml b/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml index fe85e917..c4ceb4b1 100644 --- a/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml +++ b/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml @@ -4,7 +4,7 @@ custom_args_module = "torchtitan.vlr.smolvlm.assets.job_config" [job] -dump_folder = "/data/users" +dump_folder = "/data/users/tockier/outputs/" description = "Llama 3 Siglip2 VLM training" print_args = false @@ -43,7 +43,7 @@ min_lr_factor = 0.0 [training] local_batch_size = 2 -seq_len = 1048 +seq_len = 4096 # packing_buffer_size = 100 max_norm = 1.0 # grad norm clipping steps = 13100 From e4ebcab1b3b343a49afe4c833b127c97e9820778 Mon Sep 17 00:00:00 2001 From: tomiock Date: Mon, 13 Oct 2025 16:47:16 +0000 Subject: [PATCH 18/18] debugging --- run_generate.sh | 1 - torchtitan/generate.py | 42 ++++++++----- torchtitan/models/attention.py | 5 +- torchtitan/models/llama3/model/model.py | 2 +- torchtitan/vlr/smolvlm/model/model.py | 22 +++++-- torchtitan/vlr/smolvlm/model/siglip2.py | 84 ++++++++++++++++++++----- 6 files changed, 116 insertions(+), 40 deletions(-) diff --git a/run_generate.sh b/run_generate.sh index 481ac99b..aa7ffcf3 100755 --- a/run_generate.sh +++ b/run_generate.sh @@ -16,7 +16,6 @@ CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/vlr/smolvlm/train_configs/llama_siglip_ INFERENCE_FILE=${INFERENCE_FILE:-"torchtitan.generate"} -CUDA_VISIBLE_DEVICES=2 \ NCCL_P2P_DISABLE=1 \ TORCH_NCCL_DUMP_ON_TIMEOUT=1 \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ diff --git a/torchtitan/generate.py b/torchtitan/generate.py index 44a78c2d..1552e24f 100644 --- a/torchtitan/generate.py +++ b/torchtitan/generate.py @@ -225,7 +225,7 @@ def __init__(self, job_config: JobConfig): @torch.no_grad() def generate( self, - messages: List[Dict[str, Any]], + messages: List[Dict[str, Any]] = None, images: Optional[List[Image.Image]] = None, max_new_tokens: Optional[int] = None, temperature: Optional[float] = None, @@ -277,11 +277,22 @@ def generate( "role": "user", "content": [ {"type": "text", "text": messages}, - #{"type": "image", "image": "/home-local/tockier/cat.jpg"} + {"type": "image", "image": images}, ] }, ] + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Caption this image"}, + {"type": "image", "image": "../cat.jpg"}, + ] + }, + ] + + inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, @@ -290,8 +301,6 @@ def generate( return_tensors="pt", ) - print(inputs) - inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, @@ -302,10 +311,19 @@ def generate( input_ids = inputs['input_ids'] pixel_values = inputs.get('pixel_values', None) - patch_attention_mask = inputs.get('patch_attention_mask', None) + patch_attention_mask = inputs.get('pixel_attention_mask', None) - print(inputs) + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + logits = model( + input_ids=input_ids, + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + print(logits) + + """ with torch.amp.autocast('cuda', dtype=torch.bfloat16): output_ids = _generate_sequence( model=model, @@ -321,7 +339,8 @@ def generate( print(output_ids.v) generated_text = self.processor.batch_decode(output_ids, skip_special_tokens=True) - return generated_text + print(generated_text) + """ def interactive_generate(self): """Interactive generation mode for testing.""" @@ -339,19 +358,14 @@ def interactive_generate(self): images = None if image_path and os.path.exists(image_path): image = Image.open(image_path).convert('RGB') - images = [image] logger.info(f"Loaded image from {image_path}") elif image_path: logger.warning(f"Image path {image_path} not found, proceeding without image") - messages = [{ "user": user_input, "assistant": ""}] - logger.info("Generating response...") start_time = time.perf_counter() - response = self.generate(user_input, images=images) - - print(response) + response = self.generate(user_input, images=image) generation_time = time.perf_counter() - start_time @@ -381,7 +395,7 @@ def main(): generator = None try: generator = Generator(config) - generator.interactive_generate() + generator.generate() except Exception as e: logger.error(f"Error during generation: {e}") diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index bf963a5b..92285777 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -83,7 +83,7 @@ class ScaledDotProductAttentionWrapper(torch.nn.Module): # TODO: remove sdpa_backends after PyTorch 2.9 is released. sdpa_backends: ClassVar[list[SDPBackend]] = [] - def __init__(self) -> None: + def __init__(self, is_causal) -> None: super().__init__() if not self.sdpa_backends: self.sdpa_backends = [ @@ -91,6 +91,7 @@ def __init__(self) -> None: SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, ] + self.is_causal = is_causal def forward( self, @@ -101,7 +102,7 @@ def forward( scale: float | None = None, ) -> torch.Tensor: with sdpa_kernel(self.sdpa_backends, set_priority=True): - return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True) + return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=self.is_causal) # We cannot do inner function/closure because we won't be able to cache it -- diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index d4b703b9..d45d2996 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -191,7 +191,7 @@ def __init__(self, model_args: TransformerModelArgs): if self.use_flex_attn: self.inner_attention = FlexAttentionWrapper() else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.inner_attention = ScaledDotProductAttentionWrapper(is_causal=True) def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): diff --git a/torchtitan/vlr/smolvlm/model/model.py b/torchtitan/vlr/smolvlm/model/model.py index 88ce8e73..18337170 100644 --- a/torchtitan/vlr/smolvlm/model/model.py +++ b/torchtitan/vlr/smolvlm/model/model.py @@ -111,9 +111,11 @@ def get_image_features( pixel_attention_mask ): batch_size, num_images, num_channels, height, width = pixel_values.shape - pixel_values = pixel_values.bfloat16() # fp16 compatibility + pixel_values = pixel_values.to(dtype=torch.bfloat16) # fp16 compatibility pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + patch_size = 16 + # Remove padding images - padding images are full 0. nb_values_per_image = pixel_values.shape[1:].numel() real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image @@ -123,7 +125,7 @@ def get_image_features( real_images_inds[0] = True pixel_values = pixel_values[real_images_inds].contiguous() - + # Handle the vision attention mask if pixel_attention_mask is None: pixel_attention_mask = torch.ones( size=[pixel_values.shape[i] for i in (0, 2, 3)], @@ -135,12 +137,12 @@ def get_image_features( pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:]) pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() - patch_size = 16 patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() image_hidden_states = self.encoder(pixel_values, patch_attention_mask) + print('v', image_hidden_states) image_hidden_states = image_hidden_states.bfloat16() image_hidden_states = self.projector(image_hidden_states) @@ -158,14 +160,24 @@ def forward( if self.encoder is not None and pixel_values is not None: vision_tokens = self.get_image_features(pixel_values, patch_attention_mask) + print('v2', vision_tokens) hidden_states = self._fuse_vision_text(hidden_states, vision_tokens, input_ids) + print('h', hidden_states) + + is_first_layer = True for layer in self.layers.values(): hidden_states = layer(hidden_states, self.freqs_cis, attention_masks=attention_masks) + + if is_first_layer: + print('d1', hidden_states) + is_first_layer = False + + print('d29', hidden_states) hidden_states = self.norm(hidden_states) - output = self.output(hidden_states) - return output + logits = self.output(hidden_states) + return logits if __name__ == "__main__": diff --git a/torchtitan/vlr/smolvlm/model/siglip2.py b/torchtitan/vlr/smolvlm/model/siglip2.py index 0f80a022..d606ee82 100644 --- a/torchtitan/vlr/smolvlm/model/siglip2.py +++ b/torchtitan/vlr/smolvlm/model/siglip2.py @@ -63,9 +63,18 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() + step_h = 1.0 / nb_patches_h + step_w = 1.0 / nb_patches_w + h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype) w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype) + fractional_coords_h = h_indices * step_h + fractional_coords_w = w_indices * step_w + + fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6)) + fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6)) + fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6) fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6) @@ -79,6 +88,30 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B return embeddings +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + scaling: float, + dropout: float = 0.0, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + """ + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + """ + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + class Attention(nn.Module): """ Multi-head attention module. @@ -100,26 +133,31 @@ def __init__(self, args: Siglip2ModelArgs): super().__init__() self.dim = args.dim self.head_dim = args.dim // args.n_heads + self.num_heads = args.n_heads + + self.scale = self.head_dim**-.5 self.q_proj = nn.Linear(self.dim, self.dim) self.k_proj = nn.Linear(self.dim, self.dim) self.v_proj = nn.Linear(self.dim, self.dim) self.out_proj = nn.Linear(self.dim, self.dim) - self.attn = ScaledDotProductAttentionWrapper() + self.attn = ScaledDotProductAttentionWrapper(is_causal=False) + + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): + batch_size, seq_length, embed_dim = hidden_states.shape - def forward(self, x: torch.Tensor, attention_mask: torch.Tensor): - xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - # Use self.head_dim instead of `n_heads` to infer the actual - # local heads from sizes of xq, xk, and xv as TP may have sharded them - # after the above linear ops. - xq = E.rearrange(xq, "b l (h d) -> b h l d", d=self.head_dim) - xk = E.rearrange(xk, "b l (h d) -> b h l d", d=self.head_dim) - xv = E.rearrange(xv, "b l (h d) -> b h l d", d=self.head_dim) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - output = self.attn(xq, xk, xv) - output = E.rearrange(output, "b h l d -> b l (h d)").contiguous() + output = eager_attention_forward(self, queries, keys, values, attention_mask, self.scale) + + output = output.reshape(batch_size, seq_length, embed_dim).contiguous() return self.out_proj(output) @@ -135,8 +173,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class FeedForward(nn.Module): def __init__(self, args: Siglip2ModelArgs): super().__init__() - self.fc1 = nn.Linear(args.dim, args.ffn_dim, bias=True) - self.fc2 = nn.Linear(args.ffn_dim, args.dim, bias=True) + self.fc1 = nn.Linear(args.dim, args.ffn_dim) + self.fc2 = nn.Linear(args.ffn_dim, args.dim) self.act_fn = PytorchGELUTanh() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -158,10 +196,22 @@ def __init__(self, args: Siglip2ModelArgs): self.layer_norm2 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) self.mlp = FeedForward(args) - def forward(self, x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - x = x + self.self_attn(self.layer_norm1(x), attention_mask=attention_mask) - x = x + self.mlp(self.layer_norm2(x)) - return x + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states def init_weights(self): self.layer_norm1.reset_parameters()