diff --git a/Dockerfile.sft b/Dockerfile.sft new file mode 100644 index 000000000000..b148f06c9ad5 --- /dev/null +++ b/Dockerfile.sft @@ -0,0 +1,29 @@ +FROM nvcr.io/nvidian/tegra-audio/speech-mlops-llm:sft_train + +WORKDIR /workspace + +COPY . /workspace/NeMo + +RUN pip install retrying + +RUN pip install multi-storage-client +#RUN pip install /workspace/NeMo/multi-storage-client/. + +RUN pip install pytorch_lightning==2.4.0 +RUN pip install lightning>=2.0.0 +RUN pip install apex +RUN pip install --upgrade megatron-core[all] + +RUN pip install "boto3>=1.36" + +ENV CXXFLAGS="-std=c++17" + +# RUN pip install aistore + +#RUN pip install /workspace/nemo-with-storageclient/megatron-lm-msc/. + +RUN pip install /workspace/NeMo/. + +ENV MSC_CONFIG /lustrefs/users/anmaster/avm_msc_config.json + +ENV MSC_PROFILE_NAME curr diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 4d265b750ff6..7ec88d6b63cc 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -64,6 +64,11 @@ # since PyTorch 2.3 the path has changed from torch.amp.grad_scaler import _refresh_per_optimizer_state +from concurrent.futures import ThreadPoolExecutor + +import multistorageclient as msc +from multistorageclient.types import MSC_PROTOCOL + from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.transformer import AutocastTransformerLayer, ParallelTransformerLayer from nemo.collections.nlp.parts import utils_funcs @@ -203,6 +208,7 @@ class NLPDDPStrategy(DDPStrategy): with FP32 gradient accumulation. nccl_communicator_config_path: Path to the yaml file with NCCL communicator options sharp: Apply SHARP to NCCL data-parallel communication. + multistorageclient_enabled: Whether to use multistorageclient for checkpointing """ def __init__( @@ -1031,6 +1037,33 @@ def restore_checkpoint_after_setup(self) -> bool: return True +def msc_download_dir(url: str, local_path: str): + logging.warning(f"Running msc_download_dir url {url} rank {torch.distributed.get_rank()}") + + if not msc.os.path.exists(url): + raise Exception(f"Download Path doesn't exist: {url}") + + base_name = os.path.basename(url) # url = "msc://my-profile/path/to/data", base_name = "data" + files = msc.list(url) + + def download_file(item): + """Helper function to download a single file.""" + file_name = item.key # item.key = "msc://profile/path/to/data/file1.txt" + base_name_idx = file_name.find(base_name) # base_name_idx = 23 + local_file_path = ( + f"{local_path}/{file_name[base_name_idx:]}" # local_file_path = f"{local_path}/data/file1.txt" + ) + os.makedirs(os.path.dirname(local_file_path), exist_ok=True) + msc.download_file(item, local_file_path) + # msc.download_file(f"{MSC_PROTOCOL}{get_profile()}/{file_name}", local_file_path) + + # Use ThreadPoolExecutor for par allel downloads + with ThreadPoolExecutor(max_workers=32) as executor: # Adjust max_workers as needed + executor.map(download_file, files) + + logging.warning(f"msc_download_dir completed rank {torch.distributed.get_rank()}") + + class NLPSaveRestoreConnector(SaveRestoreConnector): """Custom connector to support saving and restoring states.""" @@ -1067,11 +1100,19 @@ def save_to(self, model, save_path: str): if (app_state.model_parallel_size is not None and app_state.model_parallel_size > 1) or dist_ckpt: dir_name = os.path.dirname(save_path) + is_msc_enabled = False + if MSC_PROTOCOL in dir_name: + is_msc_enabled = True # dist ckpt calls save on every rank if dist_ckpt: # model weights is a directory dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt)) + + if is_msc_enabled: + filename = os.path.join(dir_name, self.model_weights_ckpt) + dist_ckpt_dir = os.path.splitext(filename)[0] + # dist checkpoint needs torch.distributed to save the checkpoint if not parallel_state.is_initialized(): @@ -1127,8 +1168,13 @@ def dummy(): if should_move_data: with tempfile.TemporaryDirectory() as tmpdir: + if dist_ckpt: - shutil.move(str(dist_ckpt_dir), tmpdir) + if is_msc_enabled: + msc_download_dir(dist_ckpt_dir, tmpdir) + else: + shutil.move(str(dist_ckpt_dir), tmpdir) + elif app_state.pipeline_model_parallel_size == 1: # move weights to the tmpdir for tp_rank in range(app_state.tensor_model_parallel_size): @@ -1136,10 +1182,19 @@ def dummy(): mp_model_weights = os.path.join( dir_name, f'mp_rank_{tp_rank:02d}_' + self.model_weights_ckpt ) - shutil.move( - mp_model_weights, - os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}', self.model_weights_ckpt), - ) + + if is_msc_enabled: + print(f"Downloading {mp_model_weights} to {tmpdir}") + msc_dest = os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}', self.model_weights_ckpt) + logging.warning( + f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}" + ) + msc_download_dir(mp_model_weights, msc_dest) + else: + shutil.move( + mp_model_weights, + os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}', self.model_weights_ckpt), + ) else: # move weights to the tmpdir for tp_rank, pp_rank in itertools.product( @@ -1150,12 +1205,23 @@ def dummy(): mp_model_weights = os.path.join( dir_name, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}_' + self.model_weights_ckpt ) - shutil.move( - mp_model_weights, - os.path.join( + + if is_msc_enabled: + print(f"Downloading {mp_model_weights} to {tmpdir}") + msc_dest = os.path.join( tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}', self.model_weights_ckpt - ), - ) + ) + logging.warning( + f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}" + ) + msc_download_dir(mp_model_weights, msc_dest) + else: + shutil.move( + mp_model_weights, + os.path.join( + tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}', self.model_weights_ckpt + ), + ) # create config and artifacts in tmpdir config_yaml = os.path.join(tmpdir, self.model_config_yaml) @@ -1300,15 +1366,32 @@ def _load_state_dict_from_disk(self, model_weights, map_location=None): uninject_model_weights = uninject_model_parallel_rank(model_weights) # legacy model_weights will have mp rank injected - if os.path.isfile(model_weights): + if msc.os.path.isfile(model_weights): return super()._load_state_dict_from_disk(model_weights, map_location) - # dist checkpoint will be a dir - elif os.path.isdir(os.path.splitext(uninject_model_weights)[0]): + elif msc.os.path.isdir(os.path.splitext(uninject_model_weights)[0]): return None else: raise ValueError(f'Expected {model_weights} to be a file or directory.') + def _download_nemo_file(self, restore_path: str, tmpdir: str) -> str: + # .nemo filename + fname = os.path.basename(restore_path) + + # check if msc path exists + if not msc.os.path.exists(restore_path): + raise FileNotFoundError(f".nemo file doesn't exist at {restore_path}") + + # download .nemo file to tempdir + os.makedirs(tmpdir, exist_ok=True) + logging.warning(f"Starting .nemo download {restore_path}") + msc.download_file(restore_path, f"{tmpdir}/{fname}") + + # update restore_path to point to downloaded .nemo + updated_restore_path = os.path.join(tmpdir, fname) + logging.warning(f".nemo download complete; updated_restore_path to {updated_restore_path}") + return updated_restore_path + def restore_from( self, calling_cls, @@ -1343,36 +1426,41 @@ def restore_from( Returns: An instance of type cls or its underlying config (if return_config is set). """ + # tempdir creation is moved here so that updated restore_path can be passed to super().load_config_and_state_dict + # since .nemo file is in the object store, the .nemo file first needs to be downloaded + with tempfile.TemporaryDirectory() as tmpdir: + if MSC_PROTOCOL in restore_path: + restore_path = self._download_nemo_file(restore_path=restore_path, tmpdir=tmpdir) + + # Get path where the command is executed - the artifacts will be "retrieved" there + # (original .nemo behavior) + loaded_params = super().load_config_and_state_dict( + calling_cls, + restore_path, + override_config_path, + map_location, + strict, + return_config, + trainer, + validate_access_integrity, + ) + if not isinstance(loaded_params, tuple) or return_config is True: + return loaded_params + conf, instance, state_dict = loaded_params - # Get path where the command is executed - the artifacts will be "retrieved" there - # (original .nemo behavior) - loaded_params = super().load_config_and_state_dict( - calling_cls, - restore_path, - override_config_path, - map_location, - strict, - return_config, - trainer, - validate_access_integrity, - ) - if not isinstance(loaded_params, tuple) or return_config is True: - return loaded_params - conf, instance, state_dict = loaded_params - - # if we're using dist checkpointing then state_dict will be None - if state_dict is None: - # dist checkpointing needs torch.distributed to load the checkpoint - if not parallel_state.is_initialized(): + # if we're using dist checkpointing then state_dict will be None + if state_dict is None: + # dist checkpointing needs torch.distributed to load the checkpoint + if not parallel_state.is_initialized(): - def dummy(): - return + def dummy(): + return - if trainer.strategy.launcher is not None: - trainer.strategy.launcher.launch(dummy, trainer=trainer) - trainer.strategy.setup_environment() + if trainer.strategy.launcher is not None: + trainer.strategy.launcher.launch(dummy, trainer=trainer) + trainer.strategy.setup_environment() - with tempfile.TemporaryDirectory() as tmpdir: + # with tempfile.TemporaryDirectory() as tmpdir: # Check if self.model_extracted_dir is set, and is a valid path if self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir): # Log that NeMo will use the provided `model_extracted_dir` @@ -1423,11 +1511,12 @@ def dummy(): if hasattr(instance, 'setup_transformer_engine_tp_groups'): instance.setup_transformer_engine_tp_groups() - else: - state_dict = self.modify_state_dict(conf, state_dict) - super().load_instance_with_state_dict(instance, state_dict, strict) - logging.info(f'Model {instance.__class__.__name__} was successfully restored from {restore_path}.') - return instance + else: + state_dict = self.modify_state_dict(conf, state_dict) + super().load_instance_with_state_dict(instance, state_dict, strict) + + logging.info(f'Model {instance.__class__.__name__} was successfully restored from {restore_path}.') + return instance class PipelineMixedPrecisionPlugin(MixedPrecisionPlugin): diff --git a/nemo/core/connectors/save_restore_connector.py b/nemo/core/connectors/save_restore_connector.py index 46710480a093..9edbb9b1b12b 100644 --- a/nemo/core/connectors/save_restore_connector.py +++ b/nemo/core/connectors/save_restore_connector.py @@ -22,6 +22,7 @@ import uuid from contextlib import contextmanager from typing import Callable, Generator, Optional, Set, Union + import torch from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, OmegaConf @@ -743,11 +744,11 @@ def _unpack_nemo_file_with_multistorageclient( @staticmethod def _save_state_dict_to_disk(state_dict, filepath): - torch.save(state_dict, filepath) + multistorageclient.torch.save(state_dict, filepath) @staticmethod def _load_state_dict_from_disk(model_weights, map_location=None): - return torch.load(model_weights, map_location='cpu', weights_only=False) + return multistorageclient.torch.load(model_weights, map_location='cpu', weights_only=False) @property def model_config_yaml(self) -> str: diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index 48f0db017a3c..b031c53097ea 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -25,6 +25,7 @@ from lightning.fabric.utilities.types import _PATH from lightning.pytorch import Callback from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO +from multistorageclient.types import MSC_PROTOCOL from nemo.utils import logging @@ -276,8 +277,10 @@ def save_checkpoint( path (_PATH): checkpoint directory storage_options (Any, optional): Optional parameters when saving the checkpoint """ - fs = get_filesystem(path) - fs.makedirs(path, exist_ok=True) + + if MSC_PROTOCOL not in path: + fs = get_filesystem(path) + fs.makedirs(path, exist_ok=True) validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure) self.validated_consistency = True diff --git a/nemo/utils/callbacks/nemo_model_checkpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py index b028c02e8714..8114766e24cd 100644 --- a/nemo/utils/callbacks/nemo_model_checkpoint.py +++ b/nemo/utils/callbacks/nemo_model_checkpoint.py @@ -228,6 +228,11 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint): maybe_injected_best_model_path = self.best_model_path if self.save_best_model: + + if self.multistorageclient_enabled: + if not multistorageclient.os.path.exists(maybe_injected_best_model_path): + return + if not os.path.exists(maybe_injected_best_model_path): return @@ -237,7 +242,9 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint): self.previous_best_path = self.best_model_path old_state_dict = deepcopy(pl_module.state_dict()) - checkpoint = torch.load(maybe_injected_best_model_path, map_location='cpu', weights_only=False) + checkpoint = multistorageclient.torch.load( + maybe_injected_best_model_path, map_location='cpu', weights_only=False + ) if 'state_dict' in checkpoint: checkpoint = checkpoint['state_dict'] # get a new instanace of the model @@ -289,8 +296,16 @@ def on_train_end(self, trainer, pl_module): "were found. Saving latest model instead." ) else: - if os.path.isdir(self.best_model_path.split('.ckpt')[0]): - self.best_model_path = self.best_model_path.split('.ckpt')[0] + if self.multistorageclient_enabled: + if multistorageclient.os.path.exists(self.best_model_path) and multistorageclient.os.path.isdir( + self.best_model_path + ): + self.best_model_path = self.best_model_path.split('.ckpt')[0] + + else: + if os.path.isdir(self.best_model_path.split('.ckpt')[0]): + self.best_model_path = self.best_model_path.split('.ckpt')[0] + self.best_model_path = trainer.strategy.broadcast(self.best_model_path) trainer._checkpoint_connector.restore(self.best_model_path) @@ -529,7 +544,7 @@ def file_exists( ) -> bool: """Checks if a file or a file without a suffix (distributed checkpoint) exists.""" if self.multistorageclient_enabled: - exists = self._fs.exists(filepath) + exists = self._fs.exists(filepath) # todo(avm): unsure if we need this check else: exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(ckpt_to_dir(filepath))) @@ -651,7 +666,6 @@ def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool: def _saved_checkpoint_paths(self) -> Iterable[Path]: # distributed checkpoints are directories so we check for them here # we filter out unfinished checkpoints, these should be deleted during next cleanup - if self.multistorageclient_enabled: # TODO: support multistorageclient distributed checkpointing return NeMoModelCheckpoint._derive_saved_checkpoint_paths_with_multistorageclient(self.dirpath) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index fc082ecf5831..4da0202eb12a 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -57,6 +57,14 @@ "megatron.core.num_microbatches_calculator", "get_current_global_batch_size" ) +try: + import multistorageclient + from multistorageclient.types import MSC_PROTOCOL as MUTLISTORAGECLIENT_PROTOCOL + + MUTLISTORAGECLIENT_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + MUTLISTORAGECLIENT_AVAILABLE = False + try: # `ptl_resiliency` is included in `gwe_resiliency_pkg` package @@ -918,7 +926,6 @@ def check_resume( # If we are using S3 checkpointing, we want check_resume to only execute on a single rank # to avoid throttling S3. - if is_global_rank_zero() or not (is_s3_url(dirpath) and is_multistorageclient_url(dirpath)): checkpoint_dir_exists = False if is_s3_url(dirpath):