Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/extension.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The extension points and protocols mentioned in this note are subject to change.
The coarse level abstraction tries to hit a balance between flexible component swapping and a straightforward train script ([train.py](../torchtitan/train.py)).
Note that among all training components, currently [`CheckpointManager`](../torchtitan/components/checkpoint.py) and [`FTManager`](../torchtitan/components/ft.py) are not configurable since we do not expect them to be customized, but we are open to requests.

To register a `TrainSpec`, please follow the example of [Llama 3.1](../torchtitan/models/llama3/__init__.py) to `register_train_spec`. Please make sure the registration code is called before training initialization. In torchtitan, it is performed during [module import](../torchtitan/__init__.py).
To register a `TrainSpec`, please use the `register_train_spec` API, and make sure registration happens before `get_train_spec` is called during training initialization. In torchtitan, `get_train_spec` will dynamically look for models in `torchtitan/models` or `torchtitan/experiments`.


### `ModelConverter`
Expand Down
24 changes: 24 additions & 0 deletions run_generate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -ex

# use envs as local overwrites for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
NGPU=${NGPU:-"1"}
export LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/vlr/smolvlm/train_configs/llama_siglip_256.toml"}
INFERENCE_FILE=${INFERENCE_FILE:-"torchtitan.generate"}


NCCL_P2P_DISABLE=1 \
TORCH_NCCL_DUMP_ON_TIMEOUT=1 \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
-m ${INFERENCE_FILE} --job.config_file ${CONFIG_FILE} "$@" \
--checkpoint.exclude-from-loading dataloader,optimizer,lr_scheduler \
23 changes: 23 additions & 0 deletions run_generate_llama3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/usr/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -ex

# use envs as local overwrites for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
NGPU=${NGPU:-"1"}
export LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_8b.toml"}
INFERENCE_FILE=${INFERENCE_FILE:-"torchtitan.generate_llama3"}


NCCL_P2P_DISABLE=1 \
TORCH_NCCL_DUMP_ON_TIMEOUT=1 \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
-m ${INFERENCE_FILE} --job.config_file ${CONFIG_FILE} "$@"
2 changes: 1 addition & 1 deletion scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def estimate_memory(job_config: JobConfig):
else contextlib.nullcontext()
):
logger.info(
f"Building {train_spec.name} {job_config.model.flavor} with {model_args}"
f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}"
)
with torch.device("meta"):
model = train_spec.model_cls(model_args)
Expand Down
28 changes: 23 additions & 5 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from torchtitan.tools.logging import init_logger, logger
from torchtitan.tools.utils import device_module, device_type

from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.tokenizer import HuggingFaceTokenizer

