diff --git a/pyproject.toml b/pyproject.toml index af9671b4f..f42ccd1b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dependencies = [ "packaging", "tqdm", "pyyaml", - "transformers", + "transformers!=4.57.2,!=4.57.3", # Exclude versions with known issues. Can probably be removed if a version >4.57.3 is released. "datasets", "protobuf", "SentencePiece", diff --git a/src/modalities/checkpointing/convert_dcp_to_torch.py b/src/modalities/checkpointing/convert_dcp_to_torch.py new file mode 100644 index 000000000..f765202ca --- /dev/null +++ b/src/modalities/checkpointing/convert_dcp_to_torch.py @@ -0,0 +1,111 @@ +import os +from pathlib import Path + +import torch +from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner +from torch.distributed.checkpoint.filesystem import FileSystemReader +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.checkpoint.state_dict_loader import _load_state_dict + +from modalities.config.config import ConfigDictType, load_app_config_dict, save_yaml_config_dict +from modalities.utils.env import EnvOverride + + +def convert_dcp_to_torch(dcp_checkpoint_dir: str, output_dir: str, model_key: str = "model_raw") -> str: + """Converts a DCP (Distributed Checkpoint) checkpoint—including + FSDP2, PP, or TP checkpoints—to a standard PyTorch checkpoint. + + Args: + dcp_checkpoint_dir (str): Directory containing the DCP checkpoint files (may include FSDP2, PP, or TP). + output_dir (str): Directory to save the converted PyTorch checkpoint. + model_key (str): Key of the model configuration in the modalities config. + Returns: + str: Path to the converted config file. + """ + os.makedirs(output_dir, exist_ok=True) + torch_checkpoint_file = os.path.join(output_dir, "pytorch_model.bin") + torch_config_file = convert_config_file(dcp_checkpoint_dir, output_dir, model_key, torch_checkpoint_file) + # TODO This is the (adapted) code from torch's dcp_to_torch_save(dcp_checkpoint_dir, torch_checkpoint_file) + # since we only want to convert the model state dict here. In future torch versions this function might + # support converting only parts of the checkpoint. + # (from torch.distributed.checkpoint.format_utils import dcp_to_torch_save) + sd: STATE_DICT_TYPE = {} + planner = _EmptyStateDictLoadPlanner(keys=["app.model"], allow_partial_load=True) + _load_state_dict(sd, storage_reader=FileSystemReader(dcp_checkpoint_dir), planner=planner, no_dist=True) + torch.save(sd["app"]["model"], torch_checkpoint_file) + return torch_config_file + + +def convert_config_file(dcp_checkpoint_dir: str, output_dir: str, model_key: str, torch_checkpoint_file: str) -> str: + """Converts the modalities config file for DCP to a config file for standard PyTorch checkpoint loading. + Args: + dcp_checkpoint_dir (str): Directory containing the DCP checkpoint files. + output_dir (str): Directory to save the converted config file. + model_key (str): Key of the model configuration in the modalities config. + torch_checkpoint_file (str): Path to the converted PyTorch checkpoint file. + Returns: + str: Path to the converted config file. + """ + config_src, dcp_config = load_dcp_config(dcp_checkpoint_dir) + config_dst: str = os.path.join(output_dir, os.path.basename(config_src)) + if os.path.exists(config_dst): + raise FileExistsError(f"Config file '{config_dst}' already exists.") + torch_config: ConfigDictType = { + "checkpointed_model": { + "component_key": "model", + "variant_key": "fsdp1_checkpointed", + "config": { + "checkpoint_loading": { + "component_key": "checkpoint_loading", + "variant_key": "torch", + "config": { + "device": "cpu", + "precision": "FP32", + }, + }, + "model": { + "instance_key": "model", + "pass_type": "BY_REFERENCE", + }, + "checkpoint_path": torch_checkpoint_file, + }, + }, + } + if model_key not in dcp_config: + raise KeyError( + f"Model key '{model_key}' not found in config file '{config_src}'." + f" Available keys: {list(dcp_config.keys())}" + ) + torch_config["model"] = dcp_config[model_key] + torch_config["model"]["config"]["use_meta_device"] = False + save_yaml_config_dict(torch_config, Path(config_dst)) + return config_dst + + +def load_dcp_config(dcp_checkpoint_dir: str) -> tuple[str, ConfigDictType]: + with EnvOverride({"LOCAL_RANK": "0", "RANK": "0", "WORLD_SIZE": "1"}): + config_src: str | None = find_yaml_config_in_dir(dcp_checkpoint_dir) + if config_src is None: + config_src = find_yaml_config_in_dir(str(Path(dcp_checkpoint_dir).parent)) + if config_src is None: + raise FileNotFoundError("No YAML config file found in checkpoint directory or its parent.") + dcp_config = load_app_config_dict(Path(config_src), experiment_id="-1") + return config_src, dcp_config + + +def find_yaml_config_in_dir(directory: str) -> str | None: + """Finds the first YAML config file in the given directory. + + Args: + directory (str): Directory to search for YAML files. + + Returns: + str | None: Path to the found YAML file or None if not found. + """ + if not os.path.isdir(directory) or not os.access(directory, os.R_OK): + # Directory does not exist or is not accessible + return None + for filename in os.listdir(directory): + if filename.endswith(".yaml") or filename.endswith(".yml"): + return os.path.join(directory, filename) + return None diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 5ae8c0822..e65a55ef1 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from functools import partial from pathlib import Path @@ -497,13 +499,14 @@ class ParallelDegreeConfig(BaseModel): # Recursive type representing arbitrary-depth YAML config structures. YAMLPrimitive = str | int | float | bool | None YAMLValue: TypeAlias = YAMLPrimitive | Path | list["YAMLValue"] | dict[str, "YAMLValue"] +ConfigDictType: TypeAlias = dict[str, YAMLValue] def load_app_config_dict( config_file_path: Path, experiment_id: Optional[str] = None, additional_resolver_funs: Optional[dict[str, Resolver]] = None, -) -> dict[str, YAMLValue]: +) -> ConfigDictType: """Load the application configuration from the given YAML file. Args: @@ -512,7 +515,7 @@ def load_app_config_dict( additional_resolver_funs (dict[str, Resolver], optional): Additional resolver functions. Returns: - dict[str, YAMLValue]: Dictionary representation of the config file with arbitrary depth. + ConfigDictType: Dictionary representation of the config file with arbitrary depth. """ def cuda_env_resolver_fun(var_name: str) -> int | str | None: @@ -528,6 +531,7 @@ def modalities_env_resolver_fun(var_name: str, kwargs: dict[str, Any]) -> str | def node_env_resolver_fun(var_name: str) -> int | None: if var_name == "num_cpus": return os.cpu_count() + return None OmegaConf.register_new_resolver("cuda_env", cuda_env_resolver_fun, replace=True) modalities_env_kwargs: dict[str, Any] = { @@ -546,6 +550,17 @@ def node_env_resolver_fun(var_name: str) -> int | None: OmegaConf.register_new_resolver(resolver_name, resolver_fun, replace=True) cfg = OmegaConf.load(config_file_path) - config_dict = cast(dict[str, YAMLValue], OmegaConf.to_container(cfg, resolve=True)) + config_dict = cast(ConfigDictType, OmegaConf.to_container(cfg, resolve=True)) return config_dict + + +def save_yaml_config_dict(config_dict: ConfigDictType, output_file_path: Path) -> None: + """Saves the given config dictionary as a YAML file. + + Args: + config_dict (ConfigDictType): Configuration dictionary to save. + output_file_path (Path): Path to the output YAML file. + """ + cfg = OmegaConf.create(config_dict) + OmegaConf.save(cfg, output_file_path) diff --git a/src/modalities/conversion/gpt2/conversion_model.py b/src/modalities/conversion/gpt2/conversion_model.py index f44ff33e6..6e1b12e2a 100644 --- a/src/modalities/conversion/gpt2/conversion_model.py +++ b/src/modalities/conversion/gpt2/conversion_model.py @@ -1,40 +1,49 @@ +import warnings + import torch import torch.nn as nn from tqdm import tqdm +from modalities.checkpointing.convert_dcp_to_torch import load_dcp_config +from modalities.config.config import ConfigDictType, PrecisionEnum, ProcessGroupBackendType from modalities.conversion.gpt2.configuration_gpt2 import GPT2Config from modalities.conversion.gpt2.modeling_gpt2 import GPT2DecoderLayer, GPT2ForCausalLM from modalities.models.components.layer_norms import LayerNormConfig from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2Block, PositionTypes from modalities.models.model import SwiGLU from modalities.models.utils import ModelTypeEnum, get_model_from_config +from modalities.running_env.cuda_env import MultiProcessingCudaEnv +from modalities.running_env.env_utils import PyTorchDtypes -def convert_model_checkpoint(modalities_config: dict) -> tuple[GPT2ForCausalLM, GPT2LLM]: +def convert_model_checkpoint(modalities_config: ConfigDictType) -> tuple[GPT2ForCausalLM, GPT2LLM]: """Converts the modalities model to a Huggingface transformers model. Both the loaded modalities model and the converted Huggingface model are returned so that they can be compared. Args: - modalities_config (dict): Modalities config dictionary. + modalities_config (ConfigDictType): Modalities config dictionary. Returns: tuple[GPT2ForCausalLM, GPT2LLM]: Converted Hugging Face model and the original modalities model. """ gpt2_config = convert_model_config(modalities_config) - hf_model = GPT2ForCausalLM(gpt2_config).to(dtype=torch.bfloat16) + dtype = PrecisionEnum( + modalities_config["checkpointed_model"]["config"]["checkpoint_loading"]["config"]["precision"] + ) + hf_model = GPT2ForCausalLM(gpt2_config).to(dtype=dtype.value) modalities_model = get_model_from_config(modalities_config, model_type=ModelTypeEnum.CHECKPOINTED_MODEL) _copy_weights_model(hf_model, modalities_model) return hf_model, modalities_model -def convert_model_config(modalities_config: dict) -> GPT2Config: +def convert_model_config(modalities_config: ConfigDictType) -> GPT2Config: """Converts the modalities model configuration to a Huggingface transformers configuration. For this the model_raw or model section of the modalities config is used. Corresponding entries are mapped to the Huggingface configuration. Args: - modalities_config (dict): Modalities config dictionary. + modalities_config (ConfigDictType): Modalities config dictionary. Returns: GPT2Config: Converted Huggingface model configuration. @@ -43,6 +52,12 @@ def convert_model_config(modalities_config: dict) -> GPT2Config: _check_conversion_criteria(config) ffn_norm_key = "ffn_norm_config" + attention_type = _map_attention_type(config) + if attention_type != "sdpa": + warnings.warn( + f"transformers checkpoint will not save the attention implementation " + f"(set to {attention_type}) and use sdpa by default." + ) return GPT2Config( vocab_size=config["vocab_size"], @@ -62,11 +77,24 @@ def convert_model_config(modalities_config: dict) -> GPT2Config: layer_norm_bias=_get_layer_norm_value(config[ffn_norm_key]["config"], "bias"), max_position_embeddings=config["sequence_length"], rope_theta=config["attention_config"]["qkv_transforms"][0]["config"]["base_freq"], - _attn_implementation=_map_attention_type(config), + attn_implementation=attention_type, output_attentions=False, ) +def check_converted_dcp_model( + hf_model_dir: str, dcp_dir: str, num_testruns: int, device_id_modalities: str | int, device_hf: str +): + new_config: ConfigDictType = _build_single_node_dcp_config(dcp_dir) + hf_model = _load_hf_model_for_dcp_comparison(hf_model_dir, new_config, device_hf) + vocab_size: int = new_config["model_raw" if "model_raw" in new_config else "model"]["config"]["vocab_size"] + if isinstance(device_id_modalities, str): + device_id_modalities = int(device_id_modalities.replace("cuda:", "")) + with MultiProcessingCudaEnv(ProcessGroupBackendType.nccl, 0, 0, 1, 24570, device_id=device_id_modalities): + modalities_model = get_model_from_config(new_config, model_type=ModelTypeEnum.DCP_CHECKPOINTED_MODEL) + check_converted_model(hf_model, modalities_model, num_testruns=num_testruns, vocab_size=vocab_size) + + def check_converted_model(hf_model: GPT2ForCausalLM, modalities_model: GPT2LLM, num_testruns: int, vocab_size: int): """Tests the converted model by inputting a random token sequence and comparing the output logits of both models. @@ -85,14 +113,85 @@ def check_converted_model(hf_model: GPT2ForCausalLM, modalities_model: GPT2LLM, modalities_logits = modalities_model(inputs)[modalities_model.prediction_key].to("cpu") assert llama_logits.shape == modalities_logits.shape + assert llama_logits.dtype == modalities_logits.dtype assert torch.equal(llama_logits, modalities_logits) -def _check_conversion_criteria(model_config: dict) -> None: +def _build_single_node_dcp_config(dcp_dir: str) -> ConfigDictType: + """Builds a modalities config dictionary for loading a DCP checkpointed model on a single node. + + Args: + dcp_dir (str): Directory containing the DCP checkpoint. + + Returns: + ConfigDictType: New modalities config dictionary for loading the DCP checkpointed model. + """ + _, dcp_config = load_dcp_config(dcp_dir) + model_key = "model_raw" if "model_raw" in dcp_config else "model" + new_config: ConfigDictType = { + "fsdp_model": dcp_config["fsdp_model"], + "initialized_model": dcp_config["initialized_model"], + model_key: dcp_config[model_key], + } + if "settings" in dcp_config: + new_config["settings"] = dcp_config["settings"] + new_config["settings"]["config_file_path"] = "converted_dcp_config.yaml" + if "dp_degree" in dcp_config: + new_config["dp_degree"] = dcp_config["dp_degree"] + if "optimizer" in dcp_config: + new_config["optimizer"] = dcp_config["optimizer"] + if "lr_scheduler" in dcp_config: + new_config["lr_scheduler"] = dcp_config["lr_scheduler"] + new_config["app_state"] = { + "component_key": "app_state", + "variant_key": "dcp", + "config": { + "raw_app_state": dcp_config["app_state_raw" if "app_state_raw" in dcp_config else "app_state"], + "checkpoint_dir_path": dcp_dir, + }, + } + new_config["device_mesh"] = { + "component_key": "device_mesh", + "variant_key": "default", + "config": { + "device_type": "cuda", + "data_parallel_shard_degree": 1, + "world_size": 1, + }, + } + new_config["fsdp_model"]["config"]["model"]["instance_key"] = model_key + new_config["initialized_model"]["config"]["model"] = {"instance_key": "fsdp_model", "pass_type": "BY_REFERENCE"} + return new_config + + +def _load_hf_model_for_dcp_comparison( + hf_model_dir: str, dcp_modalities_config: ConfigDictType, device_hf: str +) -> GPT2ForCausalLM: + # Need execution dtype of FSDP2 to get same outputs from model. + dtype = dcp_modalities_config["fsdp_model"]["config"]["mixed_precision_settings"]["param_dtype"] + hf_model: GPT2ForCausalLM = ( + GPT2ForCausalLM.from_pretrained(hf_model_dir, local_files_only=True, trust_remote_code=True) + .to(device=device_hf) + .to(PyTorchDtypes(dtype).value) + ) + # Need to match attention implementation + hf_model.config._attn_implementation = _map_attention_type( + dcp_modalities_config["model_raw" if "model_raw" in dcp_modalities_config else "model"]["config"] + ) + # Rotary embedding frequencies are not downcasted in FSDP2. + # Therefore, we need to ensure they remain in the original precision. + hf_model.model.rotary_emb.inv_freq = hf_model.model.rotary_emb.original_inv_freq.to( + hf_model.model.rotary_emb.inv_freq.device + ) + + return hf_model + + +def _check_conversion_criteria(model_config: ConfigDictType) -> None: """Checks that the modalities config fulfills criteria necessary for conversion Args: - model_config (dict): model or model_raw part of the Modalities config dictionary. + model_config (ConfigDictType): model or model_raw part of the Modalities config dictionary. Returns: None @@ -116,12 +215,12 @@ def _check_conversion_criteria(model_config: dict) -> None: ), "All norms must have the same eps setting." -def _get_layer_norm_value(config: dict, field: str) -> bool | float | int: +def _get_layer_norm_value(config: ConfigDictType, field: str) -> bool | float | int: default = LayerNormConfig.model_fields[field].default return config.get(field, default) -def _map_attention_type(config: dict): +def _map_attention_type(config: ConfigDictType) -> str: if config["attention_implementation"] == "pytorch_flash": attention_impl = "sdpa" elif config["attention_implementation"] == "manual": diff --git a/src/modalities/conversion/gpt2/convert_gpt2.py b/src/modalities/conversion/gpt2/convert_gpt2.py index a137ff09e..8fe54f972 100644 --- a/src/modalities/conversion/gpt2/convert_gpt2.py +++ b/src/modalities/conversion/gpt2/convert_gpt2.py @@ -1,11 +1,12 @@ """ usage: convert_gpt2.py [-h] [--num_testruns NUM_TESTRUNS] [--device_modalities DEVICE_MODALITIES] - [--device_hf DEVICE_HF] modalities_config output_dir + [--device_hf DEVICE_HF] [--dcp] [--model_key MODEL_KEY] + modalities_input output_dir Convert GPT-2 model checkpoint to Huggingface transformers format. positional arguments: - modalities_config Path to the modalities config file. + modalities_input Path to the modalities config file or the dcp checkpoint dir. output_dir Directory to save the converted model. options: @@ -16,21 +17,99 @@ Device for the modalities model. --device_hf DEVICE_HF Device for the Hugging Face model. + --dcp Indicates that the provided modalities checkpoint is in DCP format. + --model_key MODEL_KEY + Key of the model configuration in the modalities config. """ import argparse +import gc import logging import os from pathlib import Path +from tempfile import TemporaryDirectory +import torch + +from modalities.checkpointing.convert_dcp_to_torch import convert_dcp_to_torch from modalities.config.config import load_app_config_dict from modalities.conversion.gpt2.conversion_code import transfer_model_code -from modalities.conversion.gpt2.conversion_model import check_converted_model, convert_model_checkpoint +from modalities.conversion.gpt2.conversion_model import ( + check_converted_dcp_model, + check_converted_model, + convert_model_checkpoint, +) from modalities.conversion.gpt2.conversion_tokenizer import convert_tokenizer logger = logging.getLogger(__name__) +def main(): + _ensure_logging() + + os.environ["LOCAL_RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["RANK"] = "0" + + parser = argparse.ArgumentParser(description="Convert GPT-2 model checkpoint to Huggingface transformers format.") + parser.add_argument( + "modalities_input", type=str, help="Path to the modalities config file or the dcp checkpoint dir." + ) + parser.add_argument("output_dir", type=str, help="Directory to save the converted model.") + parser.add_argument("--num_testruns", type=int, default=0, help="Number of test runs to perform.") + parser.add_argument("--device_modalities", type=str, default="cpu", help="Device for the modalities model.") + parser.add_argument("--device_hf", type=str, default="cpu", help="Device for the Hugging Face model.") + parser.add_argument( + "--dcp", action="store_true", help="Indicates that the provided modalities checkpoint is in DCP format." + ) + parser.add_argument( + "--model_key", type=str, default="model_raw", help="Key of the model configuration in the modalities config." + ) + + args = parser.parse_args() + + logger.info("Starting GPT-2 conversion script...") + if args.dcp: + convert_gpt2_dcp( + args.modalities_input, + args.output_dir, + args.num_testruns, + args.device_modalities, + args.device_hf, + args.model_key, + ) + else: + convert_gpt2( + args.modalities_input, + args.output_dir, + args.num_testruns, + args.device_modalities, + args.device_hf, + ) + + +def convert_gpt2_dcp( + distributed_cp_dir: str, + output_dir: str, + num_testruns: int = 0, + device_id_modalities: str | int = 0, + device_hf: str = "cpu", + model_key: str = "model_raw", +) -> None: + with TemporaryDirectory() as temp_dir: + logger.info("Converting DCP checkpoint to standard PyTorch checkpoint...") + modalities_config_path = convert_dcp_to_torch(distributed_cp_dir, temp_dir, model_key=model_key) + logger.info("Converting standard PyTorch checkpoint to Huggingface transformers format...") + convert_gpt2(modalities_config_path, output_dir) + # Clear GPU and CPU memory before running tests + torch.cuda.empty_cache() + gc.collect() + + check_converted_dcp_model( + output_dir, distributed_cp_dir, num_testruns, device_id_modalities=device_id_modalities, device_hf=device_hf + ) + + def convert_gpt2( modalities_config_path: str, output_dir: str, @@ -77,10 +156,6 @@ def convert_gpt2( elif len(sentence_piece_tokenizer_configs) == 1: tokenizer_model = modalities_config["tokenizer"]["config"]["tokenizer_model_file"] bos_token_id, eos_token_id, pad_token_id, _ = convert_tokenizer(tokenizer_model, output_dir) - # The values bos=1, eos=2 and pad=None are set by default in the model config (as taken from Llama). - # Overwrite them, with the actual values from the internal SentencePiece tokenizer. - # Note, that the LlamaTokenizer wrapping around the SentencePiece tokenizer does not know about these values. - # The unk token id is not set in the model config. hf_model.config.bos_token_id = bos_token_id hf_model.config.eos_token_id = eos_token_id hf_model.config.pad_token_id = pad_token_id @@ -95,24 +170,15 @@ def convert_gpt2( transfer_model_code(output_dir) -if __name__ == "__main__": - os.environ["LOCAL_RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["RANK"] = "0" - - parser = argparse.ArgumentParser(description="Convert GPT-2 model checkpoint to Huggingface transformers format.") - parser.add_argument("modalities_config", type=str, help="Path to the modalities config file.") - parser.add_argument("output_dir", type=str, help="Directory to save the converted model.") - parser.add_argument("--num_testruns", type=int, default=0, help="Number of test runs to perform.") - parser.add_argument("--device_modalities", type=str, default="cpu", help="Device for the modalities model.") - parser.add_argument("--device_hf", type=str, default="cpu", help="Device for the Hugging Face model.") +def _ensure_logging(): + if not logger.hasHandlers(): + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(logging.INFO) - args = parser.parse_args() - convert_gpt2( - args.modalities_config, - args.output_dir, - args.num_testruns, - args.device_modalities, - args.device_hf, - ) +if __name__ == "__main__": + main() diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 0a846b38a..10efe42ee 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -175,7 +175,7 @@ def _update_cos_sin_tables(self, x): if seq_len != self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: self._seq_len_cached = seq_len t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32) - freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.float()) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) diff --git a/src/modalities/models/utils.py b/src/modalities/models/utils.py index 69e0c9e37..e33ecce97 100644 --- a/src/modalities/models/utils.py +++ b/src/modalities/models/utils.py @@ -1,9 +1,11 @@ from enum import Enum +import torch.nn as nn from pydantic import BaseModel from modalities.config.component_factory import ComponentFactory -from modalities.config.pydantic_if_types import PydanticPytorchModuleType +from modalities.config.config import ConfigDictType +from modalities.config.pydantic_if_types import PydanticAppStateType, PydanticPytorchModuleType from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry @@ -15,22 +17,24 @@ class ModelTypeEnum(Enum): Attributes: MODEL (str): Represents a regular model. CHECKPOINTED_MODEL (str): Represents a checkpointed model. + DCP_CHECKPOINTED_MODEL (str): Represents a distributed checkpointed model. """ MODEL = "model" CHECKPOINTED_MODEL = "checkpointed_model" + DCP_CHECKPOINTED_MODEL = "dcp_checkpointed_model" -def get_model_from_config(config: dict, model_type: ModelTypeEnum): +def get_model_from_config(config: ConfigDictType, model_type: ModelTypeEnum) -> nn.Module: """ Retrieves a model from the given configuration based on the specified model type. Args: - config (dict): The configuration dictionary. + config (ConfigDictType): The configuration dictionary. model_type (ModelTypeEnum): The type of the model to retrieve. Returns: - Any: The model object based on the specified model type. + nn.Module: The model object based on the specified model type. Raises: NotImplementedError: If the model type is not supported. @@ -49,6 +53,15 @@ class PydanticConfig(BaseModel): class PydanticConfig(BaseModel): checkpointed_model: PydanticPytorchModuleType + elif model_type.value == "dcp_checkpointed_model": + + class PydanticConfig(BaseModel): + app_state: PydanticAppStateType + + @property + def dcp_checkpointed_model(self) -> PydanticPytorchModuleType: + return self.app_state.model + else: raise NotImplementedError() diff --git a/src/modalities/running_env/cuda_env.py b/src/modalities/running_env/cuda_env.py index b58efc7f4..ac6c8854c 100644 --- a/src/modalities/running_env/cuda_env.py +++ b/src/modalities/running_env/cuda_env.py @@ -6,6 +6,7 @@ import torch.distributed as dist from modalities.config.config import ProcessGroupBackendType +from modalities.utils.env import EnvOverride class CudaEnv: @@ -15,14 +16,18 @@ def __init__( self, process_group_backend: ProcessGroupBackendType, timeout_s: int = 600, + **process_group_kwargs: Any, ) -> None: """Initializes the CudaEnv context manager with the process group backend. Args: process_group_backend (ProcessGroupBackendType): Process group backend to be used for distributed training. + timeout_s (int, optional): Timeout in seconds for process group initialization. Defaults to 600. + **process_group_kwargs: Additional keyword arguments for process group initialization. """ self.process_group_backend = process_group_backend self._timeout_s = timeout_s + self._process_group_kwargs = process_group_kwargs def __enter__(self) -> "CudaEnv": """Sets the CUDA environment for distributed training. @@ -30,7 +35,9 @@ def __enter__(self) -> "CudaEnv": Returns: CudaEnv: Instance of the CudaEnv context manager. """ - dist.init_process_group(self.process_group_backend.value, timeout=timedelta(seconds=self._timeout_s)) + dist.init_process_group( + self.process_group_backend.value, timeout=timedelta(seconds=self._timeout_s), **self._process_group_kwargs + ) local_rank = int(os.getenv("LOCAL_RANK", "-1")) if local_rank == -1: raise ValueError("LOCAL_RANK environment variable is not set. Please set it before using CudaEnv.") @@ -55,3 +62,40 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: dist.destroy_process_group() except Exception as e: print(f"[Rank {local_rank}] Error during process group cleanup: {e}") + + +class MultiProcessingCudaEnv(CudaEnv): + """Context manager to set the CUDA environment for distributed training.""" + + def __init__( + self, + process_group_backend: ProcessGroupBackendType, + global_rank: int, + local_rank: int, + world_size: int, + rdvz_port: int, + timeout_s: int = 600, + **process_group_kwargs: Any, + ) -> None: + super().__init__(process_group_backend=process_group_backend, timeout_s=timeout_s, **process_group_kwargs) + self._env_override = EnvOverride( + { + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(rdvz_port), + "RANK": str(global_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + } + ) + + def __enter__(self) -> "MultiProcessingCudaEnv": + # Set environment overrides + self._env_override.__enter__() + # Initialize CUDA environment + super().__enter__() + return self + + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any | None): + # Restore original environment variables + self._env_override.__exit__(exc_type, exc_val, exc_tb) + super().__exit__(exc_type, exc_val, exc_tb) diff --git a/src/modalities/utils/debug.py b/src/modalities/utils/debug.py index fe7461f2b..d31dac06a 100644 --- a/src/modalities/utils/debug.py +++ b/src/modalities/utils/debug.py @@ -97,4 +97,6 @@ def print_forward_hook( ) if not print_shape_only: print(f">>> Input:\n{input}") + if hasattr(module, "weight"): + print(f">>> Weights:\n{module.weight}") print(f">>> Output:\n{output}") diff --git a/src/modalities/utils/env.py b/src/modalities/utils/env.py new file mode 100644 index 000000000..0bad721c9 --- /dev/null +++ b/src/modalities/utils/env.py @@ -0,0 +1,20 @@ +import os +from typing import Any + + +class EnvOverride: + def __init__(self, overrides: dict[str, str]): + self._overrides = overrides + self._original: dict[str, str | None] = {} + + def __enter__(self): + for key, value in self._overrides.items(): + self._original[key] = os.environ.get(key) + os.environ[key] = value + + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any | None): + for key, value in self._original.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value diff --git a/tests/checkpointing/test_fsdp1_to_disc_checkpointing.py b/tests/checkpointing/test_fsdp1_to_disc_checkpointing.py index 5b5f0428b..03bb2e70e 100644 --- a/tests/checkpointing/test_fsdp1_to_disc_checkpointing.py +++ b/tests/checkpointing/test_fsdp1_to_disc_checkpointing.py @@ -20,10 +20,10 @@ from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2LLMConfig from modalities.models.model_factory import ModelFactory from modalities.optimizers.optimizer_factory import OptimizerFactory +from modalities.running_env.cuda_env import MultiProcessingCudaEnv from modalities.running_env.env_utils import MixedPrecisionSettings from modalities.training.training_progress import TrainingProgress from tests.checkpointing.checkpointing_test_utils import CheckpointingTestUtils -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv def get_gpt2_model(gpt2_model_config_dict: GPT2LLMConfig) -> GPT2LLM: diff --git a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py index 6ecd65751..2634f7dc0 100644 --- a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py +++ b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py @@ -21,9 +21,9 @@ from modalities.checkpointing.stateful.app_state import AppState from modalities.config.config import ProcessGroupBackendType, load_app_config_dict from modalities.config.pydantic_if_types import PydanticAppStateType, PydanticPipelineType +from modalities.running_env.cuda_env import MultiProcessingCudaEnv from modalities.training.training_progress import TrainingProgress from tests.checkpointing.checkpointing_test_utils import CheckpointingTestUtils -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv from tests.utility import monitor_child_processes diff --git a/tests/conversion/gpt2/conftest.py b/tests/conversion/gpt2/conftest.py index 611b1d92f..ebd27c847 100644 --- a/tests/conversion/gpt2/conftest.py +++ b/tests/conversion/gpt2/conftest.py @@ -1,14 +1,30 @@ +import logging +import multiprocessing as py_mp import os import shutil +import traceback +from multiprocessing import Queue +from multiprocessing.managers import ListProxy from pathlib import Path import pytest import torch - -from modalities.config.config import load_app_config_dict +import torch.multiprocessing as mp +from pydantic import BaseModel + +from modalities.checkpointing.checkpoint_saving_instruction import CheckpointingInstruction +from modalities.checkpointing.fsdp.fsdp_checkpoint_saving import DCPCheckpointSaving +from modalities.config.component_factory import ComponentFactory +from modalities.config.config import ConfigDictType, ProcessGroupBackendType, load_app_config_dict +from modalities.config.pydantic_if_types import PydanticAppStateType from modalities.models.gpt2.gpt2_model import GPT2LLM from modalities.models.utils import ModelTypeEnum, get_model_from_config +from modalities.registry.components import COMPONENTS +from modalities.registry.registry import Registry +from modalities.running_env.cuda_env import MultiProcessingCudaEnv +from modalities.training.training_progress import TrainingProgress from tests.conftest import _ROOT_DIR +from tests.utility import find_free_port, monitor_child_processes @pytest.fixture @@ -43,7 +59,7 @@ def corrupt_model_head_key_in_state_dict(request: pytest.FixtureRequest) -> bool @pytest.fixture() -def initialized_model(set_env, modalities_config_dict: dict) -> GPT2LLM: +def initialized_model(set_env: None, modalities_config_dict: ConfigDictType) -> GPT2LLM: model = get_model_from_config(config=modalities_config_dict, model_type=ModelTypeEnum.MODEL) assert isinstance(model, GPT2LLM) return model @@ -57,7 +73,7 @@ def set_env(): @pytest.fixture() -def modalities_config_dict(config_file_path: Path) -> dict: +def modalities_config_dict(config_file_path: Path) -> ConfigDictType: return load_app_config_dict(config_file_path=config_file_path) @@ -67,6 +83,116 @@ def config_file_path(config_file_name: str) -> Path: return config_file_path -@pytest.fixture(params=["gpt2_config_test.yaml"]) -def config_file_name(request) -> str: - return request.param +@pytest.fixture() +def config_file_name() -> str: + return "gpt2_config_test.yaml" + + +@pytest.fixture() +def dcp_checkpoint(tmpdir_factory: pytest.TempdirFactory, corrupt_model_head_key_in_state_dict: bool) -> str: + tmp_path = tmpdir_factory.mktemp("dcp_checkpoint_test") + config_file = _ROOT_DIR / "tests" / "conversion" / "test_configs" / "gpt2_dcp_config.yaml" + world_size = 8 + port = find_free_port() + manager = py_mp.Manager() + try: + error_queue = manager.Queue() + return_list = manager.list([None] * world_size) + + proc_ctx = mp.spawn( + _create_dcp_checkpoint_worker, + args=( + world_size, + port, + tmp_path, + corrupt_model_head_key_in_state_dict, + config_file, + error_queue, + return_list, + ), + nprocs=world_size, + join=False, + ) + + monitor_child_processes(manager, error_queue, proc_ctx, shutdown_manager=False) + + checkpoint_path = return_list[0] + if checkpoint_path is None: + raise RuntimeError("DCP checkpoint creation failed.") + + finally: + manager.shutdown() + + yield checkpoint_path + + +def _create_dcp_checkpoint_worker( + device_idx: int, + world_size: int, + port: int, + output_dir: str, + corrupt_model_head_key_in_state_dict: bool, + config_file: str, + error_queue: Queue, + return_list: ListProxy, +): + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=device_idx, + local_rank=device_idx, + world_size=world_size, + rdvz_port=port, + ): + try: + modalities_config_dict = load_app_config_dict(config_file_path=config_file) + registry = Registry(COMPONENTS) + component_factory = ComponentFactory(registry=registry) + + class Components(BaseModel): + app_state: PydanticAppStateType + + components: Components = component_factory.build_components( + config_dict=modalities_config_dict, components_model_type=Components + ) + model: GPT2LLM = components.app_state.model + if corrupt_model_head_key_in_state_dict and hasattr(model.transformer, "lm_head"): + # Rename the key transformer.lm_head.weight to old_lm_head.weight + # simulating the old format used in modalities' gpt2 models. + model.transformer["old_lm_head"] = model.transformer.lm_head + del model.transformer["lm_head"] + + experiment_id = "0" + checkpoint_saving_execution = DCPCheckpointSaving( + checkpoint_path=Path(output_dir), experiment_id=experiment_id, global_rank=device_idx + ) + + checkpointing_instruction = CheckpointingInstruction(save_current=True, checkpoints_to_delete=[]) + training_progress = TrainingProgress( + num_seen_steps_current_run=0, + num_seen_tokens_current_run=0, + num_target_steps=16, # dummy value + num_target_tokens=256, # dummy value + ) + checkpoint_saving_execution.run_checkpoint_instruction( + checkpointing_instruction, training_progress, components.app_state + ) + # FIXME: Hack to get the checkpoint folder path + full_path = checkpoint_saving_execution._get_checkpointing_folder_path( + experiment_id=experiment_id, + num_seen_steps=training_progress.num_seen_steps_current_run, + num_seen_tokens=training_progress.num_seen_tokens_current_run, + num_target_steps=training_progress.num_target_steps, + num_target_tokens=training_progress.num_target_tokens, + ) + # Copy yaml config file to output dir + shutil.copy(config_file, Path(full_path) / "config.yaml") + return_list[device_idx] = full_path + except Exception as e: + tb = traceback.format_exc() + logging.error(f"Process {device_idx} encountered an error:\n{e}") + logging.error(tb) + try: + error_queue.put((device_idx, tb)) + except Exception: + logging.error("Failed to put exception info into error queue.") + os._exit(1) diff --git a/tests/conversion/gpt2/helper.py b/tests/conversion/gpt2/helper.py index 99adbacbc..2eeb333c3 100644 --- a/tests/conversion/gpt2/helper.py +++ b/tests/conversion/gpt2/helper.py @@ -1,13 +1,15 @@ import torch import torch.nn as nn +from torch.distributed.tensor import DTensor from modalities.conversion.gpt2.modeling_gpt2 import GPT2DecoderLayer, GPT2ForCausalLM from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2Block +@torch.no_grad() def check_same_weight_model(converted_model: GPT2ForCausalLM, modalities_model: GPT2LLM): converted_model.to(device=modalities_model.transformer.h["0"].attn.q_attn.weight.device) - assert torch.equal(converted_model.model.embed_tokens.weight, modalities_model.transformer.wte.weight) + assert torch.equal(converted_model.model.embed_tokens.weight, to_local(modalities_model.transformer.wte.weight)) for i, (llama_layer, modalities_layer_idx) in enumerate( zip(converted_model.model.layers, modalities_model.transformer.h) ): @@ -37,5 +39,13 @@ def check_same_weight_layer_norms(llama_layer: GPT2DecoderLayer, modalities_laye def check_same_weight_base_modules(l1: nn.Linear | nn.LayerNorm, l2: nn.Linear | nn.LayerNorm): - assert torch.equal(l1.weight, l2.weight) - assert (l1.bias is None and l2.bias is None) or torch.equal(l1.bias, l2.bias) + assert torch.equal(l1.weight, to_local(l2.weight)) + assert (l1.bias is None and l2.bias is None) or torch.equal(l1.bias, to_local(l2.bias)) + + +@torch.no_grad() +def to_local(tensor: torch.Tensor | DTensor) -> torch.Tensor: + """Convert a tensor or distributed tensor to a local tensor.""" + if isinstance(tensor, DTensor): + return tensor.to_local() + return tensor diff --git a/tests/conversion/gpt2/test_convert_gpt2.py b/tests/conversion/gpt2/test_convert_gpt2.py index 74569f94c..548b318f4 100644 --- a/tests/conversion/gpt2/test_convert_gpt2.py +++ b/tests/conversion/gpt2/test_convert_gpt2.py @@ -4,11 +4,16 @@ import torch from transformers import AutoModelForCausalLM, PreTrainedModel -from modalities.config.config import load_app_config_dict -from modalities.conversion.gpt2.conversion_model import check_converted_model -from modalities.conversion.gpt2.convert_gpt2 import convert_gpt2 +from modalities.config.config import ConfigDictType, ProcessGroupBackendType, load_app_config_dict +from modalities.conversion.gpt2.conversion_model import ( + _build_single_node_dcp_config, + check_converted_dcp_model, + check_converted_model, +) +from modalities.conversion.gpt2.convert_gpt2 import convert_gpt2, convert_gpt2_dcp from modalities.models.gpt2.gpt2_model import GPT2LLM from modalities.models.utils import ModelTypeEnum, get_model_from_config +from modalities.running_env.cuda_env import MultiProcessingCudaEnv from tests.conversion.gpt2.helper import check_same_weight_model @@ -24,6 +29,21 @@ def test_converting_gpt2_does_not_change_outputs( ) +@pytest.mark.skipif(torch.cuda.device_count() < 8, reason="This test requires 8 GPUs.") +def test_converting_dcp_gpt2_does_not_change_weights(converted_dcp_model: PreTrainedModel, dcp_checkpoint: str): + new_config: ConfigDictType = _build_single_node_dcp_config(dcp_checkpoint) + with MultiProcessingCudaEnv(ProcessGroupBackendType.nccl, 0, 0, 1, 24570, device_id=0): + modalities_model = get_model_from_config(new_config, model_type=ModelTypeEnum.DCP_CHECKPOINTED_MODEL) + check_same_weight_model(converted_dcp_model, modalities_model) + + +@pytest.mark.skipif(torch.cuda.device_count() < 8, reason="This test requires 8 GPUs.") +def test_converting_dcp_gpt2_does_not_change_outputs(run_convert_gpt2_dcp: None, output_dir: Path, dcp_checkpoint: str): + check_converted_dcp_model( + hf_model_dir=str(output_dir), dcp_dir=dcp_checkpoint, num_testruns=1, device_id_modalities=0, device_hf="cuda:1" + ) + + @pytest.fixture def converted_model(run_convert_gpt2: None, output_dir: Path) -> PreTrainedModel: return AutoModelForCausalLM.from_pretrained(output_dir, local_files_only=True, trust_remote_code=True).to( @@ -31,11 +51,21 @@ def converted_model(run_convert_gpt2: None, output_dir: Path) -> PreTrainedModel ) +@pytest.fixture +def converted_dcp_model(run_convert_gpt2_dcp: None, output_dir: Path) -> PreTrainedModel: + return AutoModelForCausalLM.from_pretrained(output_dir, local_files_only=True, trust_remote_code=True) + + @pytest.fixture def run_convert_gpt2(gpt2_config_path: Path, output_dir: Path): convert_gpt2(str(gpt2_config_path), str(output_dir)) +@pytest.fixture +def run_convert_gpt2_dcp(dcp_checkpoint: str, output_dir: Path): + convert_gpt2_dcp(dcp_checkpoint, str(output_dir)) + + @pytest.fixture def original_model(gpt2_config_path: Path) -> GPT2LLM: modalities_config = load_app_config_dict(gpt2_config_path) diff --git a/tests/dataloader/distributed/test_distributed_multidim_dataloader.py b/tests/dataloader/distributed/test_distributed_multidim_dataloader.py index a2767bae4..98796a968 100644 --- a/tests/dataloader/distributed/test_distributed_multidim_dataloader.py +++ b/tests/dataloader/distributed/test_distributed_multidim_dataloader.py @@ -8,10 +8,10 @@ from modalities.dataloader.dataloader import LLMDataLoader from modalities.dataloader.dataloader_factory import DataloaderFactory from modalities.dataloader.sampler_factory import SamplerFactory +from modalities.running_env.cuda_env import MultiProcessingCudaEnv from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_device_mesh, get_mesh_for_parallelism_method from tests.dataloader.distributed.mocks import MultiProcessingCudaEnvMock from tests.dataloader.dummy_sequential_dataset import TestDataset -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv from tests.mocks import MockDeviceMesh from tests.utility import find_free_port, tensors_equal_across_mesh, tensors_pairwise_not_equal_across_mesh diff --git a/tests/end2end_tests/custom_components.py b/tests/end2end_tests/custom_components.py index 57f51a5eb..143f7076e 100644 --- a/tests/end2end_tests/custom_components.py +++ b/tests/end2end_tests/custom_components.py @@ -1,13 +1,10 @@ -import os -from typing import Any, Optional +from typing import Any from pydantic import BaseModel from modalities.batch import EvaluationResultBatch -from modalities.config.config import ProcessGroupBackendType from modalities.logging_broker.messages import Message from modalities.logging_broker.subscriber import MessageSubscriberIF -from modalities.running_env.cuda_env import CudaEnv class SaveAllResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]): @@ -24,48 +21,3 @@ def consume_dict(self, message_dict: dict[str, Any]): class SaveAllResultSubscriberConfig(BaseModel): pass - - -class MultiProcessingCudaEnv(CudaEnv): - """Context manager to set the CUDA environment for distributed training.""" - - def __init__( - self, - process_group_backend: ProcessGroupBackendType, - global_rank: int, - local_rank: int, - world_size: int, - rdvz_port: int, - timeout_s: int = 600, - ) -> None: - super().__init__(process_group_backend=process_group_backend, timeout_s=timeout_s) - self.global_rank = global_rank - self.local_rank = local_rank - self.world_size = world_size - self.rdvz_port = rdvz_port - self._original_env: dict[str, Optional[str]] = {} - - def __enter__(self): - # Store original values - for key in ["MASTER_ADDR", "MASTER_PORT", "RANK", "LOCAL_RANK", "WORLD_SIZE"]: - self._original_env[key] = os.environ.get(key) - - # Set new environment variables - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(self.rdvz_port) - os.environ["RANK"] = str(self.global_rank) - os.environ["LOCAL_RANK"] = str(self.local_rank) - os.environ["WORLD_SIZE"] = str(self.world_size) - - # Initialize CUDA environment - super().__enter__() - return self - - def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any | None): - # Restore original environment variables - for key, value in self._original_env.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = value - super().__exit__(exc_type, exc_val, exc_tb) diff --git a/tests/end2end_tests/system_tests/test_fsdp_loss_convergence.py b/tests/end2end_tests/system_tests/test_fsdp_loss_convergence.py index f374cb55d..5bf36c1de 100644 --- a/tests/end2end_tests/system_tests/test_fsdp_loss_convergence.py +++ b/tests/end2end_tests/system_tests/test_fsdp_loss_convergence.py @@ -10,11 +10,8 @@ from modalities.config.config import ProcessGroupBackendType from modalities.config.instantiation_models import TrainingComponentsInstantiationModel from modalities.logging_broker.messages import Message -from tests.end2end_tests.custom_components import ( - MultiProcessingCudaEnv, - SaveAllResultSubscriber, - SaveAllResultSubscriberConfig, -) +from modalities.running_env.cuda_env import MultiProcessingCudaEnv +from tests.end2end_tests.custom_components import SaveAllResultSubscriber, SaveAllResultSubscriberConfig @pytest.mark.skipif( diff --git a/tests/end2end_tests/test_fsdp2_warmstart_pp_tp.py b/tests/end2end_tests/test_fsdp2_warmstart_pp_tp.py index 31c05ca2a..836f305b9 100644 --- a/tests/end2end_tests/test_fsdp2_warmstart_pp_tp.py +++ b/tests/end2end_tests/test_fsdp2_warmstart_pp_tp.py @@ -21,11 +21,8 @@ from modalities.config.pydantic_if_types import PydanticLLMDataLoaderIFType from modalities.dataloader.dataloader import LLMDataLoader from modalities.logging_broker.messages import Message -from tests.end2end_tests.custom_components import ( - MultiProcessingCudaEnv, - SaveAllResultSubscriber, - SaveAllResultSubscriberConfig, -) +from modalities.running_env.cuda_env import MultiProcessingCudaEnv +from tests.end2end_tests.custom_components import SaveAllResultSubscriber, SaveAllResultSubscriberConfig from tests.utility import monitor_child_processes working_dir = Path(os.path.dirname(__file__)) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py index 9534153a3..9c0c068ee 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -17,8 +17,8 @@ PydanticPipelineType, ) from modalities.models.parallelism.pipeline_parallelism import Pipeline +from modalities.running_env.cuda_env import MultiProcessingCudaEnv from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv class ComponentsInstantiationPPModel(BaseModel): diff --git a/tests/fsdp2_parallelization/test_full_and_hybrid_sharding.py b/tests/fsdp2_parallelization/test_full_and_hybrid_sharding.py index 06b600cab..702950b25 100644 --- a/tests/fsdp2_parallelization/test_full_and_hybrid_sharding.py +++ b/tests/fsdp2_parallelization/test_full_and_hybrid_sharding.py @@ -11,8 +11,8 @@ from modalities.__main__ import Main from modalities.config.config import ProcessGroupBackendType from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType +from modalities.running_env.cuda_env import MultiProcessingCudaEnv from modalities.util import get_local_number_of_trainable_parameters, get_total_number_of_trainable_parameters -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv @pytest.fixture diff --git a/tests/fsdp2_parallelization/test_tensor_parallelism.py b/tests/fsdp2_parallelization/test_tensor_parallelism.py index f611b7164..facac217b 100644 --- a/tests/fsdp2_parallelization/test_tensor_parallelism.py +++ b/tests/fsdp2_parallelization/test_tensor_parallelism.py @@ -16,7 +16,7 @@ from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType from modalities.models.gpt2.gpt2_model import TransformerMLP from modalities.models.model import SwiGLU -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv +from modalities.running_env.cuda_env import MultiProcessingCudaEnv def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path) -> Path: diff --git a/tests/instruction_tuning/test_e2e_instruction_tuning.py b/tests/instruction_tuning/test_e2e_instruction_tuning.py index 4b398c7e0..e185f1988 100644 --- a/tests/instruction_tuning/test_e2e_instruction_tuning.py +++ b/tests/instruction_tuning/test_e2e_instruction_tuning.py @@ -20,9 +20,9 @@ create_partitioned_instruction_tuning_index_and_pbin_files, ) from modalities.dataloader.dataset_factory import DatasetFactory +from modalities.running_env.cuda_env import MultiProcessingCudaEnv from modalities.tokenization.tokenizer_wrapper import PreTrainedHFTokenizer from tests.conftest import _ROOT_DIR -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv @pytest.mark.skipif( diff --git a/tests/test_initialization_fsdpx.py b/tests/test_initialization_fsdpx.py index 9d9bac435..673b80b6e 100644 --- a/tests/test_initialization_fsdpx.py +++ b/tests/test_initialization_fsdpx.py @@ -20,7 +20,7 @@ from modalities.__main__ import Main from modalities.config.config import ProcessGroupBackendType from modalities.config.pydantic_if_types import PydanticFSDP1ModuleType, PydanticFSDP2ModuleType -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv +from modalities.running_env.cuda_env import MultiProcessingCudaEnv @dataclass diff --git a/tests/test_optimizer_factory.py b/tests/test_optimizer_factory.py index d6c9a7b37..aeb2f6f28 100644 --- a/tests/test_optimizer_factory.py +++ b/tests/test_optimizer_factory.py @@ -17,9 +17,9 @@ from modalities.optimizers.optimizer_factory import get_optimizer_groups from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry +from modalities.running_env.cuda_env import MultiProcessingCudaEnv from modalities.running_env.env_utils import MixedPrecisionSettings from tests.conftest import _ROOT_DIR -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv from tests.utility import find_free_port # number of parameters for each optimizer group diff --git a/tests/test_util.py b/tests/test_util.py index f71450982..dece6fb0e 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -11,9 +11,9 @@ from modalities.__main__ import Main from modalities.config.config import ProcessGroupBackendType from modalities.config.pydantic_if_types import PydanticAppStateType, PydanticDeviceMeshIFType +from modalities.running_env.cuda_env import MultiProcessingCudaEnv from modalities.util import get_local_number_of_trainable_parameters, get_total_number_of_trainable_parameters from modalities.utils.typing_utils import FSDPX -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv from tests.utility import find_free_port diff --git a/tests/training/test_activation_checkpointing.py b/tests/training/test_activation_checkpointing.py index 9d1146a2b..851bd0dfd 100644 --- a/tests/training/test_activation_checkpointing.py +++ b/tests/training/test_activation_checkpointing.py @@ -13,7 +13,7 @@ from modalities.config.config import ProcessGroupBackendType from modalities.config.pydantic_if_types import PydanticPytorchModuleType from modalities.models.gpt2.gpt2_model import GPT2Block -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv +from modalities.running_env.cuda_env import MultiProcessingCudaEnv working_dir = Path(os.path.dirname(__file__)) diff --git a/tests/utility.py b/tests/utility.py index e3e08245b..106e58832 100644 --- a/tests/utility.py +++ b/tests/utility.py @@ -77,6 +77,7 @@ def monitor_child_processes( manager: SyncManager, error_queue: Queue, proc_ctx: ProcessContext, + shutdown_manager: bool = True, ) -> None: # Normalize the return value from mp.spawn. When join=False it often # returns a ProcessContext-like object that may expose a `processes` @@ -173,7 +174,8 @@ def monitor_child_processes( time.sleep(0.05) finally: try: - manager.shutdown() + if shutdown_manager: + manager.shutdown() except Exception: pass diff --git a/tests/utils/test_communication_test.py b/tests/utils/test_communication_test.py index 8b91df64c..02bfb8f94 100644 --- a/tests/utils/test_communication_test.py +++ b/tests/utils/test_communication_test.py @@ -3,8 +3,8 @@ import torch.multiprocessing as mp from modalities.config.config import ProcessGroupBackendType +from modalities.running_env.cuda_env import MultiProcessingCudaEnv from modalities.utils.communication_test import run_communication_test -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv @pytest.mark.skipif( diff --git a/tests/utils/test_experiment_id_generation.py b/tests/utils/test_experiment_id_generation.py index 75dbf8c8e..25f0a6267 100644 --- a/tests/utils/test_experiment_id_generation.py +++ b/tests/utils/test_experiment_id_generation.py @@ -7,8 +7,8 @@ import torch.multiprocessing as mp from modalities.config.config import ProcessGroupBackendType +from modalities.running_env.cuda_env import MultiProcessingCudaEnv from modalities.util import get_synced_experiment_id_of_run -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv @pytest.fixture diff --git a/tests/utils/test_mfu.py b/tests/utils/test_mfu.py index adac8f100..2f1e7b84c 100644 --- a/tests/utils/test_mfu.py +++ b/tests/utils/test_mfu.py @@ -18,9 +18,9 @@ PydanticFSDP2ModuleType, PydanticMFUCalculatorABCType, ) +from modalities.running_env.cuda_env import MultiProcessingCudaEnv from modalities.running_env.env_utils import MixedPrecisionSettings, PyTorchDtypes from modalities.utils.mfu import GPT2MFUCalculator, MFUCalculatorABC -from tests.end2end_tests.custom_components import MultiProcessingCudaEnv @pytest.fixture