Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions Dockerfile.sft
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
FROM nvcr.io/nvidian/tegra-audio/speech-mlops-llm:sft_train

WORKDIR /workspace

COPY . /workspace/NeMo

RUN pip install retrying

RUN pip install multi-storage-client
#RUN pip install /workspace/NeMo/multi-storage-client/.

RUN pip install pytorch_lightning==2.4.0
RUN pip install lightning>=2.0.0
RUN pip install apex
RUN pip install --upgrade megatron-core[all]

RUN pip install "boto3>=1.36"

ENV CXXFLAGS="-std=c++17"

# RUN pip install aistore

#RUN pip install /workspace/nemo-with-storageclient/megatron-lm-msc/.

RUN pip install /workspace/NeMo/.

ENV MSC_CONFIG /lustrefs/users/anmaster/avm_msc_config.json

ENV MSC_PROFILE_NAME curr
177 changes: 133 additions & 44 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@
# since PyTorch 2.3 the path has changed
from torch.amp.grad_scaler import _refresh_per_optimizer_state

from concurrent.futures import ThreadPoolExecutor

import multistorageclient as msc
from multistorageclient.types import MSC_PROTOCOL

from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.collections.nlp.modules.common.megatron.transformer import AutocastTransformerLayer, ParallelTransformerLayer
from nemo.collections.nlp.parts import utils_funcs
Expand Down Expand Up @@ -203,6 +208,7 @@ class NLPDDPStrategy(DDPStrategy):
with FP32 gradient accumulation.
nccl_communicator_config_path: Path to the yaml file with NCCL communicator options
sharp: Apply SHARP to NCCL data-parallel communication.
multistorageclient_enabled: Whether to use multistorageclient for checkpointing
"""

def __init__(
Expand Down Expand Up @@ -1031,6 +1037,33 @@ def restore_checkpoint_after_setup(self) -> bool:
return True


def msc_download_dir(url: str, local_path: str):
logging.warning(f"Running msc_download_dir url {url} rank {torch.distributed.get_rank()}")

if not msc.os.path.exists(url):
raise Exception(f"Download Path doesn't exist: {url}")

base_name = os.path.basename(url) # url = "msc://my-profile/path/to/data", base_name = "data"
files = msc.list(url)

def download_file(item):
"""Helper function to download a single file."""
file_name = item.key # item.key = "msc://profile/path/to/data/file1.txt"
base_name_idx = file_name.find(base_name) # base_name_idx = 23
local_file_path = (
f"{local_path}/{file_name[base_name_idx:]}" # local_file_path = f"{local_path}/data/file1.txt"
)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
msc.download_file(item, local_file_path)
# msc.download_file(f"{MSC_PROTOCOL}{get_profile()}/{file_name}", local_file_path)

# Use ThreadPoolExecutor for par allel downloads
with ThreadPoolExecutor(max_workers=32) as executor: # Adjust max_workers as needed
executor.map(download_file, files)

logging.warning(f"msc_download_dir completed rank {torch.distributed.get_rank()}")


class NLPSaveRestoreConnector(SaveRestoreConnector):
"""Custom connector to support saving and restoring states."""

Expand Down Expand Up @@ -1067,11 +1100,19 @@ def save_to(self, model, save_path: str):
if (app_state.model_parallel_size is not None and app_state.model_parallel_size > 1) or dist_ckpt:

dir_name = os.path.dirname(save_path)
is_msc_enabled = False
if MSC_PROTOCOL in dir_name:
is_msc_enabled = True

# dist ckpt calls save on every rank
if dist_ckpt:
# model weights is a directory
dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt))

if is_msc_enabled:
filename = os.path.join(dir_name, self.model_weights_ckpt)
dist_ckpt_dir = os.path.splitext(filename)[0]

# dist checkpoint needs torch.distributed to save the checkpoint
if not parallel_state.is_initialized():

Expand Down Expand Up @@ -1127,19 +1168,33 @@ def dummy():

if should_move_data:
with tempfile.TemporaryDirectory() as tmpdir:

if dist_ckpt:
shutil.move(str(dist_ckpt_dir), tmpdir)
if is_msc_enabled:
msc_download_dir(dist_ckpt_dir, tmpdir)
else:
shutil.move(str(dist_ckpt_dir), tmpdir)

elif app_state.pipeline_model_parallel_size == 1:
# move weights to the tmpdir
for tp_rank in range(app_state.tensor_model_parallel_size):
os.makedirs(os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}'))
mp_model_weights = os.path.join(
dir_name, f'mp_rank_{tp_rank:02d}_' + self.model_weights_ckpt
)
shutil.move(
mp_model_weights,
os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}', self.model_weights_ckpt),
)

if is_msc_enabled:
print(f"Downloading {mp_model_weights} to {tmpdir}")
msc_dest = os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}', self.model_weights_ckpt)
logging.warning(
f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}"
)
msc_download_dir(mp_model_weights, msc_dest)
else:
shutil.move(
mp_model_weights,
os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}', self.model_weights_ckpt),
)
else:
# move weights to the tmpdir
for tp_rank, pp_rank in itertools.product(
Expand All @@ -1150,12 +1205,23 @@ def dummy():
mp_model_weights = os.path.join(
dir_name, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}_' + self.model_weights_ckpt
)
shutil.move(
mp_model_weights,
os.path.join(

if is_msc_enabled:
print(f"Downloading {mp_model_weights} to {tmpdir}")
msc_dest = os.path.join(
tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}', self.model_weights_ckpt
),
)
)
logging.warning(
f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}"
)
msc_download_dir(mp_model_weights, msc_dest)
else:
shutil.move(
mp_model_weights,
os.path.join(
tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}', self.model_weights_ckpt
),
)

# create config and artifacts in tmpdir
config_yaml = os.path.join(tmpdir, self.model_config_yaml)
Expand Down Expand Up @@ -1300,15 +1366,32 @@ def _load_state_dict_from_disk(self, model_weights, map_location=None):
uninject_model_weights = uninject_model_parallel_rank(model_weights)

# legacy model_weights will have mp rank injected
if os.path.isfile(model_weights):
if msc.os.path.isfile(model_weights):
return super()._load_state_dict_from_disk(model_weights, map_location)

# dist checkpoint will be a dir
elif os.path.isdir(os.path.splitext(uninject_model_weights)[0]):
elif msc.os.path.isdir(os.path.splitext(uninject_model_weights)[0]):
return None
else:
raise ValueError(f'Expected {model_weights} to be a file or directory.')

def _download_nemo_file(self, restore_path: str, tmpdir: str) -> str:
# .nemo filename
fname = os.path.basename(restore_path)

# check if msc path exists
if not msc.os.path.exists(restore_path):
raise FileNotFoundError(f".nemo file doesn't exist at {restore_path}")

# download .nemo file to tempdir
os.makedirs(tmpdir, exist_ok=True)
logging.warning(f"Starting .nemo download {restore_path}")
msc.download_file(restore_path, f"{tmpdir}/{fname}")

# update restore_path to point to downloaded .nemo
updated_restore_path = os.path.join(tmpdir, fname)
logging.warning(f".nemo download complete; updated_restore_path to {updated_restore_path}")
return updated_restore_path

def restore_from(
self,
calling_cls,
Expand Down Expand Up @@ -1343,36 +1426,41 @@ def restore_from(
Returns:
An instance of type cls or its underlying config (if return_config is set).
"""
# tempdir creation is moved here so that updated restore_path can be passed to super().load_config_and_state_dict
# since .nemo file is in the object store, the .nemo file first needs to be downloaded
with tempfile.TemporaryDirectory() as tmpdir:
if MSC_PROTOCOL in restore_path:
restore_path = self._download_nemo_file(restore_path=restore_path, tmpdir=tmpdir)

