diff --git a/docs/extension.md b/docs/extension.md index f529b05b..98902528 100644 --- a/docs/extension.md +++ b/docs/extension.md @@ -14,7 +14,7 @@ The extension points and protocols mentioned in this note are subject to change. The coarse level abstraction tries to hit a balance between flexible component swapping and a straightforward train script ([train.py](../torchtitan/train.py)). Note that among all training components, currently [`CheckpointManager`](../torchtitan/components/checkpoint.py) and [`FTManager`](../torchtitan/components/ft.py) are not configurable since we do not expect them to be customized, but we are open to requests. -To register a `TrainSpec`, please follow the example of [Llama 3.1](../torchtitan/models/llama3/__init__.py) to `register_train_spec`. Please make sure the registration code is called before training initialization. In torchtitan, it is performed during [module import](../torchtitan/__init__.py). +To register a `TrainSpec`, please use the `register_train_spec` API, and make sure registration happens before `get_train_spec` is called during training initialization. In torchtitan, `get_train_spec` will dynamically look for models in `torchtitan/models` or `torchtitan/experiments`. ### `ModelConverter` diff --git a/run_generate.sh b/run_generate.sh new file mode 100755 index 00000000..aa7ffcf3 --- /dev/null +++ b/run_generate.sh @@ -0,0 +1,24 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overwrites for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./run_train.sh +NGPU=${NGPU:-"1"} +export LOG_RANK=${LOG_RANK:-0} +CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml"} +INFERENCE_FILE=${INFERENCE_FILE:-"torchtitan.generate"} + + +NCCL_P2P_DISABLE=1 \ +TORCH_NCCL_DUMP_ON_TIMEOUT=1 \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +-m ${INFERENCE_FILE} --job.config_file ${CONFIG_FILE} "$@" \ +--checkpoint.exclude-from-loading dataloader,optimizer,lr_scheduler \ diff --git a/run_generate_llama3.sh b/run_generate_llama3.sh new file mode 100755 index 00000000..69c4ca23 --- /dev/null +++ b/run_generate_llama3.sh @@ -0,0 +1,23 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overwrites for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./run_train.sh +NGPU=${NGPU:-"1"} +export LOG_RANK=${LOG_RANK:-0} +CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_8b.toml"} +INFERENCE_FILE=${INFERENCE_FILE:-"torchtitan.generate_llama3"} + + +NCCL_P2P_DISABLE=1 \ +TORCH_NCCL_DUMP_ON_TIMEOUT=1 \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +-m ${INFERENCE_FILE} --job.config_file ${CONFIG_FILE} "$@" diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 8103ae0b..b1f45c40 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -95,7 +95,7 @@ def estimate_memory(job_config: JobConfig): else contextlib.nullcontext() ): logger.info( - f"Building {train_spec.name} {job_config.model.flavor} with {model_args}" + f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}" ) with torch.device("meta"): model = train_spec.model_cls(model_args) diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 21322ba2..ea30ce0b 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -32,6 +32,9 @@ from torchtitan.tools.logging import init_logger, logger from torchtitan.tools.utils import device_module, device_type +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.tokenizer import HuggingFaceTokenizer + # support running w/o installing as package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) @@ -143,11 +146,26 @@ def test_generate( state_dict = model.state_dict() - # Checkpoint Loading - begin = time.monotonic() - logger.info(f"Loading chkpt at: {checkpoint_path}") - dcp.load(state_dict, checkpoint_id=checkpoint_path) - logger.info(f"Finished loading chkpt in {time.monotonic() - begin:.2f} seconds.") + # Setup checkpoint manager for loading + checkpointer = CheckpointManager( + dataloader=None, # No dataloader needed for generation + model_parts=[model], + optimizers=None, # No optimizer needed for generation + lr_schedulers=None, # No lr_scheduler needed for generation + states={}, + checkpoint_config=config.checkpoint, + sd_adapter=( + train_spec.state_dict_adapter( + model_args, config.model.hf_assets_path + ) + ), + base_folder=config.job.dump_folder, + ft_manager=None, # No fault tolerance for generation + ) + + # Load checkpoint + checkpointer.load(step=config.checkpoint.load_step) + logger.info(f"Loaded checkpoint from step {config.checkpoint.load_step}") device_mem_stats = device_memory_monitor.get_peak_stats() logger.info( diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py index 57167304..fb326a47 100644 --- a/tests/unit_tests/test_train_spec.py +++ b/tests/unit_tests/test_train_spec.py @@ -76,7 +76,6 @@ class TestTrainSpec: def test_register_train_spec(self): fake_config = {"fake": BaseModelArgs()} spec = TrainSpec( - name="fake", model_cls=FakeModel, model_args=fake_config, parallelize_fn=parallelize_llama, @@ -87,7 +86,7 @@ def test_register_train_spec(self): build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, ) - register_train_spec(spec) + register_train_spec("fake", spec) new_spec = get_train_spec("fake") assert new_spec == spec @@ -98,7 +97,6 @@ def test_optim_hook(self): fake_config = {"fake": BaseModelArgs()} spec = TrainSpec( - name="fake2", model_cls=FakeModel, model_args=fake_config, parallelize_fn=parallelize_llama, @@ -109,7 +107,7 @@ def test_optim_hook(self): build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, ) - register_train_spec(spec) + register_train_spec("fake2", spec) new_spec = get_train_spec("fake2") model = new_spec.model_cls(BaseModelArgs()) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 1b25aa3f..3527bc77 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -19,10 +19,7 @@ import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn -from torch.distributed.checkpoint import ( - HuggingFaceStorageReader, - HuggingFaceStorageWriter, -) +from torch.distributed.checkpoint import HuggingFaceStorageWriter from torch.distributed.checkpoint._consolidate_hf_safetensors import ( consolidate_safetensors_files_on_every_rank, ) @@ -249,6 +246,9 @@ def load_state_dict(state_dict): self.initial_load_model_only = checkpoint_config.initial_load_model_only self.initial_load_in_hf = checkpoint_config.initial_load_in_hf self.initial_load_path = checkpoint_config.initial_load_path + self.initial_load_in_hf_quantized = ( + checkpoint_config.initial_load_in_hf_quantized + ) self.last_save_model_only = checkpoint_config.last_save_model_only self.last_save_in_hf = checkpoint_config.last_save_in_hf if self.last_save_in_hf: @@ -339,7 +339,7 @@ def dcp_save( checkpoint_id (str): The checkpoint id to save. async_mode (AsyncMode): Whether the checkpoint is async. enable_garbage_collection (bool): Whether to enable garbage collection after save. - to_hf (bool): Whether to save in HF model definition and safetensors format. + to_hf (bool): Whether to save in HF mel definition and safetensors format. Returns: Future: The future object if the checkpoint is async, otherwise None. @@ -418,6 +418,7 @@ def dcp_load( state_dict: dict[str, Any], checkpoint_id: str, from_hf: bool, + from_quantized: bool, ) -> None: """Load the checkpoint with dcp. Args: @@ -432,10 +433,13 @@ def dcp_load( self.sd_adapter is not None ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided." hf_state_dict = self.sd_adapter.to_hf(state_dict) + hf_storage_reader = self.sd_adapter.get_hf_storage_reader( + checkpoint_id, from_quantized + ) dcp.load( hf_state_dict, - storage_reader=HuggingFaceStorageReader(path=checkpoint_id), + storage_reader=hf_storage_reader, ) state_dict = self.sd_adapter.from_hf(hf_state_dict) @@ -544,13 +548,21 @@ def load(self, step: int = -1) -> bool: model_only = False from_hf = False + from_quantized = False if not os.path.exists(self.folder): model_only = self.initial_load_model_only from_hf = self.initial_load_in_hf + from_quantized = self.initial_load_in_hf_quantized if from_hf: assert ( model_only ), "Only model can be loaded when loading from HF's safetensors checkpoint." + + if from_quantized: + assert ( + from_hf + ), "Quantized checkpoint can only be loaded from HuggingFace format." + if self.initial_load_path: checkpoint_id = self.initial_load_path if not os.path.isdir(checkpoint_id): @@ -602,6 +614,7 @@ def load(self, step: int = -1) -> bool: states, checkpoint_id=checkpoint_id, from_hf=from_hf, + from_quantized=from_quantized, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( @@ -679,6 +692,7 @@ def _ft_load(self) -> None: checkpoint_id=checkpoint_id, # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader. from_hf=False, + from_quantized=False, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index 2b14b9a8..30366113 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -23,7 +23,8 @@ def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor ) -def build_cross_entropy_loss(job_config: JobConfig): +def build_cross_entropy_loss(job_config: JobConfig, **kwargs): + del kwargs # delete any unused arguments loss_fn = cross_entropy_loss if job_config.compile.enable and "loss" in job_config.compile.components: logger.info("Compiling the loss function with torch.compile") diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index e64fe5c8..993d0373 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -34,6 +34,20 @@ class Profiling: profile_freq: int = 10 """How often to collect profile traces, in iterations""" + profiler_active: int = 1 + """ + The steps profiler is active for. + + This is used to configure torch.profile.schedule. + """ + + profiler_warmup: int = 3 + """ + The number of warmup steps before the active step in each profiling cycle. + + This is used to configure torch.profile.schedule. + """ + enable_memory_snapshot: bool = False """Whether to dump memory snapshot""" @@ -289,9 +303,11 @@ class Parallelism: within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward, trading off memory and communication. See torch's `fully_shard` API for more documentation on `reshard_after_forward`. + The supported policies include "default", "always" and "never": + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal - scenarios. + scenarios. - "always" will enable `reshard_after_forward` for all forward passes. - "never" will disable `reshard_after_forward` for all forward passes. """ @@ -399,15 +415,21 @@ class Parallelism: expert_parallel_degree: int = 1 """ Expert parallelism degree. 1 means disabled. No effect for non-MoE models. + Currently, it is supported with the following constraints: + - when etp = tp: + - cp <= ep <= dp_shard * cp - ep % cp == 0 - dp_shard * cp % ep == 0 + - when etp = 1: + - cp * tp <= ep <= dp_shard * cp * tp - ep % (cp * tp) == 0 - dp_shard * cp * tp % ep == 0 + Note that this is still an experimental feature. Some constraints will be relaxed soon when we have more flexible DeviceMesh support. """ @@ -473,6 +495,14 @@ class Checkpoint: non-tensors. The default value is False. """ + initial_load_in_hf_quantized: bool = False + """ + Enable loading of HuggingFace's safetensors format with quantized state dict keys. The option + is only used when `initial_load_path` and `initial_load_path_in_hf` is specified. This will load + checkpoints in HF's model definition and dequantize on model weights if necessary. To support + this parameter, the model need to define proper HuggingFaceStorageReader to perform dequantize. + """ + last_save_model_only: bool = True """ When last_save_model_only=True, only the model will be saved at the end of training, @@ -501,6 +531,7 @@ class Checkpoint: async_mode: Literal["disabled", "async", "async_with_pinned_mem"] = "disabled" """ Which async checkpoint mode to use. Currently there are 3 different modes. + - "disabled": synchronized checkpointing will be used. - "async": torch.distributed.checkpoint.async_save will be used. - "async_with_pinned_mem": this option utilizes a dedicated pinned memory space and creates a @@ -558,7 +589,7 @@ class Checkpoint: @dataclass class ActivationCheckpoint: - mode: Literal["selective", "full", "none"] = "selective" + mode: Literal["selective", "full", "memory_budget", "none"] = "selective" """Type of activation checkpointing to use""" selective_ac_option: str = "2" @@ -587,6 +618,24 @@ class ActivationCheckpoint: rematerialized. """ + memory_budget: float = 0.5 + """ + When mode is set to "memory_budget", this value determines how much + partitioner in the compiler should trade off compute for memory. + 0.0 corresponds to the activation memory from applying + activation checkpointing to the full compiled region, and 1.0 corresponds to + the activation memory from the default runtime-optimized strategy. Read here: + https://pytorch.org/blog/activation-checkpointing-techniques/ + """ + + visualize_memory_budget_pareto: bool = False + """ + This dumps out a SVG visualization of the expected runtime vs. activation + memory tradeoffs for all memory budget values from 0 to 1 in increments of + 0.05 in {--job.dump_folder}/memory_budget_pareto folder. See an example here: + https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015 + """ + @dataclass class Compile: diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 227c2ca2..57809c45 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -7,6 +7,7 @@ # This file provides the util functions to apply activation checkpointing to the model. # Technically, this is not a part of distributed, but distributed module is the best place to put it. +import os from collections import defaultdict import torch @@ -279,6 +280,7 @@ def apply_ac( model_compile_enabled: bool = False, use_flex_attn: bool = False, op_sac_save_list: set[torch._ops.OpOverload] | None = None, + base_folder: str = "", ) -> None: """Apply activation checkpointing to the model. @@ -297,15 +299,27 @@ def apply_ac( None """ - for layer_id, transformer_block in model.layers.named_children(): - transformer_block = _apply_ac_to_transformer_block( - transformer_block, - ac_config, - base_fqn=f"layers.{layer_id}", - model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, - op_sac_save_list=op_sac_save_list, - ) - model.layers.register_module(layer_id, transformer_block) + if ac_config.mode == "memory_budget": + assert model_compile_enabled, "Memory budget mode requires model to be compiled" + if ac_config.visualize_memory_budget_pareto: + pareto_dir = os.path.join(base_folder, "memory_budget_pareto") + if not os.path.exists(pareto_dir): + os.makedirs(pareto_dir, exist_ok=True) + torch._functorch.config.memory_budget_pareto_dir = pareto_dir + torch._functorch.config.visualize_memory_budget_pareto = True + + torch._functorch.config.activation_memory_budget = ac_config.memory_budget + logger.info(f"Selected {ac_config.memory_budget} budget option") + else: + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = _apply_ac_to_transformer_block( + transformer_block, + ac_config, + base_fqn=f"layers.{layer_id}", + model_compile_enabled=model_compile_enabled, + use_flex_attn=use_flex_attn, + op_sac_save_list=op_sac_save_list, + ) + model.layers.register_module(layer_id, transformer_block) logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index a2f1feb3..c2ec7bd7 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -106,6 +106,14 @@ def set_determinism( # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + # Ensure flex_attention is compiled without max-autotune. This is needed to ensure + # reproducibility, since the autotune results may not be deterministic. + from torch.nn.attention.flex_attention import flex_attention + + from torchtitan.models.attention import FlexAttentionWrapper + + FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) + if not world_mesh: if seed is not None: torch.manual_seed(seed) @@ -199,14 +207,6 @@ def context(cp_context: Generator[None, None, None] | None = None): torch._dynamo.utils.maybe_enable_compiled_autograd(True) ) - if cp_context is not None: - from torch.nn.attention import SDPBackend - - from torchtitan.models.attention import ScaledDotProductAttention - - if SDPBackend.MATH in ScaledDotProductAttention.backends: - ScaledDotProductAttention.backends.remove(SDPBackend.MATH) - stack.enter_context(cp_context) yield diff --git a/torchtitan/experiments/deepseek_v3/__init__.py b/torchtitan/experiments/deepseek_v3/__init__.py index f93d0d80..f5829dab 100644 --- a/torchtitan/experiments/deepseek_v3/__init__.py +++ b/torchtitan/experiments/deepseek_v3/__init__.py @@ -40,8 +40,8 @@ register_train_spec( + "deepseek3", TrainSpec( - name="deepseek3", model_cls=DeepseekForCausalLM, model_args=deepseek_configs, parallelize_fn=parallelize_deepseek, @@ -51,5 +51,5 @@ build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=get_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, - ) + ), ) diff --git a/torchtitan/experiments/flux/__init__.py b/torchtitan/experiments/flux/__init__.py index 89c3f68b..2d648f51 100644 --- a/torchtitan/experiments/flux/__init__.py +++ b/torchtitan/experiments/flux/__init__.py @@ -109,7 +109,6 @@ def get_train_spec() -> TrainSpec: return TrainSpec( - name="flux", model_cls=FluxModel, model_args=flux_configs, parallelize_fn=parallelize_flux, diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index 3d0c52c0..f8b14129 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -167,7 +167,7 @@ def __init__(self, job_config: ForgeJobConfig): if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: raise RuntimeError( - f"Pipeline Parallel is enabled but {self.train_spec.name} " + f"Pipeline Parallel is enabled but {job_config.model.name} " f"does not support pipelining" ) diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 7bd1531d..d3a7d39b 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -66,7 +66,7 @@ def __init__(self, job_config: JobConfig): model_args = self.model_args logger.info( - f"Built {self.train_spec.name} {job_config.model.flavor} with {model_args}" + f"Built {job_config.model.name} {job_config.model.flavor} with {model_args}" ) # metrics logging @@ -78,7 +78,7 @@ def __init__(self, job_config: JobConfig): self.metrics_processor.num_flops_per_token = self.num_flops_per_token logger.info( - f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " + f"{color.blue}Model {job_config.model.name} {job_config.model.flavor} " f"{color.red}size: {self.model_param_count:,} total parameters{color.reset}" ) @@ -157,15 +157,14 @@ def forward_backward_step( model_parts = self.model_parts parallel_dims = self.parallel_dims - # apply context parallelism if cp is enabled - # ensure CP handles the separate freqs_cis buffer for each pp stage inputs = input_dict["input"] - # Create the FlexAttention mask according to the input + extra_args = {} + if getattr(self.model_args, "use_flex_attn", False): - cp_mesh = ( - parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None + extra_args["attention_masks"] = model_parts[0].get_attention_masks( + input_batch=inputs, + tokenizer=self.tokenizer, ) - init_attention_mask(inputs, self.tokenizer.eos_id, cp_mesh) optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( @@ -187,11 +186,18 @@ def forward_backward_step( ) if self.pp_has_first_stage: self.pp_schedule.step( - inputs, target=targets, losses=losses, input_batch=inputs + inputs, + **extra_args, + target=targets, + losses=losses, + input_batch=inputs, ) else: self.pp_schedule.step( - target=targets, losses=losses, input_batch=inputs + **extra_args, + target=targets, + losses=losses, + input_batch=inputs, ) # accumulate losses across pipeline microbatches @@ -209,7 +215,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs) + pred = model_parts[0](inputs, **extra_args) loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred diff --git a/torchtitan/experiments/forge/train_spec.py b/torchtitan/experiments/forge/train_spec.py index b7b1d605..f9ad1d65 100644 --- a/torchtitan/experiments/forge/train_spec.py +++ b/torchtitan/experiments/forge/train_spec.py @@ -21,7 +21,6 @@ @dataclass class ForgeTrainSpec: - name: str model_cls: type[ModelProtocol] model_args: Mapping[str, BaseModelArgs] parallelize_fn: ParallelizeFunction @@ -39,7 +38,6 @@ def _transform_train_spec(original_spec: TrainSpec): """Transform the original train spec to ForgeTrainSpec format.""" # Create a new TrainSpec with only the fields we need in forge return ForgeTrainSpec( - name=original_spec.name, model_cls=original_spec.model_cls, model_args=original_spec.model_args, parallelize_fn=original_spec.parallelize_fn, @@ -51,13 +49,13 @@ def _transform_train_spec(original_spec: TrainSpec): ) -def register_train_spec(train_spec: ForgeTrainSpec) -> None: +def register_train_spec(name: str, train_spec: ForgeTrainSpec) -> None: global _extra_train_specs - if train_spec.name in _extra_train_specs: - raise ValueError(f"ForgeTrainSpec {train_spec.name} is already registered.") + if name in _extra_train_specs: + raise ValueError(f"ForgeTrainSpec {name} is already registered.") # user can define a ForgeTrainSpec from outside of torchtitan - _extra_train_specs[train_spec.name] = train_spec + _extra_train_specs[name] = train_spec def get_train_spec(name: str) -> ForgeTrainSpec: diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py index 71a2eecc..325cd6ac 100644 --- a/torchtitan/experiments/llama4/__init__.py +++ b/torchtitan/experiments/llama4/__init__.py @@ -8,13 +8,14 @@ from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.models.llama3 import pipeline_llama from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import TrainSpec from .infra.parallelize import parallelize_llama -from .model.args import TransformerModelArgs +from .model.args import RoPEScalingArgs, TransformerModelArgs from .model.model import Transformer from .model.state_dict_adapter import Llama4StateDictAdapter @@ -32,6 +33,7 @@ n_heads=16, vocab_size=2048, rope_theta=500000, + rope_scaling_args=RoPEScalingArgs(), ), "17bx16e": TransformerModelArgs( dim=5120, @@ -41,6 +43,7 @@ ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, + rope_scaling_args=RoPEScalingArgs(), max_seq_len=10485760, moe_args=MoEArgs(num_experts=16), interleave_moe_layer_step=1, @@ -61,6 +64,7 @@ n_heads=16, vocab_size=2048, rope_theta=500000, + rope_scaling_args=RoPEScalingArgs(), every_n_layers_nope=4, fixed_attn_block_size=256, use_flex_attn=True, @@ -74,6 +78,7 @@ ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, + rope_scaling_args=RoPEScalingArgs(), max_seq_len=10485760, moe_args=MoEArgs(num_experts=16), interleave_moe_layer_step=1, @@ -99,7 +104,6 @@ def get_train_spec() -> TrainSpec: return TrainSpec( - name="llama4", model_cls=Transformer, model_args=llama4_configs, parallelize_fn=parallelize_llama, @@ -109,5 +113,6 @@ def get_train_spec() -> TrainSpec: build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, state_dict_adapter=Llama4StateDictAdapter, ) diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index dba6d69e..c0607318 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -120,6 +120,7 @@ def parallelize_llama( model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP @@ -238,8 +239,8 @@ def apply_non_moe_tp( layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), + input_layouts=(Shard(1), None, None), + desired_input_layouts=(Replicate(), None, None), ), "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index e34d4d3c..faeb60aa 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -18,6 +18,14 @@ from torchtitan.tools.utils import has_cuda_capability +@dataclass +class RoPEScalingArgs: + scaling_factor: float = 16.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 1.0 + original_max_position_embeddings: int = 8192 + + @dataclass class TransformerModelArgs(BaseModelArgs): dim: int = 4096 @@ -29,6 +37,7 @@ class TransformerModelArgs(BaseModelArgs): ffn_dim_multiplier: float | None = None norm_eps: float = 1e-5 rope_theta: float = 10000 + rope_scaling_args: RoPEScalingArgs | None = None max_seq_len: int = 1048576 # If `True`, then each transformer block init uses its layer ID, and if diff --git a/torchtitan/experiments/llama4/model/model.py b/torchtitan/experiments/llama4/model/model.py index c88286e5..93ff4e89 100644 --- a/torchtitan/experiments/llama4/model/model.py +++ b/torchtitan/experiments/llama4/model/model.py @@ -4,19 +4,35 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math import torch import torch.nn.functional as F from torch import nn - -from torchtitan.models.attention import build_attention +from torch.nn.attention.flex_attention import and_masks + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + get_fixed_block_mask_mod, + ScaledDotProductAttentionWrapper, +) from torchtitan.models.moe import MoE -from torchtitan.protocols import ModelProtocol +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.protocols.train_spec import ModelProtocol -from .args import TransformerModelArgs +from .args import RoPEScalingArgs, TransformerModelArgs -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + scaling_args: RoPEScalingArgs | None = None, +) -> torch.Tensor: """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -28,11 +44,42 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Te dim (int): Dimension of the frequency tensor. end (int): End index for precomputing frequencies. theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. - + scaling_args (RoPEScalingArgs | None): RoPE scaling arguments. Defaults to None. + scaling_factor (float): RoPE scaling multiplier; larger values + stretch positions to support longer contexts. Defaults to 16.0. + low_freq_factor (float): Extra scaling applied to the low-frequency + (long-wavelength) RoPE bands. Defaults to 1.0. + high_freq_factor (float): Extra scaling applied to the high-frequency + (short-wavelength) RoPE bands. Defaults to 1.0. + original_max_position_embeddings (int): Maximum position embeddings + for original model. Defaults to 8192. Returns: torch.Tensor: Precomputed frequency tensor with complex exponentials. """ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + + # apply rope scaling + if scaling_args is not None: + scaling_factor = scaling_args.scaling_factor + low_freq_factor = scaling_args.low_freq_factor + high_freq_factor = scaling_args.high_freq_factor + original_max_position_embeddings = scaling_args.original_max_position_embeddings + wavelen = 2 * math.pi / freqs + high_freq_wavelen = original_max_position_embeddings / high_freq_factor + low_freq_wavelen = original_max_position_embeddings / low_freq_factor + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by scaling factor + freqs = torch.where(wavelen > low_freq_wavelen, freqs / scaling_factor, freqs) + # wavelen in between: linear interpolation of the scaled freqs and the original freqs + smooth_factor = ( + original_max_position_embeddings / wavelen - low_freq_factor + ) / (high_freq_factor - low_freq_factor) + smoothed_freqs = ( + 1 - smooth_factor + ) * freqs / scaling_factor + smooth_factor * freqs + is_medium_freqs = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + freqs = torch.where(is_medium_freqs, smoothed_freqs, freqs) + t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 @@ -155,9 +202,11 @@ def __init__( # values of these two variables. self.use_rope = use_rope - self.sdpa = build_attention( - model_args.use_flex_attn, model_args.attn_mask_type, fixed_block_size - ) + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -168,6 +217,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Forward pass of the attention module. @@ -202,7 +252,13 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - output = self.sdpa(xq, xk, xv) + if self.use_flex_attn: + assert isinstance(attention_masks, dict), attention_masks + attention_mask = attention_masks["rope" if self.use_rope else "nope"] + output = self.inner_attention(xq, xk, xv, block_mask=attention_mask) + else: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) output = output.transpose( 1, 2 @@ -335,6 +391,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Perform a forward pass through the TransformerBlock. @@ -347,7 +404,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis) + h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) if self.moe_enabled: out = h + self.moe(self.ffn_norm(h)) else: @@ -445,11 +502,43 @@ def _precompute_freqs_cis(self) -> torch.Tensor: # relaxing in our CP enablement PR self.model_args.max_seq_len, self.model_args.rope_theta, + self.model_args.rope_scaling_args, ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + B = input_batch.shape[0] + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + + rope_mask_mod = and_masks( + *mask_mods, + get_fixed_block_mask_mod(self.model_args.fixed_attn_block_size), + ) + nope_mask_mod = and_masks(*mask_mods) + + seqlen = input_batch.shape[1] + return { + "rope": create_attention_mask(rope_mask_mod, B, None, seqlen, seqlen), + "nope": create_attention_mask(nope_mask_mod, B, None, seqlen, seqlen), + } + def forward( self, tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -473,7 +562,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + h = layer(h, self.freqs_cis, attention_masks) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/experiments/multimodal/__init__.py b/torchtitan/experiments/multimodal/__init__.py index bbb37d5c..b35bc165 100644 --- a/torchtitan/experiments/multimodal/__init__.py +++ b/torchtitan/experiments/multimodal/__init__.py @@ -22,8 +22,8 @@ } register_train_spec( + "llama4_multimodal", TrainSpec( - name="llama4_multimodal", model_cls=MultimodalDecoder, model_args=llama4_mm_configs, parallelize_fn=parallelize_llama, @@ -33,5 +33,5 @@ build_dataloader_fn=build_mm_dataloader, build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, - ) + ), ) diff --git a/torchtitan/experiments/qwen3/__init__.py b/torchtitan/experiments/qwen3/__init__.py index b468ff96..32ba652f 100644 --- a/torchtitan/experiments/qwen3/__init__.py +++ b/torchtitan/experiments/qwen3/__init__.py @@ -180,7 +180,6 @@ def get_train_spec() -> TrainSpec: return TrainSpec( - name="qwen3", model_cls=Qwen3Model, model_args=qwen3_configs, # Change from dict to Mapping parallelize_fn=parallelize_qwen3, diff --git a/torchtitan/experiments/qwen3/infra/parallelize.py b/torchtitan/experiments/qwen3/infra/parallelize.py index 93f4caea..27406bf7 100644 --- a/torchtitan/experiments/qwen3/infra/parallelize.py +++ b/torchtitan/experiments/qwen3/infra/parallelize.py @@ -114,6 +114,7 @@ def parallelize_qwen3( model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP diff --git a/torchtitan/experiments/qwen3/model/model.py b/torchtitan/experiments/qwen3/model/model.py index f2a77e99..0fff490b 100644 --- a/torchtitan/experiments/qwen3/model/model.py +++ b/torchtitan/experiments/qwen3/model/model.py @@ -10,13 +10,23 @@ import torch import torch.nn.functional as F from torch import nn - -from torchtitan.models.attention import build_attention +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) from torchtitan.models.moe import MoE +from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol from .args import Qwen3ModelArgs + # Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_positional_embeddings.py def precompute_rope_cache( dim: int, max_seq_len: int, base: float = 1_000_000.0 @@ -133,6 +143,7 @@ def __init__(self, model_args: Qwen3ModelArgs): self.n_rep = self.n_heads // self.n_kv_heads self.head_dim = model_args.head_dim self.scaling = self.head_dim**-0.5 + self.use_flex_attn = getattr(model_args, "use_flex_attn", False) # RMSNorm added here to the here to include the q-k norm # This is one of the main differences between Llama3 and Qwen3 @@ -155,7 +166,11 @@ def __init__(self, model_args: Qwen3ModelArgs): self.wo = nn.Linear( model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -170,6 +185,7 @@ def forward( self, x: torch.Tensor, rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Forward pass of the attention module. @@ -210,7 +226,12 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - output = self.sdpa(xq, xk, xv, scale=self.scaling) + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + else: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) output = output.transpose( 1, 2 @@ -308,6 +329,7 @@ def forward( self, x: torch.Tensor, rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Perform a forward pass through the TransformerBlock. @@ -320,7 +342,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - x = x + self.attention(self.attention_norm(x), rope_cache) + x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks) if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) @@ -423,9 +445,31 @@ def _precompute_rope_cache(self) -> torch.Tensor: self.model_args.rope_theta, ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] + ) + def forward( self, tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -449,7 +493,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.rope_cache) + h = layer(h, self.rope_cache, attention_masks) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index e76370e5..df916054 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -125,6 +125,13 @@ def parallelize_deepseekv3( ): experts_shard_dim = 1 + # when EP is enable, the routed experts' gradient reduction is done over + # dp_mod_ep_mesh instead of whole dp_mesh. + # we add a `fsdp_gradient_divide_factor` to scale gradient over dp_mesh + # to be consistent with data. + # TODO (ruisizhang123): update the logic following the link below instead + # of using a reduction_divide_factor + # https://github.com/pytorch/torchtitan/pull/1803#discussion_r2415190883 transformer_block.moe.experts = data_parallel( transformer_block.moe.experts, dp_mod_ep_mesh, @@ -132,11 +139,8 @@ def parallelize_deepseekv3( ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, shard_dim=experts_shard_dim, + reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) - # TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp - # transformer_block.moe.experts.set_gradient_divide_factor( - # parallel_dims.fsdp_gradient_divide_factor, - # ) model = data_parallel( model, diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index cf3f1dd4..3e2775b7 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -85,6 +85,7 @@ def parallelize_llama( model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, ) # apply data parallel diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 8cb2a447..9ca74601 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -49,6 +49,37 @@ class MixedPrecisionPolicy: reduce_dtype: Optional[torch.dtype] = None +class _ScaledPartial(Partial): + # A subclass of Partial placement that allows user to perform reduction with a custom + # factor (reduction_divide_factor) other than the default world size. + def __init__( + self, + reduction_divide_factor: float, + ): + self.reduction_divide_factor = reduction_divide_factor + super().__init__(reduce_op="sum") + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # for all_reduce in DDP + tensor.div_(self.reduction_divide_factor) + reduced = super()._reduce_value(tensor, mesh, mesh_dim) + return reduced + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # for reduce_scatter in FSDP + tensor.div_(self.reduction_divide_factor) + reduced = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec) + return reduced + + def _distribute_dtensor( tensor: DTensor, device_mesh: DeviceMesh, @@ -192,18 +223,24 @@ def __init__( mode, regional_ac, mp_policy, + reduction_divide_factor, ): super().__init__() self.device_mesh = device_mesh self.param_sharding = param_sharding self.mode = mode self.compute_placements = [Replicate()] * self.device_mesh.ndim - self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim + self.grad_placements = [ + _ScaledPartial( + reduction_divide_factor=reduction_divide_factor, + ) + if reduction_divide_factor is not None + else Partial(reduce_op="avg") + ] * self.device_mesh.ndim self.regional_ac = regional_ac mp_policy = mp_policy or MixedPrecisionPolicy() self.param_dtype = mp_policy.param_dtype self.reduce_dtype = mp_policy.reduce_dtype - self.ep_mesh_name, self.tp_mesh_name = "ep", "tp" def replicate_compute(self, x): # data parallel runtime replicate parameters and do local compute @@ -286,6 +323,7 @@ def data_parallel( ac_mode: str = "none", mp_policy: Optional[MixedPrecisionPolicy] = None, shard_dim: int = 0, + reduction_divide_factor: Optional[float] = None, ): if mode == "replicate": param_sharding = (Replicate(),) @@ -348,6 +386,7 @@ def data_parallel( mode, regional_ac, mp_policy=mp_policy, + reduction_divide_factor=reduction_divide_factor, ), ) return model diff --git a/torchtitan/experiments/vlm/__init__.py b/torchtitan/experiments/vlm/__init__.py index 7d62a8ed..19452ac1 100644 --- a/torchtitan/experiments/vlm/__init__.py +++ b/torchtitan/experiments/vlm/__init__.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import asdict, replace +from dataclasses import fields +from typing import Any from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers @@ -27,9 +28,14 @@ ] +def _get_dict(obj) -> dict[str, Any]: + """Convert dataclass to dict, preserving nested dataclasses (unlike asdict).""" + return {field.name: getattr(obj, field.name) for field in fields(obj)} + + llama3_siglip2_configs = { "debugmodel": Llama3Siglip2ModelArgs( - **asdict(replace(llama3_configs["debugmodel"], vocab_size=2048)), + **_get_dict(llama3_configs["debugmodel_flex_attn"]), encoder=Siglip2ModelArgs( dim=128, ffn_dim=256, @@ -42,7 +48,6 @@ def get_train_spec() -> TrainSpec: return TrainSpec( - name="llama3-siglip2", model_cls=Llama3Siglip2Transformer, model_args=llama3_siglip2_configs, parallelize_fn=parallelize_vlm, diff --git a/torchtitan/experiments/vlm/infra/loss.py b/torchtitan/experiments/vlm/infra/loss.py new file mode 100644 index 00000000..bba51f28 --- /dev/null +++ b/torchtitan/experiments/vlm/infra/loss.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch import distributed as dist +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.components.ft.manager import FTManager +from torchtitan.config.job_config import JobConfig +from torchtitan.distributed.parallel_dims import ParallelDims +from torchtitan.tools.logging import logger + + +IGNORE_INDEX = -100 # Pytorch's default for F.cross_entropy + + +# WARNING: currently this does not take into account gradient accumulation +# and the gradient can still be biased toward grad accum step with less valid tokens +# See: https://github.com/pytorch/torchtitan/issues/1842 +def token_imbalance_ce_loss( + pred: torch.Tensor, + labels: torch.Tensor, + token_mesh: DeviceMesh, + ft_pg: dist.ProcessGroup | None, +) -> torch.Tensor: + """ + Cross‑entropy loss that is *robust* to varying numbers of valid tokens across ranks. + + In a typical distributed training setup (data parallel + sequence parallel), + each rank computes the loss over **only its local tokens** and returns an + *average* over those tokens: + + Afterwards, when Fully‑Sharded Data Parallel (FSDP) averages the gradients + across all ranks, the resulting update is equivalent to a **global sample + average** *only if every rank contains the same number of tokens*. + In practice that assumption is violated for many workloads: + - Sequences are padded to a fixed length -> some ranks see fewer real tokens. + - SFT finetuning where user's queries tokens are masked out. + - Vision encoders often injects a large number of “ignored” + tokens as context that are not trained with text tokens' loss. + + This function fixes the issue by **scaling the sum-of-loss** with the *average* + number of non‑ignored tokens per rank, computed via an all-reduce over + `token_mesh`. The returned scalar therefore represents the loss that would + be obtained if every token in the entire distributed batch contributed with + equal weight to the global gradient, regardless of how many padded or + ignored tokens each rank contains. + + Parameters + ---------- + pred : torch.Tensor + labels : torch.Tensor + token_mesh : DeviceMesh + A device mesh that contains all ranks participating in this training step's + loss computation. The function performs an ``all_reduce`` (mean) over the + `num_tokens` tensor of a rank across this mesh. + ft_pg: dist.ProcessGroup | None + Optional pg for Fault Tolerance training. + + Returns + ------- + torch.Tensor + A scalar loss tensor, ready for ``backward()`` and FSDP all-reduce mean + + Notes + ----- + * The function internally uses :func:`torch.nn.functional.cross_entropy` + with ``reduction="sum"`` so that each token contributes exactly once to + the numerator. The denominator is the **average** number of valid tokens + per rank, not the local count. + * If a rank contains no valid tokens (i.e., all labels are ``IGNORE_INDEX``), + its contribution to the sum is zero and its `num_tokens` becomes zero. + In that case the mean across ranks will still be well‑defined as long as + at least one rank has non‑zero token count. + """ + sum_loss = torch.nn.functional.cross_entropy( + pred.flatten(0, 1).float(), + labels.flatten(0, 1), + reduction="sum", + ignore_index=IGNORE_INDEX, + ) + num_tokens = (labels != IGNORE_INDEX).sum() + avg_num_tokens_per_rank = funcol.all_reduce( + num_tokens, reduceOp=c10d.ReduceOp.AVG.name, group=token_mesh + ) + if ft_pg is not None: + avg_num_tokens_per_rank = funcol.all_reduce( + avg_num_tokens_per_rank, reduceOp=c10d.ReduceOp.AVG.name, group=ft_pg + ) + return sum_loss / avg_num_tokens_per_rank + + +def build_token_imbalance_ce_loss( + job_config: JobConfig, parallel_dims: ParallelDims, ft_manager: FTManager, **kwargs +): + del kwargs # delete any unused arguments + # NOTE: The device mesh where the input tokens w/ shape BSD can be sliced: + # DP split the batch dim B + # CP split the sequence dim S + token_mesh = parallel_dims.world_mesh["dp_cp"] + ft_pg = ft_manager.loss_sync_pg + loss_fn = partial(token_imbalance_ce_loss, token_mesh=token_mesh, ft_pg=ft_pg) + if job_config.compile.enable and "loss" in job_config.compile.components: + logger.info("Compiling the loss function with torch.compile") + loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend) + return loss_fn diff --git a/torchtitan/experiments/vlm/model/model.py b/torchtitan/experiments/vlm/model/model.py index 71c8a739..712cd805 100644 --- a/torchtitan/experiments/vlm/model/model.py +++ b/torchtitan/experiments/vlm/model/model.py @@ -7,8 +7,11 @@ import einops as E import torch from torch import nn +from torch.nn.attention.flex_attention import BlockMask +from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.llama3 import Transformer as Llama3 +from torchtitan.protocols.model import AttentionMasksType from ..datasets.mm_datasets import SpecialTokens @@ -71,28 +74,49 @@ def init_weights(self, buffer_device=None): if self.projector is not None: self.projector.init_weights() + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + masks = super().get_attention_masks(input_batch, tokenizer, extra_inputs) + assert isinstance(masks, BlockMask) + if self.encoder is not None: + encoder_masks = self.encoder.get_attention_masks( + input_batch, tokenizer, extra_inputs + ) + assert isinstance(encoder_masks, BlockMask) + return {"llama3_masks": masks, "encoder_masks": encoder_masks} + def forward( self, tokens: torch.Tensor, pixel_values: torch.Tensor, grid_thw: torch.Tensor, special_tokens: SpecialTokens, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h_BSD = self.tok_embeddings(tokens) if self.tok_embeddings else tokens if self.encoder is not None: + assert ( + attention_masks is not None + ), "encoder only allows FlexAttention, so the llama3 must use FlexAttention as well." grid_hw = grid_thw[:, :, 1:] # Siglip2 only support image hw pixel_masks = E.reduce(grid_hw != -1, "n l hw -> n l", reduction="all") - i_NLD = self.encoder(pixel_values, pixel_masks, grid_hw) + i_NLD = self.encoder( + pixel_values, pixel_masks, grid_hw, attention_masks["encoder_masks"] + ) i_NLD = self.projector(i_NLD) h_BSD = _scatter_img_tokens( h_BSD, tokens, i_NLD, pixel_masks, special_tokens.img_id ) for layer in self.layers.values(): - h_BSD = layer(h_BSD, self.freqs_cis) + h_BSD = layer(h_BSD, self.freqs_cis, attention_masks["llama3_masks"]) h_BSD = self.norm(h_BSD) if self.norm else h_BSD output = self.output(h_BSD) if self.output else h_BSD diff --git a/torchtitan/experiments/vlm/model/siglip2.py b/torchtitan/experiments/vlm/model/siglip2.py index a1183f7c..69278350 100644 --- a/torchtitan/experiments/vlm/model/siglip2.py +++ b/torchtitan/experiments/vlm/model/siglip2.py @@ -8,8 +8,16 @@ import torch import torch.nn.functional as F from torch import nn +from torch.nn.attention.flex_attention import and_masks, BlockMask -from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, +) +from torchtitan.protocols.model import AttentionMasksType from .args import Siglip2ModelArgs @@ -125,11 +133,9 @@ def __init__(self, args: Siglip2ModelArgs): self.v_proj = nn.Linear(self.dim, self.dim) self.out_proj = nn.Linear(self.dim, self.dim) - self.attn = build_attention( - use_flex_attn=True, attn_mask_type=args.attn_mask_type - ) + self.inner_attention = FlexAttentionWrapper() - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, attention_masks: AttentionMasksType): xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) # Use self.head_dim instead of `n_heads` to infer the actual @@ -139,7 +145,8 @@ def forward(self, x: torch.Tensor): xk = E.rearrange(xk, "b l (h d) -> b h l d", d=self.head_dim) xv = E.rearrange(xv, "b l (h d) -> b h l d", d=self.head_dim) - output = self.attn(xq, xk, xv) + assert isinstance(attention_masks, BlockMask) + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) output = E.rearrange(output, "b h l d -> b l (h d)").contiguous() return self.out_proj(output) @@ -174,8 +181,10 @@ def __init__(self, args: Siglip2ModelArgs): self.layer_norm2 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) self.mlp = FeedForward(args) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.self_attn(self.layer_norm1(x)) + def forward( + self, x: torch.Tensor, attention_masks: AttentionMasksType + ) -> torch.Tensor: + x = x + self.self_attn(self.layer_norm1(x), attention_masks) x = x + self.mlp(self.layer_norm2(x)) return x @@ -198,18 +207,46 @@ def __init__(self, args: Siglip2ModelArgs): ) self.post_layernorm = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + + # TODO: this is duplicated in the main model forward. + # TODO: is this really required? Can we call this `get_attention_masks` + # inside the main model forward? At that time PP should already split the + # grid_thw correctly. + grid_hw = extra_inputs["grid_thw"][:, :, 1:] # Siglip2 only support image hw + pixel_masks = E.reduce(grid_hw != -1, "n l hw -> n l", reduction="all") + + mask_mods = [get_causal_mask_mod()] + match self.args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = pixel_masks.shape[0] + mask_mods.append(get_document_mask_mod(pixel_masks, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, pixel_masks.shape[1], pixel_masks.shape[1] + ) + def forward( self, pixel_values_NLD: torch.FloatTensor, pixel_masks_NL: torch.BoolTensor, grid_hw: torch.LongTensor, + attention_masks: AttentionMasksType, ): - init_attention_mask(pixel_masks_NL, eos_id=self.eos_id) - h = self.embeddings(pixel_values_NLD, grid_hw) for layer in self.layers.values(): - h = layer(h) + h = layer(h, attention_masks) h = self.post_layernorm(h) return h diff --git a/torchtitan/experiments/vlm/train_configs/debug_model.toml b/torchtitan/experiments/vlm/train_configs/debug_model.toml index c4f97463..91b7c0c3 100644 --- a/torchtitan/experiments/vlm/train_configs/debug_model.toml +++ b/torchtitan/experiments/vlm/train_configs/debug_model.toml @@ -23,7 +23,7 @@ save_tb_folder = "tb" enable_wandb = false [model] -name = "llama3-siglip2" +name = "vlm" flavor = "debugmodel" # test folder with tokenizer.json, for debug purpose only hf_assets_path = "tests/assets/tokenizer" diff --git a/torchtitan/generate.py b/torchtitan/generate.py index 475e9467..1552e24f 100644 --- a/torchtitan/generate.py +++ b/torchtitan/generate.py @@ -11,50 +11,132 @@ import numpy as np import torch +torch.set_printoptions(threshold=10_000) from torch.distributed.elastic.multiprocessing.errors import record from transformers import AutoProcessor from PIL import Image import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager -from torchtitan.components.tokenizer import HuggingFaceTokenizer from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger +# --- Generation utilities from scripts/generate/_generation.py --- + +def multinomial_sample_one( + probs: torch.Tensor, rng: Optional[torch.Generator] = None +) -> torch.Tensor: + q = torch.empty_like(probs).exponential_(1, generator=rng) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) + + +def logits_to_probs( + logits: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, +) -> torch.Tensor: + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) + pivot = v.select(dim=-1, index=-1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def generate_next_token( + model, + x: torch.Tensor, + *, + temperature: float = 1.0, + top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, + **model_kwargs, +) -> torch.Tensor: + input_dict = { + "input_ids": x, + **model_kwargs, + } + logits = model(**input_dict) + probs = logits_to_probs(logits[:, -1, :], temperature, top_k) + next_token = multinomial_sample_one(probs, rng=rng) + return next_token + + +@torch.no_grad() +def _generate_sequence( + model, + input_ids: torch.Tensor, + max_new_tokens: int, + temperature: float = 1.0, + pixel_values: torch.Tensor | None = None, + patch_attention_mask: torch.BoolTensor | None = None, + top_k: Optional[int] = None, + seed: Optional[int] = None, +) -> torch.Tensor: + # ensure batch dimension (T,) --> (B, T) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + rng = None + if seed is not None: + rng = torch.Generator(input_ids.device).manual_seed(seed) + + generated_tokens = input_ids.clone() + + for _ in range(max_new_tokens): + next_token = generate_next_token( + model, + x=generated_tokens, + temperature=temperature, + top_k=top_k, + rng=rng, + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + generated_tokens = torch.cat([generated_tokens, next_token], dim=1) + + return generated_tokens + +# --- End of generation utilities --- + class Generator: """Generator class for SmolVLM model inference.""" - + def __init__(self, job_config: JobConfig): torch._C._log_api_usage_once("torchtitan.generate") - + self.job_config = job_config - + logger.info(f"Starting generation: {job_config.job.description}") - + if job_config.experimental.custom_import: importlib.import_module(job_config.experimental.custom_import) - + if job_config.job.print_args: logger.info(f"Running with args: {job_config.to_dict()}") - + device_module, device_type = utils.device_module, utils.device_type self.device = torch.device(f"{device_type}:{int(os.environ.get('LOCAL_RANK', 0))}") device_module.set_device(self.device) - - # Initialize distributed - dist_utils.init_distributed( - job_config.comm, - enable_cpu_backend=False, - base_folder=job_config.job.dump_folder, - ) - + world_size = int(os.environ.get("WORLD_SIZE", 1)) + if world_size > 1: + dist_utils.init_distributed( + job_config.comm, + enable_cpu_backend=False, + base_folder=job_config.job.dump_folder, + ) + parallelism_config = job_config.parallelism - self.parallel_dims = parallel_dims = ParallelDims( + self.parallel_dims = ParallelDims( dp_shard=parallelism_config.data_parallel_shard_degree, dp_replicate=parallelism_config.data_parallel_replicate_degree, cp=parallelism_config.context_parallel_degree, @@ -64,66 +146,49 @@ def __init__(self, job_config: JobConfig): etp=parallelism_config.expert_tensor_parallel_degree, world_size=world_size, ) - - world_mesh = parallel_dims.world_mesh - - # Set random seed + dist_utils.set_determinism( - world_mesh, + self.parallel_dims.world_mesh if world_size > 1 else None, self.device, job_config.training.seed, deterministic=False, ) - + self.train_spec = train_spec_module.get_train_spec(job_config.model.name) - - # Build tokenizer - self.tokenizer = ( - self.train_spec.build_tokenizer_fn(job_config) - if self.train_spec.build_tokenizer_fn is not None - else None - ) - - # Build model + + self.tokenizer = self.train_spec.build_tokenizer_fn(job_config) + model_args = self.train_spec.model_args[job_config.model.flavor] model_args.update_from_config(job_config) self.model_args = model_args - - logger.info( - f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" - ) - + with ( torch.device("meta"), utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), ): model = self.train_spec.model_cls(model_args) - - # Build model converters (e.g., for float8) - model_converters = build_model_converters(job_config, parallel_dims) + + model_converters = build_model_converters(job_config, self.parallel_dims) model_converters.convert(model) - - # Apply parallelism - if parallel_dims.pp_enabled: + + if self.parallel_dims.pp_enabled: raise NotImplementedError("Pipeline parallelism not supported for generation") else: - model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) - - # Move to device and initialize + model = self.train_spec.parallelize_fn(model, self.parallel_dims, job_config) + init_device = self.device.type model.to_empty(device=init_device) with torch.no_grad(): model.init_weights() model.eval() - + self.model_parts = [model] - - # Setup checkpoint manager for loading + self.checkpointer = CheckpointManager( - dataloader=None, # No dataloader needed for generation + dataloader=None, model_parts=self.model_parts, - optimizers=None, # No optimizer needed for generation - lr_schedulers=None, # No lr_scheduler needed for generation + optimizers=None, + lr_schedulers=None, states={}, checkpoint_config=job_config.checkpoint, sd_adapter=( @@ -134,19 +199,15 @@ def __init__(self, job_config: JobConfig): else None ), base_folder=job_config.job.dump_folder, - ft_manager=None, # No fault tolerance for generation + ft_manager=None, ) - - # Load checkpoint + self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Loaded checkpoint from step {job_config.checkpoint.load_step}") - - # Setup HF processor for image processing - #processor_path = getattr(model_args, 'tokenizer_name',) - self.processor = AutoProcessor.from_pretrained('HuggingFaceTB/SmolVLM2-256M-Video-Instruct') + + self.processor = AutoProcessor.from_pretrained(job_config.model.hf_assets_path) self.image_processor = self.processor.image_processor - - # Load chat template + template_path = "torchtitan/vlr/smolvlm/datasets/template.jinja" if os.path.exists(template_path): with open(template_path, 'r') as f: @@ -154,71 +215,41 @@ def __init__(self, job_config: JobConfig): else: logger.warning(f"Chat template not found at {template_path}, using default") self.chat_template = None - - # Setup generation parameters - self.max_new_tokens = getattr(job_config, 'max_new_tokens', 256) - self.temperature = getattr(job_config, 'temperature', 0.7) - self.top_p = getattr(job_config, 'top_p', 0.9) - self.top_k = getattr(job_config, 'top_k', 50) - + + self.max_new_tokens = getattr(job_config, 'max_new_tokens', 16) + self.temperature = getattr(job_config, 'temperature', 0) + self.top_k = getattr(job_config, 'top_k', None) + logger.info("Generator initialized successfully") - + @torch.no_grad() def generate( self, - messages: List[Dict[str, Any]], + messages: List[Dict[str, Any]] = None, images: Optional[List[Image.Image]] = None, max_new_tokens: Optional[int] = None, temperature: Optional[float] = None, - top_p: Optional[float] = None, top_k: Optional[int] = None, - do_sample: bool = True, + seed: Optional[int] = None, ) -> str: - """Generate text from messages and optional images. - - Args: - messages: List of message dictionaries with 'role' and 'content' - images: Optional list of PIL images - max_new_tokens: Maximum number of tokens to generate - temperature: Sampling temperature - top_p: Top-p sampling parameter - top_k: Top-k sampling parameter - do_sample: Whether to use sampling or greedy decoding - - Returns: - Generated text string - """ max_new_tokens = max_new_tokens or self.max_new_tokens temperature = temperature or self.temperature - top_p = top_p or self.top_p top_k = top_k or self.top_k - + model = self.model_parts[0] model.eval() - - # Process images if provided + pixel_values = None patch_attention_mask = None - + if images: - # Process images using HF processor - vision_inputs = self.image_processor(images) - pixel_values = torch.tensor( - np.array(vision_inputs['pixel_values']) - ).to(self.device, dtype=torch.bfloat16) - + vision_inputs = self.image_processor(images, return_tensors="pt") + pixel_values = vision_inputs['pixel_values'].to(self.device, dtype=torch.bfloat16) + if 'pixel_attention_mask' in vision_inputs: - patch_attention_mask = torch.tensor( - vision_inputs['pixel_attention_mask'] - ).to(self.device) - - # Handle batch dimension - if pixel_values.dim() == 4: - pixel_values = pixel_values.unsqueeze(0) - if patch_attention_mask is not None and patch_attention_mask.dim() == 3: - patch_attention_mask = patch_attention_mask.unsqueeze(0) - - # Tokenize input + patch_attention_mask = vision_inputs['pixel_attention_mask'].to(self.device) + + """ if self.chat_template: input_ids = self.tokenizer.apply_chat_template( messages, @@ -228,167 +259,116 @@ def generate( return_tensors="pt", ) else: - # Fallback to default chat template input_ids = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", ) - + if isinstance(input_ids, dict): input_ids = input_ids["input_ids"] - + input_ids = input_ids.to(self.device) - - # Setup generation context (compile if enabled) - generate_fn = self._generate_tokens - if self.job_config.compile.enable and "model" in self.job_config.compile.components: - generate_fn = torch.compile(generate_fn, mode="reduce-overhead") - - # Generate tokens - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - output_ids = generate_fn( + """ + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": messages}, + {"type": "image", "image": images}, + ] + }, + ] + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Caption this image"}, + {"type": "image", "image": "../cat.jpg"}, + ] + }, + ] + + + inputs = self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=False, + return_dict=True, + return_tensors="pt", + ) + + inputs = self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ).to(self.device, dtype=torch.bfloat16) + + input_ids = inputs['input_ids'] + pixel_values = inputs.get('pixel_values', None) + patch_attention_mask = inputs.get('pixel_attention_mask', None) + + + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + logits = model( + input_ids=input_ids, + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + print(logits) + + """ + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + output_ids = _generate_sequence( model=model, input_ids=input_ids, pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, - max_new_tokens=max_new_tokens, + max_new_tokens=64, temperature=temperature, - top_p=top_p, top_k=top_k, - do_sample=do_sample, + seed=seed, ) - - # Decode output - generated_ids = output_ids[0, input_ids.shape[1]:] - generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) - - return generated_text - - def _generate_tokens( - self, - model: torch.nn.Module, - input_ids: torch.Tensor, - pixel_values: Optional[torch.Tensor], - patch_attention_mask: Optional[torch.Tensor], - max_new_tokens: int, - temperature: float, - top_p: float, - top_k: int, - do_sample: bool, - ) -> torch.Tensor: - """Core generation loop.""" - - batch_size = input_ids.shape[0] - generated_ids = input_ids.clone() - - # Cache for key-value pairs (if using KV cache in the future) - past_key_values = None - - for _ in range(max_new_tokens): - # Forward pass - with torch.no_grad(): - # Prepare input dict - input_dict = { - "input_ids": generated_ids, - "eos_id": self.tokenizer.eos_token, - } - - if pixel_values is not None: - input_dict["pixel_values"] = pixel_values - - if patch_attention_mask is not None: - input_dict["patch_attention_mask"] = patch_attention_mask - - # Get model output - logits = model(**input_dict) - - # Get next token logits - next_token_logits = logits[:, -1, :] - - # Apply temperature - if temperature > 0: - next_token_logits = next_token_logits / temperature - - # Sample or greedy decode - if do_sample: - # Apply top-k filtering - if top_k > 0: - indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] - next_token_logits[indices_to_remove] = -float('inf') - - # Apply top-p filtering - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - next_token_logits[indices_to_remove] = -float('inf') - - # Sample - probs = torch.softmax(next_token_logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - else: - # Greedy decoding - next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) - - # Append to generated sequence - generated_ids = torch.cat([generated_ids, next_token], dim=1) - - # Check for EOS token - if (next_token == self.tokenizer.eos_token): - break - - return generated_ids - + + print(output_ids.v) + generated_text = self.processor.batch_decode(output_ids, skip_special_tokens=True) + + print(generated_text) + """ + def interactive_generate(self): """Interactive generation mode for testing.""" logger.info("Starting interactive generation mode. Type 'quit' to exit.") - + while True: try: user_input = input("\nEnter your prompt (or 'quit' to exit): ").strip() - + if user_input.lower() == 'quit': break - - # Check if user wants to include an image + image_path = input("Enter image path (or press Enter to skip): ").strip() - + images = None if image_path and os.path.exists(image_path): image = Image.open(image_path).convert('RGB') - # Resize to expected size - image = image.resize((512, 512)) - images = [image] logger.info(f"Loaded image from {image_path}") elif image_path: logger.warning(f"Image path {image_path} not found, proceeding without image") - - # Create message format - messages = [ - { - "user": user_input, - "assistant": "" # Will be filled by generation - } - ] - + logger.info("Generating response...") start_time = time.perf_counter() - - response = self.generate(messages, images=images) - + + response = self.generate(user_input, images=image) + generation_time = time.perf_counter() - start_time - logger.info(f"Generation completed in {generation_time:.2f}s") - - print(f"\nGenerated response:\n{response}") - + except KeyboardInterrupt: logger.info("\nInterrupted by user") break @@ -396,49 +376,7 @@ def interactive_generate(self): logger.error(f"Error during generation: {e}") import traceback traceback.print_exc() - - def batch_generate(self, input_file: str, output_file: str): - """Generate responses for a batch of inputs from a file. - - Args: - input_file: Path to JSON file with inputs - output_file: Path to save outputs - """ - import json - - logger.info(f"Loading inputs from {input_file}") - - with open(input_file, 'r') as f: - inputs = json.load(f) - - results = [] - for i, item in enumerate(inputs): - logger.info(f"Processing item {i+1}/{len(inputs)}") - - messages = item.get('messages', []) - image_paths = item.get('images', []) - - # Load images if provided - images = [] - for path in image_paths: - if os.path.exists(path): - image = Image.open(path).convert('RGB').resize((512, 512)) - images.append(image) - - # Generate response - response = self.generate(messages, images=images if images else None) - - results.append({ - 'input': item, - 'output': response - }) - - # Save results - with open(output_file, 'w') as f: - json.dump(results, f, indent=2) - - logger.info(f"Results saved to {output_file}") - + def close(self): """Cleanup resources.""" if hasattr(self, 'checkpointer'): @@ -450,36 +388,15 @@ def close(self): def main(): """Main entry point for generation.""" init_logger() - - # Parse configuration + config_manager = ConfigManager() config = config_manager.parse_args() - + generator = None try: - # Initialize generator generator = Generator(config) - - # Check for generation mode from config or command line - generation_mode = getattr(config, 'generation_mode', 'interactive') - - if generation_mode == 'interactive': - generator.interactive_generate() - elif generation_mode == 'batch': - input_file = getattr(config, 'input_file', 'inputs.json') - output_file = getattr(config, 'output_file', 'outputs.json') - generator.batch_generate(input_file, output_file) - else: - # Single generation example - messages = [ - { - "user": "What is the capital of France?", - "assistant": "" - } - ] - response = generator.generate(messages) - logger.info(f"Generated: {response}") - + generator.generate() + except Exception as e: logger.error(f"Error during generation: {e}") if generator: @@ -488,7 +405,8 @@ def main(): else: if generator: generator.close() - torch.distributed.destroy_process_group() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() logger.info("Process group destroyed") diff --git a/torchtitan/generate_llama3.py b/torchtitan/generate_llama3.py new file mode 100644 index 00000000..eb08813b --- /dev/null +++ b/torchtitan/generate_llama3.py @@ -0,0 +1,298 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os +import time +from typing import Optional, List, Dict, Any + +import torch +from torch.distributed.elastic.multiprocessing.errors import record + +import torchtitan.protocols.train_spec as train_spec_module +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.protocols.model_converter import build_model_converters +from torchtitan.tools import utils +from torchtitan.tools.logging import init_logger, logger + +# --- Generation utilities from scripts/generate/_generation.py --- + +def multinomial_sample_one( + probs: torch.Tensor, rng: Optional[torch.Generator] = None +) -> torch.Tensor: + q = torch.empty_like(probs).exponential_(1, generator=rng) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) + + +def logits_to_probs( + logits: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, +) -> torch.Tensor: + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) + pivot = v.select(dim=-1, index=-1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def generate_next_token( + model, + x: torch.Tensor, + *, + temperature: float = 1.0, + top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, +) -> torch.Tensor: + # The model forward pass in torchtitan expects a `tokens` argument. + logits = model(tokens=x) + probs = logits_to_probs(logits[:, -1, :], temperature, top_k) + next_token = multinomial_sample_one(probs, rng=rng) + return next_token + + +@torch.no_grad() +def _generate_sequence( + model, + input_ids: torch.Tensor, + *, + max_new_tokens: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + seed: Optional[int] = None, +) -> torch.Tensor: + # ensure batch dimension (T,) --> (B, T) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + rng = None + if seed is not None: + rng = torch.Generator(input_ids.device).manual_seed(seed) + + generated_tokens = input_ids.clone() + + for _ in range(max_new_tokens): + next_token = generate_next_token( + model, + x=generated_tokens, + temperature=temperature, + top_k=top_k, + rng=rng, + ) + + generated_tokens = torch.cat([generated_tokens, next_token], dim=1) + + return generated_tokens + +# --- End of generation utilities --- + + +class Generator: + """Generator class for Llama3 model inference.""" + + def __init__(self, job_config: JobConfig): + torch._C._log_api_usage_once("torchtitan.generate") + + self.job_config = job_config + + logger.info(f"Starting generation: {job_config.job.description}") + + if job_config.experimental.custom_import: + importlib.import_module(job_config.experimental.custom_import) + + if job_config.job.print_args: + logger.info(f"Running with args: {job_config.to_dict()}") + + device_module, device_type = utils.device_module, utils.device_type + self.device = torch.device(f"{device_type}:{int(os.environ.get('LOCAL_RANK', 0))}") + device_module.set_device(self.device) + + # For generation, we usually use a single process or TP. + # We will not initialize the full distributed setup unless necessary. + world_size = int(os.environ.get("WORLD_SIZE", 1)) + if world_size > 1: + dist_utils.init_distributed( + job_config.comm, + enable_cpu_backend=False, + base_folder=job_config.job.dump_folder, + ) + + parallelism_config = job_config.parallelism + self.parallel_dims = ParallelDims( + dp_shard=parallelism_config.data_parallel_shard_degree, + dp_replicate=parallelism_config.data_parallel_replicate_degree, + cp=parallelism_config.context_parallel_degree, + tp=parallelism_config.tensor_parallel_degree, + pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, + etp=parallelism_config.expert_tensor_parallel_degree, + world_size=world_size, + ) + + dist_utils.set_determinism( + self.parallel_dims.world_mesh if world_size > 1 else None, + self.device, + job_config.training.seed, + deterministic=False, + ) + + self.train_spec = train_spec_module.get_train_spec(job_config.model.name) + + self.tokenizer = self.train_spec.build_tokenizer_fn(job_config) + + model_args = self.train_spec.model_args[job_config.model.flavor] + model_args.update_from_config(job_config) + self.model_args = model_args + + with ( + torch.device("meta"), + utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), + ): + model = self.train_spec.model_cls(model_args) + + model_converters = build_model_converters(job_config, self.parallel_dims) + model_converters.convert(model) + + if self.parallel_dims.pp_enabled: + raise NotImplementedError("Pipeline parallelism not supported for generation") + else: + model = self.train_spec.parallelize_fn(model, self.parallel_dims, job_config) + + init_device = self.device.type + model.to_empty(device=init_device) + with torch.no_grad(): + model.init_weights() + model.eval() + + self.model_parts = [model] + + self.checkpointer = CheckpointManager( + dataloader=None, + model_parts=self.model_parts, + optimizers=None, + lr_schedulers=None, + states={}, + checkpoint_config=job_config.checkpoint, + sd_adapter=( + self.train_spec.state_dict_adapter( + model_args, job_config.model.hf_assets_path + ) + if self.train_spec.state_dict_adapter + else None + ), + base_folder=job_config.job.dump_folder, + ft_manager=None, + ) + + self.checkpointer.load(step=job_config.checkpoint.load_step) + logger.info(f"Loaded checkpoint from step {job_config.checkpoint.load_step}") + + self.max_new_tokens = getattr(job_config, 'max_new_tokens', 256) + self.temperature = getattr(job_config, 'temperature', 0.7) + self.top_k = getattr(job_config, 'top_k', 50) + + logger.info("Generator initialized successfully") + + @torch.no_grad() + def generate( + self, + prompts: List[str], + max_new_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + seed: Optional[int] = None, + ) -> List[str]: + max_new_tokens = max_new_tokens or self.max_new_tokens + temperature = temperature or self.temperature + top_k = top_k or self.top_k + + model = self.model_parts[0] + model.eval() + + # For simplicity, this example handles one prompt at a time. + # Batching can be added for efficiency. + generated_texts = [] + for prompt in prompts: + input_ids = self.tokenizer.encode(prompt, add_bos=True, add_eos=False) + input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device) + + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + output_ids = _generate_sequence( + model=model, + input_ids=input_ids, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + seed=seed, + ) + + generated_ids = output_ids[0, input_ids.shape[0]:] + generated_text = self.tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True) + generated_texts.append(generated_text) + + return generated_texts + + def close(self): + """Cleanup resources.""" + if hasattr(self, 'checkpointer'): + self.checkpointer.close() + logger.info("Generator closed") + + +@record +def main(): + """Main entry point for generation.""" + init_logger() + + # Parse configuration + config_manager = ConfigManager() + config = config_manager.parse_args() + + generator = None + try: + # Initialize generator + generator = Generator(config) + + prompts = [ + "What is the meaning of life?", + "Translate 'hello world' to French.", + ] + + logger.info(f"Generating for prompts: {prompts}") + start_time = time.perf_counter() + + responses = generator.generate(prompts) + + generation_time = time.perf_counter() - start_time + logger.info(f"Generation completed in {generation_time:.2f}s") + + for prompt, response in zip(prompts, responses): + print("-" * 20) + print(f"Prompt: {prompt}") + print(f"Response: {response}") + print("-" * 20) + + except Exception as e: + logger.error(f"Error during generation: {e}") + if generator: + generator.close() + raise + else: + if generator: + generator.close() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + logger.info("Process group destroyed") + + +if __name__ == "__main__": + main() diff --git a/torchtitan/generate_simple.py b/torchtitan/generate_simple.py new file mode 100644 index 00000000..f21021f9 --- /dev/null +++ b/torchtitan/generate_simple.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import time +from pathlib import Path +from typing import Optional + +import torch +import numpy as np +from PIL import Image +from transformers import AutoProcessor + +from torchtitan.config import JobConfig, ConfigManager, TORCH_DTYPE_MAP +from torchtitan.tools import utils +from torchtitan.tools.logging import init_logger, logger +from torchtitan.components.tokenizer import HuggingFaceTokenizer +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.distributed import ParallelDims, utils as dist_utils + +# Import SmolVLM specific components +from torchtitan.vlr.smolvlm.model.args import Llama3Siglip2ModelArgs, Siglip2ModelArgs +from torchtitan.vlr.smolvlm.model.model import Llama3Siglip2Transformer +from torchtitan.vlr.smolvlm.model.state_dict_adapter import SmolVLMStateDictAdapter + + +class SimpleGenerator: + """Barebones generator for debugging using CheckpointManager.""" + + def __init__(self, job_config: JobConfig): + self.job_config = job_config + + # Setup device + device_module, device_type = utils.device_module, utils.device_type + self.device = torch.device(f"{device_type}:{int(os.environ.get('LOCAL_RANK', 0))}") + device_module.set_device(self.device) + + logger.info(f"Device: {self.device}") + + # Init distributed (needed for checkpoint loading) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + if world_size > 1 or int(os.environ.get("RANK", 0)) >= 0: + dist_utils.init_distributed( + job_config.comm, + enable_cpu_backend=False, + base_folder=job_config.job.dump_folder, + ) + + # Setup parallel dims (minimal - no parallelism for inference) + self.parallel_dims = ParallelDims( + dp_shard=1, + dp_replicate=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=world_size, + ) + + # Load tokenizer using model's hf_assets_path + tokenizer_path = job_config.model.hf_assets_path + self.tokenizer = HuggingFaceTokenizer(tokenizer_path) + self.tokenizer.image_id = job_config.special_tokens.img_id + + logger.info(f"Tokenizer loaded from: {tokenizer_path}") + logger.info(f"Vocab size: {len(self.tokenizer)}") + logger.info(f"Special tokens - BOS: {self.tokenizer.bos_id}, EOS: {self.tokenizer.eos_id}, PAD: {self.tokenizer.pad_id}") + logger.info(f"Image token ID: {self.tokenizer.image_id}") + + # Load image processor + processor = AutoProcessor.from_pretrained(tokenizer_path) + self.image_processor = processor.image_processor + + # Load chat template + template_path = Path("torchtitan/vlr/smolvlm/datasets/template.jinja") + if template_path.exists(): + with open(template_path, 'r') as f: + self.chat_template = f.read() + logger.info("Chat template loaded") + else: + logger.warning(f"Template not found at {template_path}") + self.chat_template = None + + # Build model + self.model_args = self._get_model_args() + self.model = self._build_model() + + # Load checkpoint using CheckpointManager + self._load_checkpoint() + + self.model.eval() + logger.info("Model loaded and ready") + + def _get_model_args(self): + """Get model args from job config.""" + from torchtitan.protocols import train_spec as train_spec_module + + train_spec = train_spec_module.get_train_spec(self.job_config.model.name) + model_args = train_spec.model_args[self.job_config.model.flavor] + model_args.update_from_config(self.job_config) + + # Override for inference + model_args.use_flex_attn = False + model_args.encoder.use_flex_attn = False + + logger.info(f"Model args: {model_args}") + return model_args + + def _build_model(self): + """Build model using torchtitan's approach.""" + logger.info(f"Building {self.job_config.model.name} {self.job_config.model.flavor}") + + dtype = TORCH_DTYPE_MAP[self.job_config.training.dtype] + + with torch.device("meta"), utils.set_default_dtype(dtype): + model = Llama3Siglip2Transformer(self.model_args) + + # Initialize on device + device_type = utils.device_type + model.to_empty(device=device_type) + with torch.no_grad(): + model.init_encoder_weights(buffer_device=device_type) + + logger.info("Model structure created") + return model + + def _load_checkpoint(self): + """Load checkpoint using CheckpointManager.""" + logger.info("Setting up CheckpointManager") + + # Create state dict adapter if available + sd_adapter = SmolVLMStateDictAdapter( + self.model_args, + self.job_config.model.hf_assets_path + ) + + # Create checkpoint manager + self.checkpointer = CheckpointManager( + dataloader=None, # Not needed for inference + model_parts=[self.model], + optimizers=None, # Not needed for inference + lr_schedulers=None, # Not needed for inference + states={}, # No training state needed + checkpoint_config=self.job_config.checkpoint, + sd_adapter=sd_adapter, + base_folder=self.job_config.job.dump_folder, + ft_manager=None, + ) + + # Load checkpoint + load_step = self.job_config.checkpoint.load_step + logger.info(f"Loading checkpoint at step: {load_step}") + self.checkpointer.load(step=load_step) + logger.info("Checkpoint loaded successfully") + + def prepare_inputs(self, prompt: str, image_path: Optional[str] = None): + """Prepare inputs - debug version.""" + + # Create messages + messages = [{"user": prompt, "assistant": ""}] + + # Apply chat template (without tokenizing first) + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + chat_template=self.chat_template, + add_generation_prompt=True, + ) + + print("\n" + "="*80) + print("FORMATTED TEXT:") + print(repr(text)) + print("="*80) + + # Tokenize + input_ids = self.tokenizer.encode(text) + print(f"\nInput tokens ({len(input_ids)}): {input_ids[:50]}...") + + # Decode to verify + decoded = self.tokenizer.decode(input_ids) + print(f"\nDecoded input:\n{repr(decoded[:200])}...") + + input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device).unsqueeze(0) + + # Process image + pixel_values = None + patch_attention_mask = None + + if image_path: + image = Image.open(image_path).resize((512, 512)) + vision_inputs = self.image_processor([image]) + pixel_values = torch.tensor(np.array(vision_inputs['pixel_values'])).squeeze() + pixel_values = pixel_values.unsqueeze(0).unsqueeze(0).to(self.device, dtype=torch.bfloat16) + + patch_attention_mask = torch.tensor(vision_inputs['pixel_attention_mask']) + patch_attention_mask = patch_attention_mask.unsqueeze(0).unsqueeze(0).to(self.device) + + print(f"\nImage processed. Pixel values shape: {pixel_values.shape}") + + return input_ids, pixel_values, patch_attention_mask + + @torch.no_grad() + def generate_greedy(self, prompt: str, image_path: Optional[str] = None, max_tokens: int = 50): + """Greedy generation with detailed logging.""" + + print("\n" + "="*80) + print("STARTING GENERATION") + print("="*80) + + input_ids, pixel_values, patch_attention_mask = self.prepare_inputs(prompt, image_path) + + print(f"\nInitial input_ids shape: {input_ids.shape}") + print(f"Starting generation loop...\n") + + generated = input_ids.clone() + + for step in range(max_tokens): + # Forward pass + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): + logits = self.model( + input_ids=generated, + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Get next token (greedy) + next_token_logits = logits[:, -1, :] + next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) + + # Decode the token + token_text = self.tokenizer.decode([next_token.item()]) + + print(f"Step {step:3d} | Token: {next_token.item():5d} | Text: {repr(token_text)}") + + # Check for EOS + if next_token.item() == self.tokenizer.eos_id: + print("\n*** EOS token generated ***") + break + + # Append + generated = torch.cat([generated, next_token], dim=1) + + # Check for repetition + if step > 5: + last_tokens = generated[0, -6:].tolist() + if len(set(last_tokens)) <= 2: + print(f"\n*** WARNING: Repetition detected in last 6 tokens: {last_tokens} ***") + + print("\n" + "="*80) + print("GENERATION COMPLETE") + print("="*80) + + # Decode full response + generated_ids = generated[0].tolist() + full_text = self.tokenizer.decode(generated_ids) + + print(f"\nGenerated tokens: {generated_ids}") + print(f"\nFull decoded text:\n{full_text}") + + # Try to extract assistant response + if "<|im_start|>assistant" in full_text: + response = full_text.split("<|im_start|>assistant")[-1] + if "<|im_end|>" in response: + response = response.split("<|im_end|>")[0] + response = response.strip() + print(f"\nExtracted assistant response:\n{response}") + return response + + return full_text + + @torch.no_grad() + def test_forward_pass(self, prompt: str, image_path: Optional[str] = None): + """Test a single forward pass with detailed output.""" + + print("\n" + "="*80) + print("TESTING FORWARD PASS") + print("="*80) + + input_ids, pixel_values, patch_attention_mask = self.prepare_inputs(prompt, image_path) + + print(f"\nInput shapes:") + print(f" input_ids: {input_ids.shape}") + if pixel_values is not None: + print(f" pixel_values: {pixel_values.shape}") + print(f" patch_attention_mask: {patch_attention_mask.shape}") + + # Forward pass + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): + logits = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + print(f"\nOutput logits shape: {logits.shape}") + print(f"Logits dtype: {logits.dtype}") + print(f"Logits range: [{logits.min().item():.2f}, {logits.max().item():.2f}]") + + # Get next token predictions + last_logits = logits[0, -1, :] + print(f"\nLast position logits stats:") + print(f" Mean: {last_logits.mean().item():.4f}") + print(f" Std: {last_logits.std().item():.4f}") + print(f" Min: {last_logits.min().item():.4f}") + print(f" Max: {last_logits.max().item():.4f}") + + # Top 10 tokens + top_logits, top_indices = torch.topk(last_logits, k=10) + print(f"\nTop 10 predicted tokens:") + for i, (logit, idx) in enumerate(zip(top_logits, top_indices)): + token_text = self.tokenizer.decode([idx.item()]) + print(f" {i+1}. Token {idx.item():5d} (logit: {logit.item():7.2f}): {repr(token_text)}") + + return logits + + +def main(): + config_manager = ConfigManager() + job_config = config_manager.parse_args() + + # Initialize logger + init_logger() + + logger.info("Job config loaded:") + logger.info(f" Model: {job_config.model.name} / {job_config.model.flavor}") + logger.info(f" HF assets path: {job_config.model.hf_assets_path}") + logger.info(f" Checkpoint folder: {job_config.checkpoint.folder}") + logger.info(f" Load step: {job_config.checkpoint.load_step}") + + # Create generator + generator = SimpleGenerator(job_config) + + # Run test or generation + if args.test_forward: + generator.test_forward_pass(args.prompt, args.image) + else: + generator.generate_greedy(args.prompt, args.image, args.max_tokens) + + +if __name__ == "__main__": + main() diff --git a/torchtitan/models/README.md b/torchtitan/models/README.md index 467031ce..456fe14b 100644 --- a/torchtitan/models/README.md +++ b/torchtitan/models/README.md @@ -40,7 +40,7 @@ The folder should be organized as follows - `__init__.py` - A dictionary of the actual model configurations, of the type `[str: ModelArgs]`. - Define `get_train_spec` to return a [`TrainSpec`](/torchtitan/protocols/train_spec.py), consisting a tuple of - - model name, model class, model args + - model class, model args - Model name should be the same as the folder name, which should be added to `torchtitan/models/__init__.py` or ``torchtitan/experiments/__init__.py``. - parallelizing function, pipelining function - builder functions for optimizer, lr scheduler, data loader, tokenizer, and loss function diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 277d64be..92285777 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -6,238 +6,183 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from typing import Callable, ClassVar +import functools +from collections.abc import Callable +from typing import ClassVar import torch import torch.nn.functional as F from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention.flex_attention import ( _mask_mod_signature, + AuxOutput, BlockMask, create_block_mask, flex_attention, ) -from torchtitan.tools.utils import has_cuda_capability -# FlexAttention mask type. For each mask type, we initialize it at most once per -# batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to -# track the initialized mask. -FLEX_ATTN_MASK_T = tuple[str, int | None] +__all__ = [ + "FlexAttentionWrapper", + "ScaledDotProductAttentionWrapper", + "get_causal_mask_mod", + "get_document_mask_mod", + "get_fixed_block_mask_mod", + "create_attention_mask", +] -class FlexAttention(torch.nn.Module): - """FlexAttention module that uses torch.nn.attention.flex_attention. +class FlexAttentionWrapper(torch.nn.Module): + """Wrapper around `flex_attention` to make it torch.compile and CP compatible. - This module is a wrapper around torch.nn.attention.flex_attention. This module - implements certain common attention types, such as causal and block_causal. + This wrapper serves two purposes: + 1) Invoke `torch.compile` with a valid mode "max-autotune-no-cudagraphs" to + achieve good performance. + 2) Being a wrapper allows us to apply _ContextParallel to it. - Args: - attn_mask_type (str): The type of attention mask. Currently, we support - "causal" and "block_causal". "causal" means the lower triangle of the - attention matrix is masked. "block_causal" means the attention matrix - is divided into blocks, where block boundary is defined by EOS token, - and the lower triangle of each block is masked. - fixed_block_size (int | None): The block size to be used to perform attention. - If specified, each sequence will be further divided to blocks, where each - block has the maximum size of ``fixed_block_size``. A query will only attend - to the keys within the same block. + Note: + The forward function must have q, k, v as the first three arguments, and + block_mask as a keyword argument to be compatible with _ContextParallel. """ - # We registered flex_attention related attributes as class variables as we - # need to amortize the cost of compilation. - flex_attn: ClassVar[Callable] = torch.compile( + _compiled_flex_attn: ClassVar[Callable] = torch.compile( flex_attention, mode="max-autotune-no-cudagraphs" ) - compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) - used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set() - # Attention mask type to the created BlockMask. - # This allows us to keep track the created block masks for each - # new batch. We will use this to update the block mask when a - # new batch is created. This also allows user to create different - # block masks for different layers. - block_masks: ClassVar[dict[FLEX_ATTN_MASK_T, BlockMask]] = {} - - # Instance variables. - attn_mask_type: str - - def __init__( - self, attn_mask_type: str, fixed_block_size: int | None = None - ) -> None: - super().__init__() - if attn_mask_type not in ["causal", "block_causal"]: - raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") - self.attn_mask_type = attn_mask_type - self.fixed_block_size = fixed_block_size - - FlexAttention.used_attn_mask_types.add(self.mask_key) - - @property - def mask_key(self) -> FLEX_ATTN_MASK_T: - return (self.attn_mask_type, self.fixed_block_size) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + *, + block_mask: BlockMask, scale: float | None = None, - ) -> torch.Tensor: - block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) - - @staticmethod - def _get_causal_mask_mod() -> _mask_mod_signature: - def causal_mask( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ): - return q_idx >= kv_idx - - return causal_mask - - @staticmethod - def _get_block_causal_mask_mod( - batch: torch.Tensor, eos_id: int - ) -> _mask_mod_signature: - # batch is [b, s, h, d] shape - mask = batch == eos_id - mask[:, -1] = True - acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1) - seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32) - seq_idx[:, 1:] = acc_mask[:, :-1] - - def block_causal_mask( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ): - return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx) - - return block_causal_mask - - @staticmethod - def _fixed_block_mask_mod( - mask_mod: _mask_mod_signature, fixed_block_size: int - ) -> _mask_mod_signature: - """ - Given an arbitrary mask_mod, divide the input sequence to blocks - and only allow attention within the same block. - - Args: - mask_mod: The mask mod to apply to the documents - fixed_block_size: The number of tokens in each block. - """ - - # Credit to @drisspg. - def blocked_mask_mod( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ): - # Get the block index of the query and key - q_block = q_idx // fixed_block_size - kv_block = kv_idx // fixed_block_size - # Only allow attention within the same block - same_block = q_block == kv_block - # Apply the original mask mod - inner_mask = mask_mod( - b, h, q_idx % fixed_block_size, kv_idx % fixed_block_size - ) - - return same_block & inner_mask - - blocked_mask_mod.__name__ = ( - f"blocked_mask_mod_{mask_mod.__name__}_fixed_block_size_{fixed_block_size}" + ) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]: + # 1. _compiled_flex_attn has to be a class variable, otherwise there will + # be multiple compiled flex_attention instances, which can be slow. + # 2. `self._compiled_flex_attn` is not correct, `self` will be passed in + # as the first argument, which will cause an error. + # `FlexAttentionWrapper._compiled_flex_attn` is correct. + return FlexAttentionWrapper._compiled_flex_attn( + q, k, v, block_mask=block_mask, scale=scale ) - return blocked_mask_mod - - @staticmethod - @torch.no_grad() - def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: - # batch is [b, s, h, d] shape - for mask_key in FlexAttention.used_attn_mask_types: - attn_mask_type, fixed_block_size = mask_key - match attn_mask_type: - case "causal": - if FlexAttention.block_masks.get(mask_key, None) is not None: - continue - # We don't care about batch dimension -- - # all samples have the same lower triangle mask. - batch_dimension = 1 - mask_mod = FlexAttention._get_causal_mask_mod() - case "block_causal": - if eos_id is None: - raise RuntimeError( - "eos_id must be provided for block_causal mask." - ) - batch_dimension = batch.shape[0] - mask_mod = FlexAttention._get_block_causal_mask_mod(batch, eos_id) - case _: - raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") - - if fixed_block_size is not None and fixed_block_size > 0: - mask_mod = FlexAttention._fixed_block_mask_mod( - mask_mod, fixed_block_size - ) - - seq_len = batch.shape[1] - block_mask = FlexAttention.compiled_create_block_mask( - mask_mod, batch_dimension, None, seq_len, seq_len - ) - FlexAttention.block_masks[mask_key] = block_mask - - -class ScaledDotProductAttention(torch.nn.Module): - backends: ClassVar[list[SDPBackend]] = [] - - def __init__(self, attn_mask_type: str) -> None: + +class ScaledDotProductAttentionWrapper(torch.nn.Module): + """Wrapper around `F.scaled_dot_product_attention` to make it CP compatible. + + This wrapper is needed because `F.scaled_dot_product_attention` is not + a torch.nn.Module, and thus cannot be applied with _ContextParallel. + We need to wrap it into a torch.nn.Module. + + Note: + The forward function must have q, k, v as the first three arguments to be + compatible with _ContextParallel. + """ + + # TODO: remove sdpa_backends after PyTorch 2.9 is released. + sdpa_backends: ClassVar[list[SDPBackend]] = [] + + def __init__(self, is_causal) -> None: super().__init__() - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) - - ScaledDotProductAttention._init_backend() - - @classmethod - def _init_backend(cls) -> None: - if cls.backends: - return - - # Add CuDNN on B200 w/ highest priority - cls.backends = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.MATH, - ] - if has_cuda_capability(10, 0): - cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) + if not self.sdpa_backends: + self.sdpa_backends = [ + SDPBackend.CUDNN_ATTENTION, + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + ] + self.is_causal = is_causal def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + *, scale: float | None = None, ) -> torch.Tensor: - assert self.backends, "SDPA Backends should not be empty." - with sdpa_kernel(self.backends, set_priority=True): - return F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale) - - -def build_attention( - use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None -): - if use_flex_attn: - return FlexAttention(attn_mask_type, fixed_block_size) - else: - if fixed_block_size is not None: - raise ValueError( - "TorchTitan with SDPA currently does not support fixed_block_size." - ) - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) - return ScaledDotProductAttention(attn_mask_type) - - -def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: - FlexAttention.init_attention_mask(batch, eos_id) + with sdpa_kernel(self.sdpa_backends, set_priority=True): + return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=self.is_causal) + + +# We cannot do inner function/closure because we won't be able to cache it -- +# if we an inner function, a new closure will be created every time +# `get_causal_mask_mod` is called. +def _causal_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor +) -> torch.Tensor: + """Causal mask that prevents attention to future tokens.""" + return q_idx >= kv_idx + + +def get_causal_mask_mod() -> _mask_mod_signature: + """Returns a causal mask modifier for flex attention. + + Returns: + A mask modifier function that implements causal masking. + """ + return _causal_mask + + +def get_document_mask_mod(batch: torch.Tensor, eos_id: int) -> _mask_mod_signature: + """Creates a document mask that prevents attention across document boundaries. + + Args: + batch: Input batch tensor with shape [b, s, h, d] + eos_id: End-of-sequence token ID that marks document boundaries + + Returns: + A mask modifier function that implements document-level masking. + """ + # batch is [b, s, h, d] shape + eos_mask = batch == eos_id + eos_mask[:, -1] = True + cumulative_mask = torch.cumsum(torch.where(eos_mask, 1, 0), dim=1) + sequence_indices = torch.zeros_like(cumulative_mask, dtype=torch.int32) + sequence_indices[:, 1:] = cumulative_mask[:, :-1] + + def document_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx] + + return document_mask + + +def get_fixed_block_mask_mod(fixed_block_size: int) -> _mask_mod_signature: + """ + Divide the input sequence into blocks and only allow attention within the same block. + + Args: + fixed_block_size: The number of tokens in each block. + + Returns: + A mask modifier function that implements block-wise attention masking. + """ + + # Credit to @drisspg. + def blocked_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + # Get the block index of the query and key + q_block = q_idx // fixed_block_size + kv_block = kv_idx // fixed_block_size + # Only allow attention within the same block + return q_block == kv_block + + blocked_mask_mod.__name__ = f"blocked_mask_mod_fixed_block_size_{fixed_block_size}" + + return blocked_mask_mod + + +_compiled_create_block_mask = torch.compile(create_block_mask) + + +@functools.lru_cache(4) +def create_attention_mask(*args, **kwargs): + """Create an attention mask using compiled create_block_mask. + + This function is cached to avoid recreating BlockMasks for the same + argumens. + """ + return _compiled_create_block_mask(*args, **kwargs) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index a290ea7e..4e8d500b 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -161,7 +161,6 @@ def get_train_spec() -> TrainSpec: return TrainSpec( - name="deepseek_v3", model_cls=DeepSeekV3Model, model_args=deepseekv3_configs, parallelize_fn=parallelize_deepseekv3, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index c7cd45f4..fc79e5ba 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -84,6 +84,7 @@ def parallelize_deepseekv3( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, + use_flex_attn=use_flex_attn, ) maybe_enable_async_tp(job_config, world_mesh["tp"]) @@ -113,6 +114,7 @@ def parallelize_deepseekv3( model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, ) if model_compile_enabled: @@ -180,6 +182,7 @@ def apply_non_moe_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, + use_flex_attn: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -209,6 +212,18 @@ def apply_non_moe_tp( PrepareModuleInput, ) + if use_flex_attn: + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) + else: + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. @@ -217,8 +232,8 @@ def apply_non_moe_tp( layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate()), - desired_input_layouts=(Replicate(), Replicate()), + input_layouts=(Shard(1), Replicate(), None), + desired_input_layouts=(Replicate(), Replicate(), None), ), # NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor # so that the intermedidate results k is generated as a DTensor and its gradient is @@ -227,11 +242,7 @@ def apply_non_moe_tp( "attention.wkv_b": colwise_parallel(use_local_output=False), "attention.kv_norm": NoParallel(use_local_output=False), # NOTE: use_local_output=True so that the inputs to FlexAttention are plain Tensors - "attention.sdpa": prepare_module_input( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(1), Shard(1), Shard(1)), - use_local_output=True, - ), + "attention.inner_attention": attention_kernel_plan, "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), } diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index dc612faf..d5bc9b10 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -5,13 +5,22 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Tuple import torch from torch import nn -from torchtitan.models.attention import build_attention +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) from torchtitan.models.moe import FeedForward, MoE +from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs @@ -58,7 +67,7 @@ def find_correction_dim( def find_correction_range( low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """ Computes the range of correction dimensions for rotary positional embeddings. @@ -70,7 +79,7 @@ def find_correction_range( max_seq_len (int): Maximum sequence length. Returns: - Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. """ low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) @@ -175,12 +184,17 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): mscale = 0.1 * model_args.mscale * math.log(model_args.rope_factor) + 1.0 self.softmax_scale = self.softmax_scale * mscale * mscale - self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Forward pass for the Multi-Head Latent Attention (MLA) Layer. @@ -231,7 +245,14 @@ def forward( k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) - output = self.sdpa(q, k, v, scale=self.softmax_scale) + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask) + output = self.inner_attention( + q, k, v, block_mask=attention_masks, scale=self.softmax_scale + ) + else: + assert attention_masks is None + output = self.inner_attention(q, k, v, scale=self.softmax_scale) # Reshape and project output output = output.transpose( @@ -284,7 +305,12 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 self.layer_id = layer_id - def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): """ Forward pass for the Transformer block. @@ -295,7 +321,7 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): Returns: torch.Tensor: Output tensor with the same shape as the input. """ - x = x + self.attention(self.attention_norm(x), freqs_cis) + x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) else: @@ -360,9 +386,31 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: b=cutoff_factor * final_out_std, ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] + ) + def forward( self, tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -385,7 +433,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + h = layer(h, self.freqs_cis, attention_masks) h = self.norm(h) if self.norm is not None else h output = self.output(h) if self.output is not None else h return output diff --git a/torchtitan/models/deepseek_v3/model/quantization.py b/torchtitan/models/deepseek_v3/model/quantization.py deleted file mode 100644 index a8ac6003..00000000 --- a/torchtitan/models/deepseek_v3/model/quantization.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torchtitan.tools.logging import logger - -# Fixed block size of 128x128 as specified in the algorithm -BLOCK_SIZE = 128 - - -def calculate_scale_shape( - weight: torch.Tensor, BLOCK_SIZE: int = BLOCK_SIZE -) -> torch.Size: - # Calculate the scale tensor shape - orig_shape = weight.shape - - # Calculate number of blocks needed - block_rows = (orig_shape[0] + BLOCK_SIZE - 1) // BLOCK_SIZE - block_cols = (orig_shape[1] + BLOCK_SIZE - 1) // BLOCK_SIZE - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = torch.Size((block_rows, block_cols)) - - return expected_scale_shape - - -def dequantize_from_fp8( - weight: torch.Tensor, - scale_inv: torch.Tensor, - dtype=torch.bfloat16, - BLOCK_SIZE: int = BLOCK_SIZE, -) -> torch.Tensor: - # Convert to float32 for computation - float_weight = weight.to(torch.float32) - # Get original dimensions - orig_shape = weight.shape - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = calculate_scale_shape(weight, BLOCK_SIZE) - block_rows, block_cols = expected_scale_shape - if scale_inv.shape != expected_scale_shape: - logger.warning( - f"scale_inv shape {scale_inv.shape} doesn't match expected shape {expected_scale_shape}" - ) - - # NOTE: When processing large models on-the-fly, misalignment between block boundaries - # and DTensor local shape partitioning can lead to silent numerical inaccuracies. - dequantized = float_weight.detach().clone().to(dtype=dtype) - - # Apply scaling factors to each block - for i in range(block_rows): - row_start = i * BLOCK_SIZE - row_end = min(row_start + BLOCK_SIZE, orig_shape[0]) - - for j in range(block_cols): - col_start = j * BLOCK_SIZE - col_end = min(col_start + BLOCK_SIZE, orig_shape[1]) - - # Get the block - block = float_weight[row_start:row_end, col_start:col_end] - - scale = scale_inv[i, j] - block = block * scale - - # Explicitly convert block to dtype - block_converted = block.to(dtype=torch.float32) - # Store the dequantized block - dequantized[row_start:row_end, col_start:col_end] = block_converted - - return dequantized diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index b366910f..11d54ffb 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -8,13 +8,14 @@ import re from typing import Any +import torch +from torch.distributed.checkpoint import HuggingFaceStorageReader + from torch.distributed.tensor import DTensor from torchtitan.models.utils import MoEStateDictAdapter from .args import DeepSeekV3ModelArgs -from .quantization import calculate_scale_shape, dequantize_from_fp8 - class DeepSeekV3StateDictAdapter(MoEStateDictAdapter): """ @@ -70,60 +71,33 @@ def __init__( } ) - def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: """ - Dequantize the weights from float8 to float32. + Override default get_hf_storage_reader function to return QuantizedHFStorageReader. """ + if from_quantized: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) - scale_inv_keys = [] - for key, weight in state_dict.items(): - if key.endswith(".weight") and key + "_scale_inv" in state_dict: - scale_inv = state_dict[key + "_scale_inv"] - dequantized_weight = dequantize_from_fp8( - weight, scale_inv, dtype=torch.float32 - ) - # update the weight and remove the scale_inv tensor - state_dict[key] = dequantized_weight - scale_inv_keys.append(key + "_scale_inv") - - for key in scale_inv_keys: - state_dict.pop(key) - - return state_dict - - def _add_quantization_scale_inv_tensors( - self, state_dict: dict[str, Any] - ) -> dict[str, Any]: - """ - Add quantization scale tensors the state_dict. - """ - non_quantized_keys = [ - "input_layernorm.weight", - "post_attention_layernorm.weight", - "norm.weight", - "lm_head.weight", - "embed_tokens.weight", - "mlp.gate.weight", - ] - - weight_scale_inv_state_dict = {} - for key, value in state_dict.items(): - if key.endswith(".weight") and not any( - non_quantized_key in key for non_quantized_key in non_quantized_keys - ): - expected_scale_shape = calculate_scale_shape(value) - # add weight_scale_inv to the state_dict - weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones( - expected_scale_shape, dtype=torch.float32 - ) - - state_dict.update(weight_scale_inv_state_dict) - return state_dict + # NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model. + # If loading checkpoints without quantization, use HuggingFaceStorageReader instead + BLOCK_SIZE = 128 + return QuantizedHuggingFaceStorageReader( + path=path, + target_dtype=torch.float32, + block_size=BLOCK_SIZE, + thread_count=4, + ) + else: + return HuggingFaceStorageReader(path) def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ 1. Convert between the HF shape and the torchtitan shape. - 2. Split the GroupedExperts' weight into separate expert's wegiht. + 2. Split the GroupedExperts' weight into separate expert's weight. """ to_hf_map = {v: k for k, v in self.from_hf_map.items()} @@ -172,24 +146,16 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: new_key = to_hf_map[key] hf_state_dict[new_key] = value - # Prepare for dequantization - hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( - hf_state_dict - ) - return hf_state_dict_with_scale_inv + return hf_state_dict def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ 1. When loading from HF checkpoint, dequantize the weights from float8 to float32. 2. Convert between the HF shape and the torchtitan shape. - 3. Concate separate expert's wegiht into GroupedExperts' weight. + 3. Concat separate expert's weight into GroupedExperts' weight. """ - # dequantize the tensor in state_dict and remove the scale_inv tensor - - hf_state_dict = self._dequantize(hf_state_dict) state_dict = {} - expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} for key, value in hf_state_dict.items(): @@ -215,7 +181,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: layer_num, value.device_mesh, ) - else: # keep this path to be compatibile with offline conversion + else: # keep this path to be compatible with offline conversion stacked_value = self._concatenate_expert_weights( expert_weights_by_layer, titan_abstract_key, diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index c6dee817..9d8625a2 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -65,7 +65,7 @@ mode = "selective" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] -enable=true +enable = true components = ["loss"] # ["model", "loss"] [quantize.linear.float8] diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 3d395b86..2966dbf1 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -51,7 +51,7 @@ "8B": TransformerModelArgs( dim=4096, - ffn_dim=8192, + ffn_dim=14336, n_layers=32, n_heads=32, n_kv_heads=8, @@ -83,7 +83,6 @@ def get_train_spec() -> TrainSpec: return TrainSpec( - name="llama3", model_cls=Transformer, model_args=llama3_configs, parallelize_fn=parallelize_llama, diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 9f98eaf2..4944af56 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -102,6 +102,7 @@ def parallelize_llama( model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP @@ -206,8 +207,8 @@ def apply_tp( layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), + input_layouts=(Shard(1), None, None), + desired_input_layouts=(Replicate(), None, None), ), "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 3273605d..83f1ed05 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -7,7 +7,7 @@ # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from dataclasses import dataclass +from dataclasses import dataclass, field from torch import nn @@ -17,6 +17,14 @@ from torchtitan.tools.logging import logger +@dataclass +class RoPEScalingArgs: + scaling_factor: float = 8.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 4.0 + original_max_position_embeddings: int = 8192 + + @dataclass class TransformerModelArgs(BaseModelArgs): dim: int = 4096 @@ -28,8 +36,9 @@ class TransformerModelArgs(BaseModelArgs): ffn_dim_multiplier: float | None = None norm_eps: float = 1e-5 rope_theta: float = 10000 + rope_scaling_args: RoPEScalingArgs = field(default_factory=RoPEScalingArgs) - ffn_dim: int = 8192 + ffn_dim: int = 1536 max_seq_len: int = 131072 # If `True`, then each transformer block init uses its layer ID, and if diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index b528e2e0..d45d2996 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -6,18 +6,33 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. +import math import torch import torch.nn.functional as F from torch import nn - -from torchtitan.models.attention import build_attention +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) +from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol -from .args import TransformerModelArgs +from .args import RoPEScalingArgs, TransformerModelArgs -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + scaling_args: RoPEScalingArgs = RoPEScalingArgs(), +) -> torch.Tensor: """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -29,11 +44,41 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Te dim (int): Dimension of the frequency tensor. end (int): End index for precomputing frequencies. theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0. - + scaling_args (RoPEScalingArgs | None): RoPE scaling arguments. Defaults to None. + scaling_factor (float): RoPE scaling multiplier; larger values + stretch positions to support longer contexts. Defaults to 8.0. + low_freq_factor (float): Extra scaling applied to the low-frequency + (long-wavelength) RoPE bands. Defaults to 1.0. + high_freq_factor (float): Extra scaling applied to the high-frequency + (short-wavelength) RoPE bands. Defaults to 4.0. + original_max_position_embeddings (int): Maximum position embeddings + for original model. Defaults to 8192. Returns: torch.Tensor: Precomputed frequency tensor with complex exponentials. """ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + + # apply rope scaling + scaling_factor = scaling_args.scaling_factor + low_freq_factor = scaling_args.low_freq_factor + high_freq_factor = scaling_args.high_freq_factor + original_max_position_embeddings = scaling_args.original_max_position_embeddings + wavelen = 2 * math.pi / freqs + high_freq_wavelen = original_max_position_embeddings / high_freq_factor + low_freq_wavelen = original_max_position_embeddings / low_freq_factor + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by scaling factor + freqs = torch.where(wavelen > low_freq_wavelen, freqs / scaling_factor, freqs) + # wavelen in between: linear interpolation of the scaled freqs and the original freqs + smooth_factor = (original_max_position_embeddings / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smoothed_freqs = ( + 1 - smooth_factor + ) * freqs / scaling_factor + smooth_factor * freqs + is_medium_freqs = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + freqs = torch.where(is_medium_freqs, smoothed_freqs, freqs) + t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 @@ -129,11 +174,7 @@ class Attention(nn.Module): def __init__(self, model_args: TransformerModelArgs): super().__init__() self.n_heads = model_args.n_heads - self.n_kv_heads = ( - model_args.n_heads - if model_args.n_kv_heads is None - else model_args.n_kv_heads - ) + self.n_kv_heads = model_args.n_kv_heads self.n_rep = self.n_heads // self.n_kv_heads self.head_dim = model_args.dim // model_args.n_heads @@ -145,7 +186,12 @@ def __init__(self, model_args: TransformerModelArgs): self.wo = nn.Linear( model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper(is_causal=True) def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -156,6 +202,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Forward pass of the attention module. @@ -189,7 +236,16 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - output = self.sdpa(xq, xk, xv) + assert ( + isinstance(attention_masks, BlockMask) or attention_masks is None + ), attention_masks + + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + else: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) output = output.transpose( 1, 2 @@ -218,29 +274,30 @@ class FeedForward(nn.Module): def __init__( self, dim: int, - hidden_dim: int, - multiple_of: int | None=None, - ffn_dim_multiplier: float | None=None + hidden_dim: int | None = None, + multiple_of: int | None = None, + ffn_dim_multiplier: float | None = None, + ffn_dim: int | None = None, ): super().__init__() - """ - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - """ - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + if not ffn_dim: + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + ffn_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, ffn_dim, bias=False) + self.w2 = nn.Linear(ffn_dim, dim, bias=False) + self.w3 = nn.Linear(dim, ffn_dim, bias=False) def forward(self, x): - return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + return self.w2(F.silu(self.w1(x)) * self.w3(x)) def init_weights(self, init_std: float): - nn.init.trunc_normal_(self.gate_proj.weight, mean=0.0, std=0.02) - for linear in (self.up_proj, self.down_proj): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) @@ -271,7 +328,7 @@ def __init__(self, layer_id: int, model_args: TransformerModelArgs): self.attention = Attention(model_args) self.feed_forward = FeedForward( dim=model_args.dim, - hidden_dim=model_args.ffn_dim, + ffn_dim=model_args.ffn_dim, ) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) @@ -285,6 +342,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Perform a forward pass through the TransformerBlock. @@ -297,7 +355,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis) + h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -389,11 +447,34 @@ def _precompute_freqs_cis(self) -> torch.Tensor: # relaxing in our CP enablement PR self.model_args.max_seq_len, self.model_args.rope_theta, + self.model_args.rope_scaling_args, + ) + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) def forward( self, tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -417,7 +498,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + h = layer(h, self.freqs_cis, attention_masks=attention_masks) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/models/llama3/train_configs/llama3_1b.toml b/torchtitan/models/llama3/train_configs/llama3_1b.toml index 56ea5864..4c27a807 100644 --- a/torchtitan/models/llama3/train_configs/llama3_1b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_1b.toml @@ -44,12 +44,14 @@ context_parallel_degree = 1 [checkpoint] enable = true -last_save_in_hf = true folder = "checkpoint" interval = 100 last_save_model_only = true +initial_load_in_hf = true +last_save_in_hf = true export_dtype = "float32" async_mode = "async" # ["disabled", "async", "async_with_pinned_mem"] +exclude_from_loading = ["dataloader", "optimizer", "train_state"] [compile] enable=false diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index ef86d783..3f610e84 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -43,11 +43,12 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable = false +enable = true folder = "checkpoint" +initial_load_in_hf = true +load_only = true interval = 500 -last_save_model_only = true -export_dtype = "float32" +export_dtype = "bfloat16" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [compile] diff --git a/torchtitan/models/llama3_ft/__init__.py b/torchtitan/models/llama3_ft/__init__.py index 1dad5e72..f6337eeb 100644 --- a/torchtitan/models/llama3_ft/__init__.py +++ b/torchtitan/models/llama3_ft/__init__.py @@ -33,12 +33,10 @@ def get_train_spec() -> TrainSpec: return FaultTolerantTrainSpec( - name="llama3_ft", model_cls=Transformer, model_args=llama3_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, - fragment_fn=fragment_llm, build_optimizers_fn=build_optimizers, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, @@ -46,4 +44,5 @@ def get_train_spec() -> TrainSpec: build_loss_fn=build_cross_entropy_loss, build_validator_fn=build_validator, state_dict_adapter=Llama3StateDictAdapter, + fragment_fn=fragment_llm, ) diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index a4f28bc8..a713bec6 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -11,9 +11,16 @@ import torch import torch.nn as nn +from torch.nn.attention.flex_attention import BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer + from torchtitan.config import JobConfig +AttentionMasksType = dict[str, BlockMask] | BlockMask + + @dataclass class BaseModelArgs: """All ModelArgs should inherit from this class. @@ -53,3 +60,13 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: buffer_device: Optional device to place buffers on during initialization. """ pass + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + raise NotImplementedError( + "This model does not support attention masking/Flex Attention." + ) diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index 5b441e9b..e22692bd 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -5,13 +5,14 @@ # LICENSE file in the root directory of this source tree. import json -import logging import os import re from abc import ABC, abstractmethod from typing import Any -logger = logging.getLogger() +from torch.distributed.checkpoint import HuggingFaceStorageReader + +from torchtitan.tools.logging import logger from .model import BaseModelArgs @@ -58,6 +59,21 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ pass + @abstractmethod + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: + """Returns hf storage reader to read HF checkpoint + + Args: + path: the path to read HF checkpoint + + Returns: + The HuggingFace storage reader to read from HF checkpoint + + """ + pass + class StateDictAdapter(BaseStateDictAdapter): """State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping""" @@ -86,3 +102,12 @@ def __init__( self.fqn_to_index_mapping[hf_key] = int(indx) else: self.fqn_to_index_mapping = None + + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: + if from_quantized: + logger.warning( + "Loading from quantized checkpoint format is not supported for this model." + ) + return HuggingFaceStorageReader(path) diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index 71d2a98a..78c9d899 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -42,7 +42,6 @@ @dataclass class TrainSpec: - name: str model_cls: type[ModelProtocol] model_args: Mapping[str, BaseModelArgs] parallelize_fn: ParallelizeFunction @@ -60,13 +59,13 @@ class TrainSpec: _extra_train_specs: dict[str, TrainSpec] = {} -def register_train_spec(train_spec: TrainSpec) -> None: +def register_train_spec(name: str, train_spec: TrainSpec) -> None: global _extra_train_specs - if train_spec.name in _extra_train_specs: - raise ValueError(f"TrainSpec {train_spec.name} is already registered.") + if name in _extra_train_specs: + raise ValueError(f"TrainSpec {name} is already registered.") # user can define a TrainSpec from outside of torchtitan - _extra_train_specs[train_spec.name] = train_spec + _extra_train_specs[name] = train_spec def get_train_spec(name: str) -> TrainSpec: @@ -77,6 +76,7 @@ def get_train_spec(name: str) -> TrainSpec: from torchtitan.experiments import _supported_experiments from torchtitan.models import _supported_models + from torchtitan.vlr import _supported_vlr_models if name in _supported_models: module = import_module(f"torchtitan.models.{name}") @@ -84,5 +84,8 @@ def get_train_spec(name: str) -> TrainSpec: elif name in _supported_experiments: module = import_module(f"torchtitan.experiments.{name}") return module.get_train_spec() + elif name in _supported_vlr_models: + module = import_module(f"torchtitan.vlr.{name}") + return module.get_train_spec() raise ValueError(f"TrainSpec {name} is not registered.") diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 0e851d33..f398dba9 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -14,9 +14,6 @@ from torchtitan.config import Profiling as ProfilingConfig from torchtitan.tools.logging import logger -# the number of warmup steps before the active step in each profiling cycle -WARMUP = 3 - # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 @@ -34,7 +31,11 @@ def maybe_enable_profiling( if enable_profiling: trace_dir = os.path.join(base_folder, profiling_config.save_traces_folder) - profile_freq = profiling_config.profile_freq + profile_freq, warmup, active = ( + profiling_config.profile_freq, + profiling_config.profiler_warmup, + profiling_config.profiler_active, + ) rank = torch.distributed.get_rank() @@ -58,7 +59,6 @@ def trace_handler(prof): if not os.path.exists(trace_dir): os.makedirs(trace_dir, exist_ok=True) - warmup, active = WARMUP, 1 wait = profile_freq - (active + warmup) assert ( wait >= 0 diff --git a/torchtitan/train.py b/torchtitan/train.py index e909964e..ad40e035 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -24,7 +24,6 @@ ) from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils -from torchtitan.models.attention import init_attention_mask from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger @@ -152,7 +151,7 @@ def __init__(self, job_config: JobConfig): self.model_args = model_args logger.info( - f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" + f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}" ) with ( @@ -184,7 +183,7 @@ def __init__(self, job_config: JobConfig): ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) logger.info( - f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " + f"{color.blue}Model {job_config.model.name} {job_config.model.flavor} " f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) @@ -199,7 +198,9 @@ def __init__(self, job_config: JobConfig): init_device = device_type buffer_device = None - self.loss_fn = self.train_spec.build_loss_fn(job_config) + self.loss_fn = self.train_spec.build_loss_fn( + job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager + ) # verify batch sizes global_batch_size = job_config.training.global_batch_size @@ -229,7 +230,7 @@ def __init__(self, job_config: JobConfig): if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: raise RuntimeError( - f"Pipeline Parallel is enabled but {self.train_spec.name} " + f"Pipeline Parallel is enabled but {job_config.model.name} " f"does not support pipelining" ) @@ -416,14 +417,21 @@ def forward_backward_step( inputs = input_dict["input"] extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} - # Create the FlexAttention mask according to the input + # For arguments, like attention_masks, we have to put them in a separate + # dict as extra_inputs are not forwarded to other stages in PP, but + # extra_args are. + extra_args = {} + if getattr(self.model_args, "use_flex_attn", False): - init_attention_mask(inputs, self.tokenizer.eos_id) + extra_args["attention_masks"] = model_parts[0].get_attention_masks( + input_batch=inputs, + tokenizer=self.tokenizer, + extra_inputs=extra_inputs, + ) # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage - inputs = input_dict["input"] - extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} + cp_mesh = parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( cp_mesh=parallel_dims.world_mesh["cp"], @@ -446,13 +454,17 @@ def forward_backward_step( self.pp_schedule.step( inputs, **extra_inputs, + **extra_args, target=targets, losses=losses, input_batch=inputs, ) else: self.pp_schedule.step( - target=targets, losses=losses, input_batch=inputs + **extra_args, + target=targets, + losses=losses, + input_batch=inputs, ) # accumulate losses across pipeline microbatches @@ -470,7 +482,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, **extra_inputs) + pred = model_parts[0](inputs, **extra_inputs, **extra_args) loss = self.loss_fn(pred, labels) # need to free pred before bwd to avoid peaking memory del pred @@ -551,8 +563,8 @@ def train_step( def train(self): job_config = self.job_config - self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Training starts at step {self.step + 1}") + self.checkpointer.load(step=job_config.checkpoint.load_step) leaf_folder = ( "" diff --git a/torchtitan/vlr/__init__.py b/torchtitan/vlr/__init__.py new file mode 100644 index 00000000..429f313f --- /dev/null +++ b/torchtitan/vlr/__init__.py @@ -0,0 +1 @@ +_supported_vlr_models = frozenset(["smolvlm"]) diff --git a/torchtitan/vlr/smolvlm/__init__.py b/torchtitan/vlr/smolvlm/__init__.py index a4147920..1759f0b5 100644 --- a/torchtitan/vlr/smolvlm/__init__.py +++ b/torchtitan/vlr/smolvlm/__init__.py @@ -9,13 +9,14 @@ from torchtitan.components.optimizer import build_optimizers from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.components.validate import build_validator -from torchtitan.protocols.train_spec import register_train_spec, TrainSpec +from torchtitan.protocols.train_spec import TrainSpec from .datasets.mm_datasets import build_mm_dataloader from .infra.parallelize import parallelize_vlm # from .infra.pipeline import pipeline_llama from .model.args import Llama3Siglip2ModelArgs, Siglip2ModelArgs from .model.model import Llama3Siglip2Transformer +from .model.state_dict_adapter import SmolVLMStateDictAdapter __all__ = [ "parallelize_vlm", @@ -35,7 +36,7 @@ ), "256M": Siglip2ModelArgs( dim=768, - ffn_dim=2304, + ffn_dim=3072, n_layers=12, n_heads=12, ) @@ -63,20 +64,21 @@ "256M": Llama3Siglip2ModelArgs( encoder=siglip2_configs["256M"], dim=576, + ffn_dim=1536, n_layers=30, n_heads=9, n_kv_heads=3, - ffn_dim_multiplier=1.3, multiple_of=1024, rope_theta=100000, vocab_size=49280, + use_flex_attn = False, + attn_mask_type = "causal", ), } -register_train_spec( - TrainSpec( - name="llama3-siglip2", +def get_train_spec() -> TrainSpec: + return TrainSpec( model_cls=Llama3Siglip2Transformer, model_args=llama3_siglip2_configs, parallelize_fn=parallelize_vlm, @@ -87,6 +89,5 @@ build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, build_validator_fn=build_validator, - # state_dict_adapter=Llama3StateDictAdapter, + state_dict_adapter=SmolVLMStateDictAdapter, ) -) diff --git a/torchtitan/vlr/smolvlm/datasets/mm_datasets.py b/torchtitan/vlr/smolvlm/datasets/mm_datasets.py index 952d5005..840de827 100644 --- a/torchtitan/vlr/smolvlm/datasets/mm_datasets.py +++ b/torchtitan/vlr/smolvlm/datasets/mm_datasets.py @@ -440,5 +440,6 @@ def build_mm_dataloader( ) for sample in dataset: - print(sample) + #print(sample) + print(sample['input_ids'].v) exit() diff --git a/torchtitan/vlr/smolvlm/datasets/template.jinja b/torchtitan/vlr/smolvlm/datasets/template.jinja index 50b69a4b..01872f5f 100644 --- a/torchtitan/vlr/smolvlm/datasets/template.jinja +++ b/torchtitan/vlr/smolvlm/datasets/template.jinja @@ -1,2 +1,2 @@ -{%- for message in messages %}{{'<|im_start|>user' + '\\n' + message['user'] + '<|im_end|>' }} -{{'<|im_start|>assistant' + '\\n' + message['assistant'] + '<|im_end|>' }}{%- endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %} \ No newline at end of file +{%- for message in messages %}{{'<|im_start|>user' + message['user'] + '' + '' }} +{%- endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %} diff --git a/torchtitan/vlr/smolvlm/model/args.py b/torchtitan/vlr/smolvlm/model/args.py index 4386619c..a4f6e48c 100644 --- a/torchtitan/vlr/smolvlm/model/args.py +++ b/torchtitan/vlr/smolvlm/model/args.py @@ -22,7 +22,7 @@ class Siglip2ModelArgs: patch_size: int = 16 image_size: int = 512 - scale_factor: int = 2 + scale_factor: int = 4 layer_norm_eps: float = 1e-6 use_flex_attn: bool = True @@ -35,6 +35,7 @@ class Llama3Siglip2ModelArgs(Llama3Args): tokenizer_name: str = 'HuggingFaceTB/SmolLM2-360M-Instruct' img_token_id: int = 49190 vocab_size: int = 49280 + ffn_dim: int = 1536 def update_from_config(self, job_config: JobConfig, **kwargs) -> None: super().update_from_config(job_config, **kwargs) diff --git a/torchtitan/vlr/smolvlm/model/model.py b/torchtitan/vlr/smolvlm/model/model.py index 042b2f0a..18337170 100644 --- a/torchtitan/vlr/smolvlm/model/model.py +++ b/torchtitan/vlr/smolvlm/model/model.py @@ -8,19 +8,23 @@ import torch from torch import nn -from torchtitan.models.attention import init_attention_mask +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.models.attention import ScaledDotProductAttentionWrapper from torchtitan.models.llama3 import Transformer as Llama3 from .args import Llama3Siglip2ModelArgs, Siglip2ModelArgs from .siglip2 import VisionTransformer +import lovely_tensors as lt +lt.monkey_patch() + class SmolVLMSimpleMLP(nn.Module): def __init__(self, config): super().__init__() # TODO: scale_factor to config - input_size = config.encoder.dim * (config.encoder.scale_factor**2) + input_size = 12288 output_size = config.dim - self.proj = nn.Linear(input_size, output_size, bias=False) + self.proj = nn.Linear(12288, 576, bias=False) def init_weights(self): nn.init.trunc_normal_(self.proj.weight, mean=0.0, std=0.02) @@ -35,7 +39,7 @@ def __init__(self, config): self.scale_factor = config.encoder.scale_factor self.modality_projection = SmolVLMSimpleMLP(config) - def pixel_shuffle(self, x, scale_factor=2): + def pixel_shuffle(self, x, scale_factor=4): bsz, seq, embed_dim = x.size() height = width = int(seq**0.5) x = x.view(bsz, height, width, embed_dim) @@ -107,9 +111,11 @@ def get_image_features( pixel_attention_mask ): batch_size, num_images, num_channels, height, width = pixel_values.shape - pixel_values = pixel_values.bfloat16() # fp16 compatibility + pixel_values = pixel_values.to(dtype=torch.bfloat16) # fp16 compatibility pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + patch_size = 16 + # Remove padding images - padding images are full 0. nb_values_per_image = pixel_values.shape[1:].numel() real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image @@ -119,7 +125,7 @@ def get_image_features( real_images_inds[0] = True pixel_values = pixel_values[real_images_inds].contiguous() - + # Handle the vision attention mask if pixel_attention_mask is None: pixel_attention_mask = torch.ones( size=[pixel_values.shape[i] for i in (0, 2, 3)], @@ -131,13 +137,12 @@ def get_image_features( pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:]) pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() - patch_size = 16 patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() image_hidden_states = self.encoder(pixel_values, patch_attention_mask) - #image_hidden_states = image_hidden_states.last_hidden_state + print('v', image_hidden_states) image_hidden_states = image_hidden_states.bfloat16() image_hidden_states = self.projector(image_hidden_states) @@ -146,30 +151,33 @@ def get_image_features( def forward( self, input_ids: torch.Tensor, - eos_id: int | None = None, - input_batch: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, patch_attention_mask: torch.BoolTensor | None = None, - #grid_thw: torch.Tensor | None = None, + attention_masks: AttentionMasksType | None = None, ): - if self.model_args.use_flex_attn: - init_attention_mask( - input_batch if input_batch is not None else input_ids, eos_id=self.eos_id - ) - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages hidden_states = self.tok_embeddings(input_ids) if self.tok_embeddings else input_ids if self.encoder is not None and pixel_values is not None: vision_tokens = self.get_image_features(pixel_values, patch_attention_mask) + print('v2', vision_tokens) hidden_states = self._fuse_vision_text(hidden_states, vision_tokens, input_ids) + print('h', hidden_states) + + is_first_layer = True for layer in self.layers.values(): - hidden_states = layer(hidden_states, self.freqs_cis) + hidden_states = layer(hidden_states, self.freqs_cis, attention_masks=attention_masks) + + if is_first_layer: + print('d1', hidden_states) + is_first_layer = False + + print('d29', hidden_states) hidden_states = self.norm(hidden_states) - output = self.output(hidden_states) - return output + logits = self.output(hidden_states) + return logits if __name__ == "__main__": @@ -201,6 +209,8 @@ def forward( n_heads=9, n_kv_heads=3, ffn_dim=1536, + use_flex_attn = False, + attn_mask_type = "causal", ), } diff --git a/torchtitan/vlr/smolvlm/model/siglip2.py b/torchtitan/vlr/smolvlm/model/siglip2.py index e94ec5f0..d606ee82 100644 --- a/torchtitan/vlr/smolvlm/model/siglip2.py +++ b/torchtitan/vlr/smolvlm/model/siglip2.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from torch import nn -from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.models.attention import ScaledDotProductAttentionWrapper from .args import Siglip2ModelArgs @@ -63,9 +63,18 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() + step_h = 1.0 / nb_patches_h + step_w = 1.0 / nb_patches_w + h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype) w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype) + fractional_coords_h = h_indices * step_h + fractional_coords_w = w_indices * step_w + + fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6)) + fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6)) + fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6) fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6) @@ -79,6 +88,30 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B return embeddings +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + scaling: float, + dropout: float = 0.0, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + """ + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + """ + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + class Attention(nn.Module): """ Multi-head attention module. @@ -100,28 +133,31 @@ def __init__(self, args: Siglip2ModelArgs): super().__init__() self.dim = args.dim self.head_dim = args.dim // args.n_heads + self.num_heads = args.n_heads + + self.scale = self.head_dim**-.5 self.q_proj = nn.Linear(self.dim, self.dim) self.k_proj = nn.Linear(self.dim, self.dim) self.v_proj = nn.Linear(self.dim, self.dim) self.out_proj = nn.Linear(self.dim, self.dim) - self.attn = build_attention( - use_flex_attn=False, attn_mask_type=args.attn_mask_type - ) + self.attn = ScaledDotProductAttentionWrapper(is_causal=False) + + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): + batch_size, seq_length, embed_dim = hidden_states.shape - def forward(self, x: torch.Tensor): - xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - # Use self.head_dim instead of `n_heads` to infer the actual - # local heads from sizes of xq, xk, and xv as TP may have sharded them - # after the above linear ops. - xq = E.rearrange(xq, "b l (h d) -> b h l d", d=self.head_dim) - xk = E.rearrange(xk, "b l (h d) -> b h l d", d=self.head_dim) - xv = E.rearrange(xv, "b l (h d) -> b h l d", d=self.head_dim) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - output = self.attn(xq, xk, xv) - output = E.rearrange(output, "b h l d -> b l (h d)").contiguous() + output = eager_attention_forward(self, queries, keys, values, attention_mask, self.scale) + + output = output.reshape(batch_size, seq_length, embed_dim).contiguous() return self.out_proj(output) @@ -160,10 +196,22 @@ def __init__(self, args: Siglip2ModelArgs): self.layer_norm2 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) self.mlp = FeedForward(args) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.self_attn(self.layer_norm1(x)) - x = x + self.mlp(self.layer_norm2(x)) - return x + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states def init_weights(self): self.layer_norm1.reset_parameters() @@ -194,7 +242,7 @@ def forward( h = self.embeddings(pixel_values, patch_attention_mask) for layer in self.layers.values(): - h = layer(h) + h = layer(h, patch_attention_mask) h = self.post_layernorm(h) return h diff --git a/torchtitan/vlr/smolvlm/model/state_dict_adapter.py b/torchtitan/vlr/smolvlm/model/state_dict_adapter.py index c994a869..32508847 100644 --- a/torchtitan/vlr/smolvlm/model/state_dict_adapter.py +++ b/torchtitan/vlr/smolvlm/model/state_dict_adapter.py @@ -26,19 +26,57 @@ def __init__( self.model_args = model_args self.hf_assets_path = hf_assets_path self.from_hf_map = { - "model.embed_tokens.weight": "tok_embeddings.weight", - "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", - "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", - "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", - "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - "model.layers.{}.self_attn.rotary_emb.inv_freq": None, - "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", - "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", - "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", - "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", - "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", - "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", + + "model.text_model.embed_tokens.weight": "tok_embeddings.weight", # check + + "model.text_model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", # check + "model.text_model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", # check + "model.text_model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", # check + "model.text_model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", # check + + #"model.layers.{}.self_attn.rotary_emb.inv_freq": None, + + "model.text_model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", # check + "model.text_model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", # check + "model.text_model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", # check + + "model.text_model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", # check + "model.text_model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", # check + + "model.text_model.norm.weight": "norm.weight", # check + + "model.vision_model.embeddings.patch_embedding.weight": "encoder.embeddings.patch_embedding.weight", + "model.vision_model.embeddings.patch_embedding.bias": "encoder.embeddings.patch_embedding.bias", + + "model.vision_model.embeddings.position_embedding.weight": "encoder.embeddings.position_embedding.weight", + + "model.vision_model.post_layernorm.weight": "encoder.post_layernorm.weight", + "model.vision_model.post_layernorm.bias": "encoder.post_layernorm.bias", + + "model.vision_model.encoder.layers.{}.layer_norm1.weight": "encoder.layers.{}.layer_norm1.weight", + "model.vision_model.encoder.layers.{}.layer_norm1.bias": "encoder.layers.{}.layer_norm1.bias", + "model.vision_model.encoder.layers.{}.layer_norm2.weight": "encoder.layers.{}.layer_norm2.weight", + "model.vision_model.encoder.layers.{}.layer_norm2.bias": "encoder.layers.{}.layer_norm2.bias", + + "model.vision_model.encoder.layers.{}.mlp.fc1.weight": "encoder.layers.{}.mlp.fc1.weight", + "model.vision_model.encoder.layers.{}.mlp.fc1.bias": "encoder.layers.{}.mlp.fc1.bias", + "model.vision_model.encoder.layers.{}.mlp.fc2.weight": "encoder.layers.{}.mlp.fc2.weight", + "model.vision_model.encoder.layers.{}.mlp.fc2.bias": "encoder.layers.{}.mlp.fc2.bias", + + "model.vision_model.encoder.layers.{}.self_attn.k_proj.weight": "encoder.layers.{}.self_attn.k_proj.weight", + "model.vision_model.encoder.layers.{}.self_attn.k_proj.bias": "encoder.layers.{}.self_attn.k_proj.bias", + + "model.vision_model.encoder.layers.{}.self_attn.out_proj.weight": "encoder.layers.{}.self_attn.out_proj.weight", + "model.vision_model.encoder.layers.{}.self_attn.out_proj.bias": "encoder.layers.{}.self_attn.out_proj.bias", + + "model.vision_model.encoder.layers.{}.self_attn.q_proj.weight": "encoder.layers.{}.self_attn.q_proj.weight", + "model.vision_model.encoder.layers.{}.self_attn.q_proj.bias": "encoder.layers.{}.self_attn.q_proj.bias", + + "model.vision_model.encoder.layers.{}.self_attn.v_proj.weight": "encoder.layers.{}.self_attn.v_proj.weight", + "model.vision_model.encoder.layers.{}.self_attn.v_proj.bias": "encoder.layers.{}.self_attn.v_proj.bias", + + "model.connector.modality_projection.proj.weight": "projector.modality_projection.proj.weight", } # HuggingFace permutation function (exact copy from their conversion script) diff --git a/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml b/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml index b3afdd29..c4ceb4b1 100644 --- a/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml +++ b/torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml @@ -4,7 +4,7 @@ custom_args_module = "torchtitan.vlr.smolvlm.assets.job_config" [job] -dump_folder = "./outputs_large" +dump_folder = "/data/users/tockier/outputs/" description = "Llama 3 Siglip2 VLM training" print_args = false @@ -23,11 +23,11 @@ save_tb_folder = "tb" enable_wandb = true [model] -name = "llama3-siglip2" +name = "smolvlm" flavor = "256M" # test folder with tokenizer.json, for debug purpose only # hf_assets_path = "torchtitan/experiments/vlm/assets/tokenizer" -hf_assets_path = "./assets/hf/SmolLM2-360M-Instruct" +hf_assets_path = "./assets/hf/SmolVLM2-256M-Video-Instruct" # converters = ["float8"] [optimizer] @@ -42,8 +42,8 @@ decay_type = "cosine" min_lr_factor = 0.0 [training] -local_batch_size = 9 -seq_len = 2048 +local_batch_size = 2 +seq_len = 4096 # packing_buffer_size = 100 max_norm = 1.0 # grad norm clipping steps = 13100 @@ -72,10 +72,11 @@ enable = true folder = "checkpoint" interval = 50 last_save_model_only = false -#initial_load_in_hf = true -#last_save_in_hf = true -export_dtype = "bfloat16" +initial_load_in_hf = true +last_save_in_hf = true +export_dtype = "float32" async_mode = "async" # ["disabled", "async", "async_with_pinned_mem"] +exclude_from_loading = ["dataloader", "optimizer", "train_state"] [activation_checkpoint] mode = "selective" # ["none", "selective", "full"]