# support running w/o installing as package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
Expand Down Expand Up @@ -143,11 +146,26 @@ def test_generate(

state_dict = model.state_dict()

# Checkpoint Loading
begin = time.monotonic()
logger.info(f"Loading chkpt at: {checkpoint_path}")
dcp.load(state_dict, checkpoint_id=checkpoint_path)
logger.info(f"Finished loading chkpt in {time.monotonic() - begin:.2f} seconds.")
# Setup checkpoint manager for loading
checkpointer = CheckpointManager(
dataloader=None, # No dataloader needed for generation
model_parts=[model],
optimizers=None, # No optimizer needed for generation
lr_schedulers=None, # No lr_scheduler needed for generation
states={},
checkpoint_config=config.checkpoint,
sd_adapter=(
train_spec.state_dict_adapter(
model_args, config.model.hf_assets_path
)
),
base_folder=config.job.dump_folder,
ft_manager=None, # No fault tolerance for generation
)

# Load checkpoint
checkpointer.load(step=config.checkpoint.load_step)
logger.info(f"Loaded checkpoint from step {config.checkpoint.load_step}")

device_mem_stats = device_memory_monitor.get_peak_stats()
logger.info(
Expand Down
6 changes: 2 additions & 4 deletions tests/unit_tests/test_train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class TestTrainSpec:
def test_register_train_spec(self):
fake_config = {"fake": BaseModelArgs()}
spec = TrainSpec(
name="fake",
model_cls=FakeModel,
model_args=fake_config,
parallelize_fn=parallelize_llama,
Expand All @@ -87,7 +86,7 @@ def test_register_train_spec(self):
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
register_train_spec(spec)
register_train_spec("fake", spec)
new_spec = get_train_spec("fake")
assert new_spec == spec

Expand All @@ -98,7 +97,6 @@ def test_optim_hook(self):
fake_config = {"fake": BaseModelArgs()}

spec = TrainSpec(
name="fake2",
model_cls=FakeModel,
model_args=fake_config,
parallelize_fn=parallelize_llama,
Expand All @@ -109,7 +107,7 @@ def test_optim_hook(self):
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
register_train_spec(spec)
register_train_spec("fake2", spec)
new_spec = get_train_spec("fake2")

model = new_spec.model_cls(BaseModelArgs())
Expand Down
26 changes: 20 additions & 6 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint import (
HuggingFaceStorageReader,
HuggingFaceStorageWriter,
)
from torch.distributed.checkpoint import HuggingFaceStorageWriter
from torch.distributed.checkpoint._consolidate_hf_safetensors import (
consolidate_safetensors_files_on_every_rank,
)
Expand Down Expand Up @@ -249,6 +246,9 @@ def load_state_dict(state_dict):
self.initial_load_model_only = checkpoint_config.initial_load_model_only
self.initial_load_in_hf = checkpoint_config.initial_load_in_hf
self.initial_load_path = checkpoint_config.initial_load_path
self.initial_load_in_hf_quantized = (
checkpoint_config.initial_load_in_hf_quantized
)
self.last_save_model_only = checkpoint_config.last_save_model_only
self.last_save_in_hf = checkpoint_config.last_save_in_hf
if self.last_save_in_hf:
Expand Down Expand Up @@ -339,7 +339,7 @@ def dcp_save(
checkpoint_id (str): The checkpoint id to save.
async_mode (AsyncMode): Whether the checkpoint is async.
enable_garbage_collection (bool): Whether to enable garbage collection after save.
to_hf (bool): Whether to save in HF model definition and safetensors format.
to_hf (bool): Whether to save in HF mel definition and safetensors format.

Returns:
Future: The future object if the checkpoint is async, otherwise None.
Expand Down Expand Up @@ -418,6 +418,7 @@ def dcp_load(
state_dict: dict[str, Any],
checkpoint_id: str,
from_hf: bool,
from_quantized: bool,
) -> None:
"""Load the checkpoint with dcp.
Args:
Expand All @@ -432,10 +433,13 @@ def dcp_load(
self.sd_adapter is not None
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
hf_state_dict = self.sd_adapter.to_hf(state_dict)
hf_storage_reader = self.sd_adapter.get_hf_storage_reader(
checkpoint_id, from_quantized
)

dcp.load(
hf_state_dict,
storage_reader=HuggingFaceStorageReader(path=checkpoint_id),
storage_reader=hf_storage_reader,
)

state_dict = self.sd_adapter.from_hf(hf_state_dict)
Expand Down Expand Up @@ -544,13 +548,21 @@ def load(self, step: int = -1) -> bool:

model_only = False
from_hf = False
from_quantized = False
if not os.path.exists(self.folder):
model_only = self.initial_load_model_only
from_hf = self.initial_load_in_hf
from_quantized = self.initial_load_in_hf_quantized
if from_hf:
assert (
model_only
), "Only model can be loaded when loading from HF's safetensors checkpoint."

if from_quantized:
assert (
from_hf
), "Quantized checkpoint can only be loaded from HuggingFace format."

if self.initial_load_path:
checkpoint_id = self.initial_load_path
if not os.path.isdir(checkpoint_id):
Expand Down Expand Up @@ -602,6 +614,7 @@ def load(self, step: int = -1) -> bool:
states,
checkpoint_id=checkpoint_id,
from_hf=from_hf,
from_quantized=from_quantized,
)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
Expand Down Expand Up @@ -679,6 +692,7 @@ def _ft_load(self) -> None:
checkpoint_id=checkpoint_id,
# FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader.
from_hf=False,
from_quantized=False,
)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/components/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor
)


def build_cross_entropy_loss(job_config: JobConfig):
def build_cross_entropy_loss(job_config: JobConfig, **kwargs):
del kwargs # delete any unused arguments
loss_fn = cross_entropy_loss
if job_config.compile.enable and "loss" in job_config.compile.components:
logger.info("Compiling the loss function with torch.compile")
Expand Down
53 changes: 51 additions & 2 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ class Profiling:
profile_freq: int = 10
"""How often to collect profile traces, in iterations"""

profiler_active: int = 1
"""
The steps profiler is active for.

This is used to configure torch.profile.schedule.
"""

profiler_warmup: int = 3
"""
The number of warmup steps before the active step in each profiling cycle.

This is used to configure torch.profile.schedule.
"""

enable_memory_snapshot: bool = False
"""Whether to dump memory snapshot"""

Expand Down Expand Up @@ -289,9 +303,11 @@ class Parallelism:
within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward,
trading off memory and communication. See torch's `fully_shard` API for more documentation
on `reshard_after_forward`.

The supported policies include "default", "always" and "never":

- "default" applies default resharding behavior, implementing "smart defaults" for known optimal
scenarios.
scenarios.
- "always" will enable `reshard_after_forward` for all forward passes.
- "never" will disable `reshard_after_forward` for all forward passes.
"""
Expand Down Expand Up @@ -399,15 +415,21 @@ class Parallelism:
expert_parallel_degree: int = 1
"""
Expert parallelism degree. 1 means disabled. No effect for non-MoE models.

Currently, it is supported with the following constraints:

- when etp = tp:

- cp <= ep <= dp_shard * cp
- ep % cp == 0
- dp_shard * cp % ep == 0

- when etp = 1:

- cp * tp <= ep <= dp_shard * cp * tp
- ep % (cp * tp) == 0
- dp_shard * cp * tp % ep == 0

Note that this is still an experimental feature. Some constraints will be
relaxed soon when we have more flexible DeviceMesh support.
"""
Expand Down Expand Up @@ -473,6 +495,14 @@ class Checkpoint:
non-tensors. The default value is False.
"""

initial_load_in_hf_quantized: bool = False
"""
Enable loading of HuggingFace's safetensors format with quantized state dict keys. The option
is only used when `initial_load_path` and `initial_load_path_in_hf` is specified. This will load
checkpoints in HF's model definition and dequantize on model weights if necessary. To support
this parameter, the model need to define proper HuggingFaceStorageReader to perform dequantize.
"""

last_save_model_only: bool = True
"""
When last_save_model_only=True, only the model will be saved at the end of training,
Expand Down Expand Up @@ -501,6 +531,7 @@ class Checkpoint:
async_mode: Literal["disabled", "async", "async_with_pinned_mem"] = "disabled"
"""
Which async checkpoint mode to use. Currently there are 3 different modes.

- "disabled": synchronized checkpointing will be used.
- "async": torch.distributed.checkpoint.async_save will be used.
- "async_with_pinned_mem": this option utilizes a dedicated pinned memory space and creates a
Expand Down Expand Up @@ -558,7 +589,7 @@ class Checkpoint:

@dataclass
class ActivationCheckpoint:
mode: Literal["selective", "full", "none"] = "selective"
mode: Literal["selective", "full", "memory_budget", "none"] = "selective"
"""Type of activation checkpointing to use"""

selective_ac_option: str = "2"
Expand Down Expand Up @@ -587,6 +618,24 @@ class ActivationCheckpoint:
rematerialized.
"""

memory_budget: float = 0.5
"""
When mode is set to "memory_budget", this value determines how much
partitioner in the compiler should trade off compute for memory.
0.0 corresponds to the activation memory from applying
activation checkpointing to the full compiled region, and 1.0 corresponds to
the activation memory from the default runtime-optimized strategy. Read here:
https://pytorch.org/blog/activation-checkpointing-techniques/
"""

visualize_memory_budget_pareto: bool = False
"""
This dumps out a SVG visualization of the expected runtime vs. activation
memory tradeoffs for all memory budget values from 0 to 1 in increments of
0.05 in {--job.dump_folder}/memory_budget_pareto folder. See an example here:
https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015
"""


@dataclass
class Compile:
Expand Down
Loading