# Get path where the command is executed - the artifacts will be "retrieved" there
# (original .nemo behavior)
loaded_params = super().load_config_and_state_dict(
calling_cls,
restore_path,
override_config_path,
map_location,
strict,
return_config,
trainer,
validate_access_integrity,
)
if not isinstance(loaded_params, tuple) or return_config is True:
return loaded_params
conf, instance, state_dict = loaded_params

# Get path where the command is executed - the artifacts will be "retrieved" there
# (original .nemo behavior)
loaded_params = super().load_config_and_state_dict(
calling_cls,
restore_path,
override_config_path,
map_location,
strict,
return_config,
trainer,
validate_access_integrity,
)
if not isinstance(loaded_params, tuple) or return_config is True:
return loaded_params
conf, instance, state_dict = loaded_params

# if we're using dist checkpointing then state_dict will be None
if state_dict is None:
# dist checkpointing needs torch.distributed to load the checkpoint
if not parallel_state.is_initialized():
# if we're using dist checkpointing then state_dict will be None
if state_dict is None:
# dist checkpointing needs torch.distributed to load the checkpoint
if not parallel_state.is_initialized():

def dummy():
return
def dummy():
return

if trainer.strategy.launcher is not None:
trainer.strategy.launcher.launch(dummy, trainer=trainer)
trainer.strategy.setup_environment()
if trainer.strategy.launcher is not None:
trainer.strategy.launcher.launch(dummy, trainer=trainer)
trainer.strategy.setup_environment()

with tempfile.TemporaryDirectory() as tmpdir:
# with tempfile.TemporaryDirectory() as tmpdir:
# Check if self.model_extracted_dir is set, and is a valid path
if self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir):
# Log that NeMo will use the provided `model_extracted_dir`
Expand Down Expand Up @@ -1423,11 +1511,12 @@ def dummy():
if hasattr(instance, 'setup_transformer_engine_tp_groups'):
instance.setup_transformer_engine_tp_groups()

