From d08cbbd8c45174e307c27611799589e20999c71d Mon Sep 17 00:00:00 2001 From: Ankit Master Date: Thu, 8 May 2025 09:47:51 -0700 Subject: [PATCH 1/3] feat: ported changes, unsure if they work --- nemo/collections/nlp/parts/nlp_overrides.py | 181 +++++++++++++----- .../core/connectors/save_restore_connector.py | 77 +++++++- nemo/utils/callbacks/dist_ckpt_io.py | 7 +- nemo/utils/callbacks/nemo_model_checkpoint.py | 144 ++++++++++---- nemo/utils/data_utils.py | 5 +- nemo/utils/exp_manager.py | 27 ++- 6 files changed, 350 insertions(+), 91 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 4d265b750ff6..9a72507d7128 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -73,6 +73,10 @@ from nemo.utils import AppState, logging from nemo.utils.model_utils import ckpt_to_dir, inject_model_parallel_rank, uninject_model_parallel_rank +from concurrent.futures import ThreadPoolExecutor +from multistorageclient.types import MSC_PROTOCOL +import multistorageclient as msc + try: from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam @@ -203,6 +207,7 @@ class NLPDDPStrategy(DDPStrategy): with FP32 gradient accumulation. nccl_communicator_config_path: Path to the yaml file with NCCL communicator options sharp: Apply SHARP to NCCL data-parallel communication. + multistorageclient_enabled: Whether to use multistorageclient for checkpointing """ def __init__( @@ -1031,6 +1036,32 @@ def restore_checkpoint_after_setup(self) -> bool: return True +def msc_download_dir(url: str, local_path: str): + logging.warning(f"Running msc_download_dir url {url} rank {torch.distributed.get_rank()}") + + if not msc.os.path.exists(url): + raise Exception(f"Download Path doesn't exist: {url}") + + base_name = os.path.basename(url) #url = "msc://my-profile/path/to/data", base_name = "data" + files = msc.list(url) + + def download_file(item): + """Helper function to download a single file.""" + file_name = item.key #item.key = "msc://profile/path/to/data/file1.txt" + base_name_idx = file_name.find(base_name) # base_name_idx = 23 + local_file_path = f"{local_path}/{file_name[base_name_idx:]}" #local_file_path = f"{local_path}/data/file1.txt" + os.makedirs(os.path.dirname(local_file_path), exist_ok=True) + msc.download_file(item, local_file_path) + #msc.download_file(f"{MSC_PROTOCOL}{get_profile()}/{file_name}", local_file_path) + + # Use ThreadPoolExecutor for par allel downloads + with ThreadPoolExecutor(max_workers=32) as executor: # Adjust max_workers as needed + executor.map(download_file, files) + + logging.warning(f"msc_download_dir completed rank {torch.distributed.get_rank()}") + + + class NLPSaveRestoreConnector(SaveRestoreConnector): """Custom connector to support saving and restoring states.""" @@ -1052,6 +1083,7 @@ def __init__(self) -> None: ) super().__init__() + def save_to(self, model, save_path: str): """Save model to save path.""" app_state = AppState() @@ -1067,11 +1099,20 @@ def save_to(self, model, save_path: str): if (app_state.model_parallel_size is not None and app_state.model_parallel_size > 1) or dist_ckpt: dir_name = os.path.dirname(save_path) - + is_msc_enabled = False + if MSC_PROTOCOL in dir_name: + is_msc_enabled = True + # dist ckpt calls save on every rank if dist_ckpt: # model weights is a directory dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt)) + + if is_msc_enabled: + filename = os.path.join(dir_name, self.model_weights_ckpt) + dist_ckpt_dir = os.path.splitext(filename)[0] + + # dist checkpoint needs torch.distributed to save the checkpoint if not parallel_state.is_initialized(): @@ -1127,8 +1168,13 @@ def dummy(): if should_move_data: with tempfile.TemporaryDirectory() as tmpdir: + if dist_ckpt: - shutil.move(str(dist_ckpt_dir), tmpdir) + if is_msc_enabled: + msc_download_dir(dist_ckpt_dir, tmpdir) + else: + shutil.move(str(dist_ckpt_dir), tmpdir) + elif app_state.pipeline_model_parallel_size == 1: # move weights to the tmpdir for tp_rank in range(app_state.tensor_model_parallel_size): @@ -1136,10 +1182,17 @@ def dummy(): mp_model_weights = os.path.join( dir_name, f'mp_rank_{tp_rank:02d}_' + self.model_weights_ckpt ) - shutil.move( - mp_model_weights, - os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}', self.model_weights_ckpt), - ) + + if is_msc_enabled: + print(f"Downloading {mp_model_weights} to {tmpdir}") + msc_dest=os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}', self.model_weights_ckpt) + logging.warning(f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}") + msc_download_dir(mp_model_weights, msc_dest) + else: + shutil.move( + mp_model_weights, + os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}', self.model_weights_ckpt), + ) else: # move weights to the tmpdir for tp_rank, pp_rank in itertools.product( @@ -1150,12 +1203,19 @@ def dummy(): mp_model_weights = os.path.join( dir_name, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}_' + self.model_weights_ckpt ) - shutil.move( - mp_model_weights, - os.path.join( - tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}', self.model_weights_ckpt - ), - ) + + if is_msc_enabled: + print(f"Downloading {mp_model_weights} to {tmpdir}") + msc_dest = os.path.join(tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}', self.model_weights_ckpt) + logging.warning(f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}") + msc_download_dir(mp_model_weights, msc_dest) + else: + shutil.move( + mp_model_weights, + os.path.join( + tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}', self.model_weights_ckpt + ), + ) # create config and artifacts in tmpdir config_yaml = os.path.join(tmpdir, self.model_config_yaml) @@ -1300,15 +1360,36 @@ def _load_state_dict_from_disk(self, model_weights, map_location=None): uninject_model_weights = uninject_model_parallel_rank(model_weights) # legacy model_weights will have mp rank injected - if os.path.isfile(model_weights): + if msc.os.path.isfile(model_weights): return super()._load_state_dict_from_disk(model_weights, map_location) - # dist checkpoint will be a dir - elif os.path.isdir(os.path.splitext(uninject_model_weights)[0]): + elif msc.os.path.isdir(os.path.splitext(uninject_model_weights)[0]): return None else: raise ValueError(f'Expected {model_weights} to be a file or directory.') + + def _download_nemo_file(self, + restore_path: str, + tmpdir: str) -> str: + # .nemo filename + fname = os.path.basename(restore_path) + + #check if msc path exists + if not msc.os.path.exists(restore_path): + raise FileNotFoundError(f".nemo file doesn't exist at {restore_path}") + + #download .nemo file to tempdir + os.makedirs(tmpdir, exist_ok=True) + logging.warning(f"Starting .nemo download {restore_path}") + msc.download_file(restore_path, f"{tmpdir}/{fname}") + + #update restore_path to point to downloaded .nemo + updated_restore_path = os.path.join(tmpdir, fname) + logging.warning(f".nemo download complete; updated_restore_path to {updated_restore_path}") + return updated_restore_path + + def restore_from( self, calling_cls, @@ -1343,37 +1424,42 @@ def restore_from( Returns: An instance of type cls or its underlying config (if return_config is set). """ + # tempdir creation is moved here so that updated restore_path can be passed to super().load_config_and_state_dict + # since .nemo file is in the object store, the .nemo file first needs to be downloaded + with tempfile.TemporaryDirectory() as tmpdir: + if MSC_PROTOCOL in restore_path: + restore_path = self._download_nemo_file(restore_path=restore_path, tmpdir=tmpdir) + + # Get path where the command is executed - the artifacts will be "retrieved" there + # (original .nemo behavior) + loaded_params = super().load_config_and_state_dict( + calling_cls, + restore_path, + override_config_path, + map_location, + strict, + return_config, + trainer, + validate_access_integrity, + ) + if not isinstance(loaded_params, tuple) or return_config is True: + return loaded_params + conf, instance, state_dict = loaded_params - # Get path where the command is executed - the artifacts will be "retrieved" there - # (original .nemo behavior) - loaded_params = super().load_config_and_state_dict( - calling_cls, - restore_path, - override_config_path, - map_location, - strict, - return_config, - trainer, - validate_access_integrity, - ) - if not isinstance(loaded_params, tuple) or return_config is True: - return loaded_params - conf, instance, state_dict = loaded_params - - # if we're using dist checkpointing then state_dict will be None - if state_dict is None: - # dist checkpointing needs torch.distributed to load the checkpoint - if not parallel_state.is_initialized(): + # if we're using dist checkpointing then state_dict will be None + if state_dict is None: + # dist checkpointing needs torch.distributed to load the checkpoint + if not parallel_state.is_initialized(): - def dummy(): - return + def dummy(): + return - if trainer.strategy.launcher is not None: - trainer.strategy.launcher.launch(dummy, trainer=trainer) - trainer.strategy.setup_environment() + if trainer.strategy.launcher is not None: + trainer.strategy.launcher.launch(dummy, trainer=trainer) + trainer.strategy.setup_environment() - with tempfile.TemporaryDirectory() as tmpdir: - # Check if self.model_extracted_dir is set, and is a valid path + # with tempfile.TemporaryDirectory() as tmpdir: + # Check if self.model_extracted_dir is set, and is a valid path if self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir): # Log that NeMo will use the provided `model_extracted_dir` logging.info( @@ -1423,11 +1509,12 @@ def dummy(): if hasattr(instance, 'setup_transformer_engine_tp_groups'): instance.setup_transformer_engine_tp_groups() - else: - state_dict = self.modify_state_dict(conf, state_dict) - super().load_instance_with_state_dict(instance, state_dict, strict) - logging.info(f'Model {instance.__class__.__name__} was successfully restored from {restore_path}.') - return instance + else: + state_dict = self.modify_state_dict(conf, state_dict) + super().load_instance_with_state_dict(instance, state_dict, strict) + + logging.info(f'Model {instance.__class__.__name__} was successfully restored from {restore_path}.') + return instance class PipelineMixedPrecisionPlugin(MixedPrecisionPlugin): diff --git a/nemo/core/connectors/save_restore_connector.py b/nemo/core/connectors/save_restore_connector.py index 9ae6cae835a0..1602dffaa0fc 100644 --- a/nemo/core/connectors/save_restore_connector.py +++ b/nemo/core/connectors/save_restore_connector.py @@ -19,6 +19,7 @@ import tarfile import tempfile import uuid +import time from contextlib import contextmanager from typing import Callable, Generator, Optional, Set, Union @@ -33,6 +34,13 @@ from nemo.utils.get_rank import is_global_rank_zero from nemo.utils.model_utils import inject_model_parallel_rank +try: + import multistorageclient + from multistorageclient.types import MSC_PROTOCOL as MULTISTORAGECLIENT_PROTOCOL + + MULTISTORAGECLIENT_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + MULTISTORAGECLIENT_AVAILABLE = False class SaveRestoreConnector: def __init__(self) -> None: @@ -586,10 +594,28 @@ def _inject_model_parallel_rank_for_ckpt(self, dirname, basename): @staticmethod def _make_nemo_file_from_folder(filename, source_dir): - dirname = os.path.dirname(filename) - os.makedirs(dirname, exist_ok=True) - with tarfile.open(filename, "w:") as tar: - tar.add(source_dir, arcname=".") + is_multistorageclient_url = MULTISTORAGECLIENT_AVAILABLE and filename.startswith(MULTISTORAGECLIENT_PROTOCOL) + + if is_multistorageclient_url: + SaveRestoreConnector._make_nemo_file_from_folder_with_multistorageclient(filename, source_dir) + else: + dirname = os.path.dirname(filename) + os.makedirs(dirname, exist_ok=True) + with tarfile.open(filename, "w:") as tar: + tar.add(source_dir, arcname=".") + + @staticmethod + def _make_nemo_file_from_folder_with_multistorageclient(filename, source_dir): + filename_with_extension = filename.split("/")[-1] # get the filename and extension + with tempfile.TemporaryDirectory() as tmpdir: + tar_file = os.path.join(tmpdir, filename_with_extension) + with tarfile.open(tar_file, "w:") as tar: + tar.add(source_dir, arcname=".") + start_time = time.time() + multistorageclient.upload_file(filename, tar_file) + logging.debug( + f"time spent for multistorageclient.upload from {tar_file} to {filename}: {time.time() - start_time:.4f}" + ) @staticmethod def _is_safe_path(member, extract_to): @@ -671,20 +697,57 @@ def _tar_open(path2file: str) -> Generator[tarfile.TarFile, None, None]: @staticmethod def _unpack_nemo_file(path2file: str, out_folder: str, members: Optional[list[str]] = None) -> str: - with SaveRestoreConnector._tar_open(path2file) as tar: + is_multistorageclient_url = MULTISTORAGECLIENT_AVAILABLE and path2file.startswith(MULTISTORAGECLIENT_PROTOCOL) + if is_multistorageclient_url: + out_folder = SaveRestoreConnector._unpack_nemo_file_with_multistorageclient(path2file, out_folder, members) + else: + with SaveRestoreConnector._tar_open(path2file) as tar: + if members is None: + SaveRestoreConnector._safe_extract(tar, out_folder) + else: + SaveRestoreConnector._safe_extract(tar, out_folder, members) + return out_folder + + @staticmethod + def _unpack_nemo_file_with_multistorageclient( + path2file: str, out_folder: str, members: Optional[list[str]] = None + ) -> str: + if not multistorageclient.os.path.exists(path2file): + raise FileNotFoundError(f"{path2file} does not exist") + + with tempfile.TemporaryDirectory() as tmpdir: + filename_with_extension = path2file.split("/")[-1] # get the filename with extension + downloaded_file_path = os.path.join(tmpdir, filename_with_extension) + start_time = time.time() + multistorageclient.download_file(path2file, downloaded_file_path) + logging.info( + f"time spent for multistorageclient.download_file from {downloaded_file_path}: {time.time() - start_time:.4f}" + ) + + # we start with an assumption of uncompressed tar, + # which should be true for versions 1.7.0 and above + tar_header = "r:" + try: + tar_test = tarfile.open(downloaded_file_path, tar_header) + tar_test.close() + except tarfile.ReadError: + # can be older checkpoint => try compressed tar + tar_header = "r:gz" + tar = tarfile.open(downloaded_file_path, tar_header) if members is None: SaveRestoreConnector._safe_extract(tar, out_folder) else: SaveRestoreConnector._safe_extract(tar, out_folder, members) + tar.close() return out_folder @staticmethod def _save_state_dict_to_disk(state_dict, filepath): - torch.save(state_dict, filepath) + multistorageclient.torch.save(state_dict, filepath) @staticmethod def _load_state_dict_from_disk(model_weights, map_location=None): - return torch.load(model_weights, map_location='cpu', weights_only=False) + return multistorageclient.torch.load(model_weights, map_location='cpu', weights_only=False) @property def model_config_yaml(self) -> str: diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index a51a353d46de..3802433ee13b 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -27,6 +27,7 @@ from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO from nemo.utils import logging +from multistorageclient.types import MSC_PROTOCOL try: from megatron.core import dist_checkpointing @@ -276,8 +277,10 @@ def save_checkpoint( path (_PATH): checkpoint directory storage_options (Any, optional): Optional parameters when saving the checkpoint """ - fs = get_filesystem(path) - fs.makedirs(path, exist_ok=True) + + if MSC_PROTOCOL not in path: + fs = get_filesystem(path) + fs.makedirs(path, exist_ok=True) validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure) self.validated_consistency = True diff --git a/nemo/utils/callbacks/nemo_model_checkpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py index 0c19f69a9aae..f06b1aaed799 100644 --- a/nemo/utils/callbacks/nemo_model_checkpoint.py +++ b/nemo/utils/callbacks/nemo_model_checkpoint.py @@ -26,6 +26,7 @@ from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol from lightning.pytorch.trainer import call from lightning.pytorch.utilities import rank_zero_info +from torch import Tensor from nemo.collections.common.callbacks import EMA from nemo.utils import logging @@ -34,6 +35,14 @@ from nemo.utils.get_rank import is_global_rank_zero from nemo.utils.model_utils import ckpt_to_dir, inject_model_parallel_rank, uninject_model_parallel_rank +try: + import multistorageclient + from multistorageclient.types import MSC_PROTOCOL as MULTISTORAGECLIENT_PROTOCOL + + MULTISTORAGECLIENT_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + MULTISTORAGECLIENT_AVAILABLE = False + class NeMoModelCheckpoint(ModelCheckpoint): """Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end. @@ -85,6 +94,10 @@ def __init__( else: self.prefix = "" + # flag for enabling multistorageclient checkpointing + if 'multistorageclient_enabled' in kwargs: + self.multistorageclient_enabled = MULTISTORAGECLIENT_AVAILABLE and kwargs.pop('multistorageclient_enabled') + # Call the parent class constructor with the remaining kwargs. super().__init__(**kwargs) @@ -215,6 +228,11 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint): maybe_injected_best_model_path = self.best_model_path if self.save_best_model: + + if self.multistorageclient_enabled: + if not multistorageclient.os.path.exists(maybe_injected_best_model_path): + return + if not os.path.exists(maybe_injected_best_model_path): return @@ -224,7 +242,7 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint): self.previous_best_path = self.best_model_path old_state_dict = deepcopy(pl_module.state_dict()) - checkpoint = torch.load(maybe_injected_best_model_path, map_location='cpu', weights_only=False) + checkpoint = multistorageclient.torch.load(maybe_injected_best_model_path, map_location='cpu', weights_only=False) if 'state_dict' in checkpoint: checkpoint = checkpoint['state_dict'] # get a new instanace of the model @@ -276,8 +294,14 @@ def on_train_end(self, trainer, pl_module): "were found. Saving latest model instead." ) else: - if os.path.isdir(self.best_model_path.split('.ckpt')[0]): - self.best_model_path = self.best_model_path.split('.ckpt')[0] + if self.multistorageclient_enabled: + if multistorageclient.os.path.exists(self.best_model_path) and multistorageclient.os.path.isdir(self.best_model_path): + self.best_model_path = self.best_model_path.split('.ckpt')[0] + + else: + if os.path.isdir(self.best_model_path.split('.ckpt')[0]): + self.best_model_path = self.best_model_path.split('.ckpt')[0] + self.best_model_path = trainer.strategy.broadcast(self.best_model_path) trainer._checkpoint_connector.restore(self.best_model_path) @@ -314,12 +338,18 @@ def _backup_existing_nemo_ckpt(self, trainer) -> Optional[str]: return None if trainer.is_global_zero: logging.info(f'{base_path} already exists, moving existing checkpoint to {available_path}') - shutil.move(base_path, available_path) + if self.multistorageclient_enabled: + # TODO: multistorageclient doesn't have "rename" function, therefore no-op but we should refactor this once multistorageclient have rename function supported. + pass + else: + shutil.move(base_path, available_path) trainer.strategy.barrier() return available_path def _format_nemo_checkpoint_name(self, ver: Optional[int] = None) -> str: version_infix = '' if ver is None else f'{self.CHECKPOINT_JOIN_CHAR}v{ver}' + if self.multistorageclient_enabled: + return f"{self.dirpath}/{self.prefix + version_infix + self.postfix}" return os.path.abspath( os.path.expanduser(os.path.join(self.dirpath, self.prefix + version_infix + self.postfix)) ) @@ -509,7 +539,11 @@ def file_exists( self, filepath: str, trainer: "lightning.pytorch.Trainer", check_dist_ckpt: bool = True # noqa: F821 ) -> bool: """Checks if a file or a file without a suffix (distributed checkpoint) exists.""" - exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(ckpt_to_dir(filepath))) + if self.multistorageclient_enabled: + exists = self._fs.exists(filepath) # todo(avm): unsure if we need this check + else: + exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(ckpt_to_dir(filepath))) + return trainer.strategy.broadcast(exists) def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str) -> None: # noqa: F821 @@ -628,13 +662,23 @@ def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool: def _saved_checkpoint_paths(self) -> Iterable[Path]: # distributed checkpoints are directories so we check for them here # we filter out unfinished checkpoints, these should be deleted during next cleanup - dist_checkpoints = [d for d in Path(self.dirpath).glob("*") if d.is_dir()] + if self.multistorageclient_enabled: + # TODO: support multistorageclient distributed checkpointing + return NeMoModelCheckpoint._derive_saved_checkpoint_paths_with_multistorageclient(self.dirpath) + else: + dist_checkpoints = [d for d in Path(self.dirpath).glob("*") if d.is_dir()] + if dist_checkpoints: return filter(lambda p: not self.is_checkpoint_unfinished(p), dist_checkpoints) else: checkpoint_files = [f for f in Path(self.dirpath).rglob("*.ckpt")] return filter(lambda p: not self.is_checkpoint_unfinished(p), checkpoint_files) + @staticmethod + def _derive_saved_checkpoint_paths_with_multistorageclient(dirpath: str) -> Iterable[Path]: + return multistorageclient.glob(f"{dirpath}/*.ckpt") + + @staticmethod def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None: @@ -644,32 +688,45 @@ def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None: if not is_global_rank_zero(): raise AssertionError("_remove_unfinished_checkpoints should run only on rank 0") - checkpoint_dir = Path(checkpoint_dir) - - existing_marker_filepaths = { - f.resolve() - for f in checkpoint_dir.glob(f"*{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}") - if f.is_file() - } - - checkpoint_filepaths = {f.resolve() for f in checkpoint_dir.rglob("*.ckpt")} - for ckpt_filepath in checkpoint_filepaths: - possible_marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_filepath) - if possible_marker_path in existing_marker_filepaths: - logging.warning(f'Removing unfinished checkpoint: {ckpt_filepath}') - os.remove(ckpt_filepath) - - # some directories might be distributed checkpoints, we remove these if they have a unfinished marker - all_dirpaths = {d.resolve() for d in checkpoint_dir.glob("*") if d.is_dir()} - for ckpt_dirpath in all_dirpaths: - possible_marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_dirpath) - if possible_marker_path in existing_marker_filepaths: - logging.warning(f'Removing unfinished dist checkpoint: {ckpt_dirpath}') - shutil.rmtree(ckpt_dirpath) - - # delete markers - for marker_path in existing_marker_filepaths: - os.remove(marker_path) + multistorageclient_enabled = MULTISTORAGECLIENT_AVAILABLE and str(checkpoint_dir).startswith( + MULTISTORAGECLIENT_PROTOCOL + ) + + # TODO: add multistorageclient support for distributed checkpointing + if multistorageclient_enabled: + existing_marker_filepaths = multistorageclient.glob( + f"{checkpoint_dir}*{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}" + ) + fs = get_filesystem(checkpoint_dir) + for ckpt_filepath in existing_marker_filepaths: + fs.rm(ckpt_filepath) + else: + checkpoint_dir = Path(checkpoint_dir) + + existing_marker_filepaths = { + f.resolve() + for f in checkpoint_dir.glob(f"*{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}") + if f.is_file() + } + + checkpoint_filepaths = {f.resolve() for f in checkpoint_dir.rglob("*.ckpt")} + for ckpt_filepath in checkpoint_filepaths: + possible_marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_filepath) + if possible_marker_path in existing_marker_filepaths: + logging.warning(f'Removing unfinished checkpoint: {ckpt_filepath}') + os.remove(ckpt_filepath) + + # some directories might be distributed checkpoints, we remove these if they have a unfinished marker + all_dirpaths = {d.resolve() for d in checkpoint_dir.glob("*") if d.is_dir()} + for ckpt_dirpath in all_dirpaths: + possible_marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_dirpath) + if possible_marker_path in existing_marker_filepaths: + logging.warning(f'Removing unfinished dist checkpoint: {ckpt_dirpath}') + shutil.rmtree(ckpt_dirpath) + + # delete markers + for marker_path in existing_marker_filepaths: + os.remove(marker_path) def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool: # noqa: F821 """Checks if the previous checkpoint should be deleted. @@ -697,3 +754,26 @@ def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, curren raise ValueError(f"{self.__class__}.dirpath is None.") dirpath = Path(self.dirpath).absolute() return dirpath in previous.parents + + def format_checkpoint_name( + self, metrics: Dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None + ) -> str: + """ + Override the format_checkpoint_name behavior from lightning's ModelCheckpoint to support multistorageclient. + Specifically, if multistorageclient_enabled = true, use string formatting to construct the full path; + Otherwise, reuse the original logic of os.path.join to construct the full path + """ + + filename = filename or self.filename + filename = self._format_checkpoint_name( + filename, metrics, auto_insert_metric_name=self.auto_insert_metric_name + ) + + if ver is not None: + filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) + + ckpt_name = f"{filename}{self.FILE_EXTENSION}" + if self.multistorageclient_enabled: + return f"{self.dirpath}/{ckpt_name}" if self.dirpath else ckpt_name + else: + return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name \ No newline at end of file diff --git a/nemo/utils/data_utils.py b/nemo/utils/data_utils.py index 6ef603df85e5..7823dbbbd453 100644 --- a/nemo/utils/data_utils.py +++ b/nemo/utils/data_utils.py @@ -295,7 +295,10 @@ def datastore_path_to_webdataset_url(store_path: str): URL which can be directly used with WebDataset. """ if is_datastore_path(store_path): - url = f'pipe:ais get {store_path} - || true' + if not store_path.startswith("msc://"): + url = f'pipe:ais get {store_path} - || true' + else: + return store_path else: raise ValueError(f'Unknown store path format: {store_path}') diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 9d098b9b5bb3..73e9c45c7bc4 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -57,6 +57,14 @@ "megatron.core.num_microbatches_calculator", "get_current_global_batch_size" ) +try: + import multistorageclient + from multistorageclient.types import MSC_PROTOCOL as MUTLISTORAGECLIENT_PROTOCOL + + MUTLISTORAGECLIENT_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + MUTLISTORAGECLIENT_AVAILABLE = False + try: # `ptl_resiliency` is included in `gwe_resiliency_pkg` package @@ -146,7 +154,7 @@ class CallbackParams: async_save: Optional[bool] = False # save the checkpoint asynchronously # a number of last checkpoints to be saved with optimizer states save_last_n_optim_states: Optional[int] = -1 - + multistorageclient_enabled: Optional[bool] = False @dataclass class StepTimingParams: @@ -909,7 +917,7 @@ def check_resume( # If we are using S3 checkpointing, we want check_resume to only execute on a single rank # to avoid throttling S3. - if is_global_rank_zero() or not is_s3_url(dirpath): + if is_global_rank_zero() or not (is_s3_url(dirpath) and is_multistorageclient_url(dirpath)): checkpoint_dir_exists = False if is_s3_url(dirpath): checkpoint_dir = dirpath @@ -924,6 +932,16 @@ def check_resume( else: end_checkpoints = [] last_checkpoints = [] + elif is_multistorageclient_url(dirpath): + checkpoint_dir = dirpath + all_keys = multistorageclient.glob(f"{dirpath}**/*.ckpt") + checkpoint_dir_exists = True if all_keys else False + if all_keys: + end_checkpoints = sorted([k for k in all_keys if k.endswith('end.ckpt')], reverse=True) + last_checkpoints = sorted([k for k in all_keys if k.endswith('last.ckpt')], reverse=True) + else: + end_checkpoints = [] + last_checkpoints = [] else: # default non-s3 implementation # Use /checkpoints/ unless `dirpath` is set checkpoint_dir = Path(dirpath) if dirpath else Path(Path(log_dir) / "checkpoints") @@ -1489,3 +1507,8 @@ def clean_exp_ckpt(exp_log_dir: Union[str, Path], remove_ckpt: bool = True, remo for filepath in nemo_files: os.remove(filepath) logging.info(f"Deleted file : {filepath}") + + + +def is_multistorageclient_url(dirpath): + return MUTLISTORAGECLIENT_AVAILABLE and dirpath and dirpath.startswith(MUTLISTORAGECLIENT_PROTOCOL) \ No newline at end of file From f36a1b3a03e94b299e9f07e2f47e225b12c287c8 Mon Sep 17 00:00:00 2001 From: Ankit Master Date: Wed, 14 May 2025 17:19:54 -0700 Subject: [PATCH 2/3] not working --- Dockerfile.sft | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 Dockerfile.sft diff --git a/Dockerfile.sft b/Dockerfile.sft new file mode 100644 index 000000000000..b148f06c9ad5 --- /dev/null +++ b/Dockerfile.sft @@ -0,0 +1,29 @@ +FROM nvcr.io/nvidian/tegra-audio/speech-mlops-llm:sft_train + +WORKDIR /workspace + +COPY . /workspace/NeMo + +RUN pip install retrying + +RUN pip install multi-storage-client +#RUN pip install /workspace/NeMo/multi-storage-client/. + +RUN pip install pytorch_lightning==2.4.0 +RUN pip install lightning>=2.0.0 +RUN pip install apex +RUN pip install --upgrade megatron-core[all] + +RUN pip install "boto3>=1.36" + +ENV CXXFLAGS="-std=c++17" + +# RUN pip install aistore + +#RUN pip install /workspace/nemo-with-storageclient/megatron-lm-msc/. + +RUN pip install /workspace/NeMo/. + +ENV MSC_CONFIG /lustrefs/users/anmaster/avm_msc_config.json + +ENV MSC_PROFILE_NAME curr From 62e6bae202958bb38ef1e5089e5a083c43eeb0a9 Mon Sep 17 00:00:00 2001 From: ankitmaster08 Date: Thu, 15 May 2025 14:51:07 +0000 Subject: [PATCH 3/3] Apply isort and black reformatting Signed-off-by: ankitmaster08 --- nemo/collections/nlp/parts/nlp_overrides.py | 70 ++++++++++--------- .../core/connectors/save_restore_connector.py | 3 +- nemo/utils/callbacks/dist_ckpt_io.py | 2 +- nemo/utils/callbacks/nemo_model_checkpoint.py | 14 ++-- nemo/utils/exp_manager.py | 1 + 5 files changed, 49 insertions(+), 41 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 9a72507d7128..7ec88d6b63cc 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -64,6 +64,11 @@ # since PyTorch 2.3 the path has changed from torch.amp.grad_scaler import _refresh_per_optimizer_state +from concurrent.futures import ThreadPoolExecutor + +import multistorageclient as msc +from multistorageclient.types import MSC_PROTOCOL + from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.transformer import AutocastTransformerLayer, ParallelTransformerLayer from nemo.collections.nlp.parts import utils_funcs @@ -73,10 +78,6 @@ from nemo.utils import AppState, logging from nemo.utils.model_utils import ckpt_to_dir, inject_model_parallel_rank, uninject_model_parallel_rank -from concurrent.futures import ThreadPoolExecutor -from multistorageclient.types import MSC_PROTOCOL -import multistorageclient as msc - try: from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam @@ -1042,25 +1043,26 @@ def msc_download_dir(url: str, local_path: str): if not msc.os.path.exists(url): raise Exception(f"Download Path doesn't exist: {url}") - base_name = os.path.basename(url) #url = "msc://my-profile/path/to/data", base_name = "data" + base_name = os.path.basename(url) # url = "msc://my-profile/path/to/data", base_name = "data" files = msc.list(url) def download_file(item): """Helper function to download a single file.""" - file_name = item.key #item.key = "msc://profile/path/to/data/file1.txt" - base_name_idx = file_name.find(base_name) # base_name_idx = 23 - local_file_path = f"{local_path}/{file_name[base_name_idx:]}" #local_file_path = f"{local_path}/data/file1.txt" + file_name = item.key # item.key = "msc://profile/path/to/data/file1.txt" + base_name_idx = file_name.find(base_name) # base_name_idx = 23 + local_file_path = ( + f"{local_path}/{file_name[base_name_idx:]}" # local_file_path = f"{local_path}/data/file1.txt" + ) os.makedirs(os.path.dirname(local_file_path), exist_ok=True) msc.download_file(item, local_file_path) - #msc.download_file(f"{MSC_PROTOCOL}{get_profile()}/{file_name}", local_file_path) + # msc.download_file(f"{MSC_PROTOCOL}{get_profile()}/{file_name}", local_file_path) # Use ThreadPoolExecutor for par allel downloads with ThreadPoolExecutor(max_workers=32) as executor: # Adjust max_workers as needed executor.map(download_file, files) logging.warning(f"msc_download_dir completed rank {torch.distributed.get_rank()}") - - + class NLPSaveRestoreConnector(SaveRestoreConnector): """Custom connector to support saving and restoring states.""" @@ -1083,7 +1085,6 @@ def __init__(self) -> None: ) super().__init__() - def save_to(self, model, save_path: str): """Save model to save path.""" app_state = AppState() @@ -1102,17 +1103,16 @@ def save_to(self, model, save_path: str): is_msc_enabled = False if MSC_PROTOCOL in dir_name: is_msc_enabled = True - + # dist ckpt calls save on every rank if dist_ckpt: # model weights is a directory dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt)) if is_msc_enabled: - filename = os.path.join(dir_name, self.model_weights_ckpt) + filename = os.path.join(dir_name, self.model_weights_ckpt) dist_ckpt_dir = os.path.splitext(filename)[0] - - + # dist checkpoint needs torch.distributed to save the checkpoint if not parallel_state.is_initialized(): @@ -1185,8 +1185,10 @@ def dummy(): if is_msc_enabled: print(f"Downloading {mp_model_weights} to {tmpdir}") - msc_dest=os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}', self.model_weights_ckpt) - logging.warning(f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}") + msc_dest = os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}', self.model_weights_ckpt) + logging.warning( + f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}" + ) msc_download_dir(mp_model_weights, msc_dest) else: shutil.move( @@ -1206,8 +1208,12 @@ def dummy(): if is_msc_enabled: print(f"Downloading {mp_model_weights} to {tmpdir}") - msc_dest = os.path.join(tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}', self.model_weights_ckpt) - logging.warning(f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}") + msc_dest = os.path.join( + tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}', self.model_weights_ckpt + ) + logging.warning( + f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}" + ) msc_download_dir(mp_model_weights, msc_dest) else: shutil.move( @@ -1368,28 +1374,24 @@ def _load_state_dict_from_disk(self, model_weights, map_location=None): else: raise ValueError(f'Expected {model_weights} to be a file or directory.') - - def _download_nemo_file(self, - restore_path: str, - tmpdir: str) -> str: - # .nemo filename + def _download_nemo_file(self, restore_path: str, tmpdir: str) -> str: + # .nemo filename fname = os.path.basename(restore_path) - - #check if msc path exists + + # check if msc path exists if not msc.os.path.exists(restore_path): raise FileNotFoundError(f".nemo file doesn't exist at {restore_path}") - - #download .nemo file to tempdir + + # download .nemo file to tempdir os.makedirs(tmpdir, exist_ok=True) logging.warning(f"Starting .nemo download {restore_path}") msc.download_file(restore_path, f"{tmpdir}/{fname}") - - #update restore_path to point to downloaded .nemo + + # update restore_path to point to downloaded .nemo updated_restore_path = os.path.join(tmpdir, fname) logging.warning(f".nemo download complete; updated_restore_path to {updated_restore_path}") return updated_restore_path - def restore_from( self, calling_cls, @@ -1459,7 +1461,7 @@ def dummy(): trainer.strategy.setup_environment() # with tempfile.TemporaryDirectory() as tmpdir: - # Check if self.model_extracted_dir is set, and is a valid path + # Check if self.model_extracted_dir is set, and is a valid path if self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir): # Log that NeMo will use the provided `model_extracted_dir` logging.info( @@ -1512,7 +1514,7 @@ def dummy(): else: state_dict = self.modify_state_dict(conf, state_dict) super().load_instance_with_state_dict(instance, state_dict, strict) - + logging.info(f'Model {instance.__class__.__name__} was successfully restored from {restore_path}.') return instance diff --git a/nemo/core/connectors/save_restore_connector.py b/nemo/core/connectors/save_restore_connector.py index f0ab9d76bd26..9edbb9b1b12b 100644 --- a/nemo/core/connectors/save_restore_connector.py +++ b/nemo/core/connectors/save_restore_connector.py @@ -20,9 +20,9 @@ import tempfile import time import uuid -import time from contextlib import contextmanager from typing import Callable, Generator, Optional, Set, Union + import torch from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, OmegaConf @@ -42,6 +42,7 @@ except (ImportError, ModuleNotFoundError): MULTISTORAGECLIENT_AVAILABLE = False + class SaveRestoreConnector: def __init__(self) -> None: self._model_config_yaml = "model_config.yaml" diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index ff5ead3564b3..b031c53097ea 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -25,9 +25,9 @@ from lightning.fabric.utilities.types import _PATH from lightning.pytorch import Callback from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO +from multistorageclient.types import MSC_PROTOCOL from nemo.utils import logging -from multistorageclient.types import MSC_PROTOCOL try: from megatron.core import dist_checkpointing diff --git a/nemo/utils/callbacks/nemo_model_checkpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py index 2d6e8dacaa48..8114766e24cd 100644 --- a/nemo/utils/callbacks/nemo_model_checkpoint.py +++ b/nemo/utils/callbacks/nemo_model_checkpoint.py @@ -26,7 +26,7 @@ from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol from lightning.pytorch.trainer import call from lightning.pytorch.utilities import rank_zero_info -from torch import Tensor +from torch import Tensor from nemo.collections.common.callbacks import EMA from nemo.utils import logging @@ -232,7 +232,7 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint): if self.multistorageclient_enabled: if not multistorageclient.os.path.exists(maybe_injected_best_model_path): return - + if not os.path.exists(maybe_injected_best_model_path): return @@ -242,7 +242,9 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint): self.previous_best_path = self.best_model_path old_state_dict = deepcopy(pl_module.state_dict()) - checkpoint = multistorageclient.torch.load(maybe_injected_best_model_path, map_location='cpu', weights_only=False) + checkpoint = multistorageclient.torch.load( + maybe_injected_best_model_path, map_location='cpu', weights_only=False + ) if 'state_dict' in checkpoint: checkpoint = checkpoint['state_dict'] # get a new instanace of the model @@ -295,7 +297,9 @@ def on_train_end(self, trainer, pl_module): ) else: if self.multistorageclient_enabled: - if multistorageclient.os.path.exists(self.best_model_path) and multistorageclient.os.path.isdir(self.best_model_path): + if multistorageclient.os.path.exists(self.best_model_path) and multistorageclient.os.path.isdir( + self.best_model_path + ): self.best_model_path = self.best_model_path.split('.ckpt')[0] else: @@ -540,7 +544,7 @@ def file_exists( ) -> bool: """Checks if a file or a file without a suffix (distributed checkpoint) exists.""" if self.multistorageclient_enabled: - exists = self._fs.exists(filepath) # todo(avm): unsure if we need this check + exists = self._fs.exists(filepath) # todo(avm): unsure if we need this check else: exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(ckpt_to_dir(filepath))) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index ab80fca70520..4da0202eb12a 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -164,6 +164,7 @@ class CallbackParams: save_last_n_optim_states: Optional[int] = -1 multistorageclient_enabled: Optional[bool] = False + @dataclass class StepTimingParams: """StepTimingParams POD"""