else:
state_dict = self.modify_state_dict(conf, state_dict)
super().load_instance_with_state_dict(instance, state_dict, strict)
logging.info(f'Model {instance.__class__.__name__} was successfully restored from {restore_path}.')
return instance
else:
state_dict = self.modify_state_dict(conf, state_dict)
super().load_instance_with_state_dict(instance, state_dict, strict)

logging.info(f'Model {instance.__class__.__name__} was successfully restored from {restore_path}.')
return instance


class PipelineMixedPrecisionPlugin(MixedPrecisionPlugin):
Expand Down
5 changes: 3 additions & 2 deletions nemo/core/connectors/save_restore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import uuid
from contextlib import contextmanager
from typing import Callable, Generator, Optional, Set, Union

import torch
from lightning.pytorch.trainer.trainer import Trainer
from omegaconf import DictConfig, OmegaConf
Expand Down Expand Up @@ -743,11 +744,11 @@ def _unpack_nemo_file_with_multistorageclient(

@staticmethod
def _save_state_dict_to_disk(state_dict, filepath):
torch.save(state_dict, filepath)
multistorageclient.torch.save(state_dict, filepath)

@staticmethod
def _load_state_dict_from_disk(model_weights, map_location=None):
return torch.load(model_weights, map_location='cpu', weights_only=False)
return multistorageclient.torch.load(model_weights, map_location='cpu', weights_only=False)

@property
def model_config_yaml(self) -> str:
Expand Down
7 changes: 5 additions & 2 deletions nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch import Callback
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
from multistorageclient.types import MSC_PROTOCOL

from nemo.utils import logging

Expand Down Expand Up @@ -276,8 +277,10 @@ def save_checkpoint(
path (_PATH): checkpoint directory
storage_options (Any, optional): Optional parameters when saving the checkpoint
"""
fs = get_filesystem(path)
fs.makedirs(path, exist_ok=True)

if MSC_PROTOCOL not in path:
fs = get_filesystem(path)
fs.makedirs(path, exist_ok=True)

validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure)
self.validated_consistency = True
Expand Down
24 changes: 19 additions & 5 deletions nemo/utils/callbacks/nemo_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
maybe_injected_best_model_path = self.best_model_path

if self.save_best_model:

if self.multistorageclient_enabled:
if not multistorageclient.os.path.exists(maybe_injected_best_model_path):
return

if not os.path.exists(maybe_injected_best_model_path):
return

Expand All @@ -237,7 +242,9 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):

self.previous_best_path = self.best_model_path
old_state_dict = deepcopy(pl_module.state_dict())
checkpoint = torch.load(maybe_injected_best_model_path, map_location='cpu', weights_only=False)
checkpoint = multistorageclient.torch.load(
maybe_injected_best_model_path, map_location='cpu', weights_only=False
)
if 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
# get a new instanace of the model
Expand Down Expand Up @@ -289,8 +296,16 @@ def on_train_end(self, trainer, pl_module):
"were found. Saving latest model instead."
)
else:
if os.path.isdir(self.best_model_path.split('.ckpt')[0]):
self.best_model_path = self.best_model_path.split('.ckpt')[0]
if self.multistorageclient_enabled:
if multistorageclient.os.path.exists(self.best_model_path) and multistorageclient.os.path.isdir(
self.best_model_path
):
self.best_model_path = self.best_model_path.split('.ckpt')[0]

else:
if os.path.isdir(self.best_model_path.split('.ckpt')[0]):
self.best_model_path = self.best_model_path.split('.ckpt')[0]

self.best_model_path = trainer.strategy.broadcast(self.best_model_path)
trainer._checkpoint_connector.restore(self.best_model_path)

Expand Down Expand Up @@ -529,7 +544,7 @@ def file_exists(
) -> bool:
"""Checks if a file or a file without a suffix (distributed checkpoint) exists."""
if self.multistorageclient_enabled:
exists = self._fs.exists(filepath)
exists = self._fs.exists(filepath) # todo(avm): unsure if we need this check
else:
exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(ckpt_to_dir(filepath)))

Expand Down Expand Up @@ -651,7 +666,6 @@ def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool:
def _saved_checkpoint_paths(self) -> Iterable[Path]:
# distributed checkpoints are directories so we check for them here
# we filter out unfinished checkpoints, these should be deleted during next cleanup

if self.multistorageclient_enabled:
# TODO: support multistorageclient distributed checkpointing
return NeMoModelCheckpoint._derive_saved_checkpoint_paths_with_multistorageclient(self.dirpath)
Expand Down
Loading
Loading