diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml index 7f4e5fad5..7cd58a6d5 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml @@ -194,7 +194,7 @@ device_mesh: config: device_type: cuda data_parallel_replicate_degree: 1 - pipeline_parallel_degree: 2 + pipeline_parallel_degree: 4 data_parallel_shard_degree: -1 world_size: ${settings.cuda_env.world_size} @@ -251,7 +251,7 @@ scheduled_pipeline: loss_fn: instance_key: loss_fn pass_type: BY_REFERENCE - pp_schedule_name: gpipe + pp_schedule_name: Interleaved1F1B batch_size: ${settings.step_profile.local_train_micro_batch_size} microbatch_size: 2 pp_degree: ${device_mesh.config.pipeline_parallel_degree} @@ -318,7 +318,7 @@ staged_pipeline: instance_key: device_mesh pass_type: BY_REFERENCE local_rank: ${settings.cuda_env.local_rank} - pp_schedule_name: gpipe + pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name} num_layers_per_stage: 2 model_raw: @@ -332,7 +332,7 @@ model_raw: sequence_length: ${settings.step_profile.sequence_length} prediction_key: ${loss_fn.config.prediction_key} vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency - n_layer: 2 + n_layer: 6 n_head_q: 8 n_head_kv: 4 ffn_hidden: 128 diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml index d30c6152f..b4982044c 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml @@ -308,7 +308,7 @@ staged_pipeline: instance_key: device_mesh pass_type: BY_REFERENCE local_rank: ${settings.cuda_env.local_rank} - pp_schedule_name: gpipe + pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name} num_layers_per_stage: 2 model_raw: diff --git a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py index b44367d06..65d053970 100644 --- a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py +++ b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py @@ -89,7 +89,8 @@ def _save_checkpoint(self, app_state: AppState, training_progress: TrainingProgr # saving the model via FULL_STATE_DICT and checkpoint via FULL_OPTIM_STATE_DICT model_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) optim_save_policy = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True) - model = app_state.model + assert len(app_state.model_parts) == 1, "FSDP1CheckpointSaving only supports a single model part." + model = app_state.model_parts[0] optimizer = app_state.optimizer with FSDP.state_dict_type( module=model, diff --git a/src/modalities/checkpointing/stateful/app_state.py b/src/modalities/checkpointing/stateful/app_state.py index 6f42074cf..57377bed7 100644 --- a/src/modalities/checkpointing/stateful/app_state.py +++ b/src/modalities/checkpointing/stateful/app_state.py @@ -15,6 +15,8 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler +from modalities.optimizers.optimizer_list import OptimizersList + class StatefulComponents(Enum): MODEL = "model" @@ -34,15 +36,18 @@ class AppState(Stateful): https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html """ - def __init__(self, model: nn.Module, optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None): + def __init__( + self, model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None + ): """Initializes the AppState object. Args: - model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model. + model (nn.Module | list[nn.Module]): The model or model parts can be either + a non-sharded model, FSDP1 or FSDP2 model. optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None. """ - self._model = model + self._model_parts = list(model) if isinstance(model, list) else [model] self._optimizer = optimizer self._lr_scheduler = lr_scheduler self._is_loaded = False @@ -56,8 +61,8 @@ def is_loaded(self) -> bool: return self._is_loaded @property - def model(self) -> nn.Module: - return self._model + def model_parts(self) -> list[nn.Module]: + return self._model_parts @property def optimizer(self) -> Optimizer: @@ -153,7 +158,7 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]: class ModelStateRetriever(StateRetrieverIF): @staticmethod def get_state_dict(app_state: AppState) -> dict[str, Any]: - """Returns the state dict of the model in the AppState object. + """Returns the flattened state dicts of the model parts in the AppState object. Args: app_state (AppState): The app_state object containing the model. @@ -161,7 +166,7 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]: Returns: dict[str, Any]: The state dict of the model in the AppState object. """ - return get_model_state_dict(model=app_state.model) + return {k: v for sd in map(get_model_state_dict, app_state.model_parts) for k, v in sd.items()} @staticmethod def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None: @@ -171,7 +176,8 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None: app_state (AppState): The app_state object containing the model. state_dict (dict[str, Any]): The state dict to load into the model. """ - set_model_state_dict(model=app_state.model, model_state_dict=state_dict, options=StateDictOptions(strict=False)) + for model in app_state.model_parts: + set_model_state_dict(model=model, model_state_dict=state_dict, options=StateDictOptions(strict=False)) class OptimizerStateRetriever(StateRetrieverIF): @@ -185,13 +191,17 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]: Returns: dict[str, Any]: The state dict of the optimizer in the AppState object. """ - sd = get_optimizer_state_dict( - model=app_state.model, - optimizers=app_state.optimizer, - # NOTE: Flattening is required for pipeline parallelism to work correctly. - # see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214 - options=StateDictOptions(flatten_optimizer_state_dict=True), - ) + if isinstance(app_state.optimizer, OptimizersList): + sd = app_state.optimizer.state_dict() + else: + assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." + sd = get_optimizer_state_dict( + model=app_state.model_parts[0], + optimizers=app_state.optimizer, + # NOTE: Flattening is required for pipeline parallelism to work correctly. + # see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214 + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) return sd @staticmethod @@ -202,12 +212,16 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None: app_state (AppState): The app_state object containing the optimizer. state_dict (dict[str, Any]): The state dict to load into the optimizer. """ - set_optimizer_state_dict( - model=app_state.model, - optimizers=app_state.optimizer, - optim_state_dict=state_dict, - options=StateDictOptions(flatten_optimizer_state_dict=True), - ) + if isinstance(app_state.optimizer, OptimizersList): + app_state.optimizer.load_state_dict(state_dict) + else: + assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." + set_optimizer_state_dict( + model=app_state.model_parts[0], + optimizers=app_state.optimizer, + optim_state_dict=state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) class LRSchedulerStateRetriever(StateRetrieverIF): diff --git a/src/modalities/checkpointing/stateful/app_state_factory.py b/src/modalities/checkpointing/stateful/app_state_factory.py index bad48d44c..8f6e63d8a 100644 --- a/src/modalities/checkpointing/stateful/app_state_factory.py +++ b/src/modalities/checkpointing/stateful/app_state_factory.py @@ -15,13 +15,14 @@ class AppStateFactory: @staticmethod def get_raw_app_state( - model: nn.Module, optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None + model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None ) -> AppState: """Creates a new (non-checkpoint loaded) AppState object from an instantiated model, optimizer, and optional learning rate scheduler. Args: - model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model. + model (nn.Module | list[nn.Module]): The model (parts) can be either + a non-sharded model, FSDP1 or FSDP2 model. optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. lr_scheduler (Optional[LRScheduler], optional): Lr scheduler used during training. Defaults to None. diff --git a/src/modalities/config/component_factory.py b/src/modalities/config/component_factory.py index e74ad9729..4da0157bd 100644 --- a/src/modalities/config/component_factory.py +++ b/src/modalities/config/component_factory.py @@ -1,6 +1,7 @@ from typing import Any, Type, TypeVar -from pydantic import BaseModel +from pydantic import AliasChoices, BaseModel +from pydantic.fields import FieldInfo from modalities.registry.registry import Registry from modalities.util import print_rank_0 @@ -164,30 +165,53 @@ def _instantiate_component_config(self, component_key: str, variant_key: str, co config_dict=config_dict, component_config_type=component_config_type, ) - comp_config = component_config_type(**config_dict, strict=True) + comp_config = component_config_type.model_validate(config_dict, extra="forbid") return comp_config def _assert_valid_config_keys( self, component_key: str, variant_key: str, config_dict: dict, component_config_type: Type[BaseModelChild] ) -> None: - required_keys = [] - optional_keys = [] - for key, field in component_config_type.model_fields.items(): + # Collect required and optional keys, including aliases if defined. + required_keys: list[str] = [] + optional_keys: list[str] = [] + # Map aliases to canonical field names for clearer error messages. + alias_to_field: dict[str, str] = {} + + for field_name, field in component_config_type.model_fields.items(): + names_for_field = self._parse_str_aliases(alias_to_field, field_name, field) if field.is_required(): - required_keys.append(key) + required_keys.extend(names_for_field) else: - optional_keys.append(key) + optional_keys.extend(names_for_field) - invalid_keys = [] - for key in config_dict.keys(): - if key not in required_keys and key not in optional_keys: - invalid_keys.append(key) + all_valid_keys = set(required_keys) | set(optional_keys) + + invalid_keys = [key for key in config_dict.keys() if key not in all_valid_keys] if len(invalid_keys) > 0: message = f"Invalid keys {invalid_keys} for config `{component_key}.{variant_key}`" message += f" of type {component_config_type}:\n{config_dict}\n" - message += f"Required keys: {required_keys}\nOptional keys: {optional_keys}" + if alias_to_field: + message += f"Alias to field mapping: {alias_to_field}\n" + message += f"Required keys (including aliases): {required_keys}\n" + message += f"Optional keys (including aliases): {optional_keys}\n" raise ValueError(message) + def _parse_str_aliases(self, alias_to_field: dict[str, str], field_name: str, field: FieldInfo) -> set[str]: + names_for_field = {field_name} + if field.alias and field.alias != field_name: + names_for_field.add(field.alias) + alias_to_field[field.alias] = field_name + if field.validation_alias and field.validation_alias != field_name: + if isinstance(field.validation_alias, str): + names_for_field.add(field.validation_alias) + alias_to_field[field.validation_alias] = field_name + elif isinstance(field.validation_alias, AliasChoices): + for alias in field.validation_alias.choices: + if isinstance(alias, str): + names_for_field.add(alias) + alias_to_field[alias] = field_name + return names_for_field + def _instantiate_component(self, component_key: str, variant_key: str, component_config: BaseModel) -> Any: component_type: Type = self.registry.get_component(component_key, variant_key) component_config_dict = self._base_model_to_dict(component_config) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 5ae8c0822..37e52a8fa 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -27,6 +27,7 @@ PydanticModelInitializationIFType, PydanticOptimizerIFType, PydanticPytorchDeviceType, + PydanticPytorchModuleOrListType, PydanticPytorchModuleType, PydanticSamplerIFType, PydanticTokenizerIFType, @@ -43,6 +44,7 @@ ActivationCheckpointingVariants, ) from modalities.util import parse_enum_by_name +from modalities.utils.deprecated_alias import add_deprecated_alias class ProcessGroupBackendType(LookupEnum): @@ -145,7 +147,7 @@ class CheckpointSavingConfig(BaseModel): class AdamOptimizerConfig(BaseModel): lr: float - wrapped_model: PydanticPytorchModuleType + wrapped_model: PydanticPytorchModuleOrListType betas: tuple[float, float] eps: float weight_decay: float @@ -154,7 +156,7 @@ class AdamOptimizerConfig(BaseModel): class AdamWOptimizerConfig(BaseModel): lr: float - wrapped_model: PydanticPytorchModuleType + wrapped_model: PydanticPytorchModuleOrListType betas: tuple[float, float] eps: float weight_decay: float @@ -264,7 +266,7 @@ def parse_sharding_strategy_by_name(cls, name: str) -> ShardingStrategy: class FSDP2WrappedModelConfig(BaseModel): - model: PydanticPytorchModuleType + model: PydanticPytorchModuleOrListType block_names: list[str] mixed_precision_settings: FSDP2MixedPrecisionSettings reshard_after_forward: bool = True @@ -289,7 +291,7 @@ def validate_dp_mesh_existence(self): class DebuggingEnrichedModelConfig(BaseModel): - model: PydanticPytorchModuleType + model: PydanticPytorchModuleOrListType logging_dir_path: Path tracked_ranks: Optional[Set[int]] = None log_interval_steps: Optional[int] = 1 @@ -302,7 +304,7 @@ def convert_list_to_set(cls, v: Iterable[int] | None) -> Set[int] | None: class GPT2ModelTPConfig(BaseModel): - model: PydanticPytorchModuleType # TODO set proper type + model: PydanticPytorchModuleOrListType # TODO set proper type device_mesh: PydanticDeviceMeshIFType @model_validator(mode="after") @@ -325,7 +327,7 @@ class CompiledModelConfig(BaseModel): class WeightInitializedModelConfig(BaseModel): - model: PydanticPytorchModuleType + model: PydanticPytorchModuleOrListType model_initializer: PydanticModelInitializationIFType # avoid warning about protected namespace 'model_', see @@ -350,12 +352,12 @@ class SelectiveOpACParams(BaseModel): ac_variant: ActivationCheckpointingVariants layers_fqn: str - model: PydanticPytorchModuleType + model: PydanticPytorchModuleOrListType ac_fun_params: FullACParams | SelectiveLayerACParams | SelectiveOpACParams class RawAppStateConfig(BaseModel): - model: PydanticPytorchModuleType + model: PydanticPytorchModuleOrListType optimizer: PydanticOptimizerIFType lr_scheduler: Optional[PydanticLRSchedulerIFType] = None @@ -480,12 +482,13 @@ class RichResultSubscriberConfig(BaseModel): global_rank: int +@add_deprecated_alias("model_parts", "wrapped_model") class GPT2MFUCalculatorConfig(BaseModel): n_layer: Annotated[int, Field(strict=True, gt=0)] sequence_length: Annotated[int, Field(strict=True, gt=0)] n_embd: Annotated[int, Field(strict=True, gt=0)] world_size: Annotated[int, Field(strict=True, gt=0)] - wrapped_model: PydanticFSDP1ModuleType | PydanticFSDP2ModuleType + model_parts: PydanticFSDP1ModuleType | PydanticFSDP2ModuleType | list[PydanticFSDP2ModuleType] device_mesh: Optional[PydanticDeviceMeshIFType] = None diff --git a/src/modalities/config/pydantic_if_types.py b/src/modalities/config/pydantic_if_types.py index 35610246c..c465c6671 100644 --- a/src/modalities/config/pydantic_if_types.py +++ b/src/modalities/config/pydantic_if_types.py @@ -66,6 +66,7 @@ def __get_pydantic_core_schema__( CheckpointSavingExecutionABC, PydanticThirdPartyTypeIF(CheckpointSavingExecutionABC) ] PydanticPytorchModuleType = Annotated[nn.Module, PydanticThirdPartyTypeIF(nn.Module)] +PydanticPytorchModuleOrListType = PydanticPytorchModuleType | list[PydanticPytorchModuleType] PydanticFSDP1ModuleType = Annotated[FSDP1, PydanticThirdPartyTypeIF(FSDP1)] PydanticFSDP2ModuleType = Annotated[FSDP2, PydanticThirdPartyTypeIF(FSDP2)] PydanticTokenizerIFType = Annotated[TokenizerWrapper, PydanticThirdPartyTypeIF(TokenizerWrapper)] diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 49bdaa11d..cd1470773 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -35,7 +35,7 @@ def __init__( def evaluate_batch( self, batch: DatasetBatch, - model: nn.Module, + model: list[nn.Module], loss_fun: Callable[[InferenceResultBatch], torch.Tensor], scheduled_pipeline: Pipeline | None = None, ) -> torch.Tensor | None: @@ -43,7 +43,7 @@ def evaluate_batch( Args: batch (DatasetBatch): The batch to evaluate - model (nn.Module): The model to evaluate + model (list[nn.Module]): The model (parts) to evaluate loss_fun (Callable[[InferenceResultBatch], torch.Tensor]): The loss function to calculate the loss scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to operate the model. Defaults to None. @@ -57,27 +57,27 @@ def evaluate_batch( pp_schedule = scheduled_pipeline.pp_schedule targets, losses = ( (batch.targets[loss_fun.target_key].contiguous(), []) - if scheduled_pipeline.is_last_pp_stage + if scheduled_pipeline.has_last_pp_stage else (None, None) ) - if scheduled_pipeline.is_first_pp_stage: - pp_schedule.eval(batch.samples[model.sample_key].contiguous(), target=targets, losses=losses) + if scheduled_pipeline.has_first_pp_stage: + pp_schedule.eval(batch.samples[model[0].sample_key].contiguous(), target=targets, losses=losses) else: pp_schedule.eval(target=targets, losses=losses) loss = ( torch.mean(torch.stack(losses)).to(losses[0].device) - if scheduled_pipeline.is_last_pp_stage + if scheduled_pipeline.has_last_pp_stage else None ) else: - result_batch = model_predict_batch(model=model, batch=batch) + result_batch = model_predict_batch(model=model[0], batch=batch) loss = loss_fun(result_batch) return loss def evaluate( self, - model: nn.Module, + model: list[nn.Module] | nn.Module, data_loaders: list[LLMDataLoader], loss_fun: Callable[[InferenceResultBatch], torch.Tensor], num_train_steps_done: int, @@ -86,7 +86,7 @@ def evaluate( """Evaluate the model on a set of datasets. Args: - model (nn.Module): The model to evaluate + model (list[nn.Module] | nn.Module): The model or model parts to evaluate data_loaders (list[LLMDataLoader]): List of dataloaders to evaluate the model on loss_fun (Callable[[InferenceResultBatch], torch.Tensor]): The loss function to calculate the loss num_train_steps_done (int): The number of training steps done so far for logging purposes @@ -97,7 +97,11 @@ def evaluate( dict[str, EvaluationResultBatch]: A dictionary containing the evaluation results for each dataloader """ result_dict: dict[str, EvaluationResultBatch] = {} - model.eval() + if not isinstance(model, list): + assert scheduled_pipeline is None, "A non-scheduled pipeline should be processed with a single model." + model = [model] + for m in model: + m.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -165,7 +169,8 @@ def evaluate( ) result_dict[data_loader.dataloader_tag] = evaluation_result - model.train() + for m in model: + m.train() return result_dict diff --git a/src/modalities/gym.py b/src/modalities/gym.py index 5ecec0d96..de56918c2 100644 --- a/src/modalities/gym.py +++ b/src/modalities/gym.py @@ -58,7 +58,7 @@ def run( """ evaluation_callback: Callable[[int], None] = partial( self._run_evaluation, - model=app_state.model, + model=app_state.model_parts, evaluation_data_loaders=evaluation_data_loaders, evaluation_interval_in_steps=evaluation_interval_in_steps, scheduled_pipeline=scheduled_pipeline, @@ -103,13 +103,13 @@ def _run_checkpointing( def _run_evaluation( self, - model: nn.Module, + model: list[nn.Module] | nn.Module, num_train_steps_done: int, evaluation_data_loaders: list[LLMDataLoader], evaluation_interval_in_steps: int, scheduled_pipeline: Pipeline | None = None, ): - if num_train_steps_done % evaluation_interval_in_steps == 0: + if num_train_steps_done % evaluation_interval_in_steps == 0 and num_train_steps_done > 0: self.evaluator.evaluate( model=model, data_loaders=evaluation_data_loaders, diff --git a/src/modalities/main.py b/src/modalities/main.py index 85a80731e..babc47db0 100644 --- a/src/modalities/main.py +++ b/src/modalities/main.py @@ -195,7 +195,7 @@ def run(self, components: TrainingComponentsInstantiationModel): loss_fun=components.loss_fn, num_ranks=components.settings.cuda_env.world_size, ) - num_params = get_total_number_of_trainable_parameters(components.app_state.model, components.device_mesh) + num_params = get_total_number_of_trainable_parameters(components.app_state.model_parts, components.device_mesh) components.evaluation_subscriber.consume_dict({"No. parameters": num_params}) logger.info(f"Training model with {num_params} parameters.") diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 3acb17f95..684ce6323 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -189,7 +189,7 @@ def get_fsdp2_wrapped_model( f"{get_local_number_of_trainable_parameters(model)}" ) # map the block names to the actual block class (e.b., GPT2Block) - block_types = tuple([get_module_class_from_name(model, b) for b in block_names]) + block_types = tuple([t for b in block_names if (t := get_module_class_from_name(model, b)) is not None]) mp_policy = MixedPrecisionPolicy( param_dtype=mixed_precision_settings.param_dtype.value, diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py index 1e86a8dcd..784dddab9 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism.py +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -5,13 +5,19 @@ import copy import re from enum import Enum -from typing import Any, Optional, Type, cast +from typing import Any, Iterable, Optional, Type, cast import torch import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh from torch.distributed.pipelining import PipelineStage -from torch.distributed.pipelining.schedules import PipelineScheduleSingle, get_schedule_class +from torch.distributed.pipelining.schedules import ( + PipelineScheduleMulti, + PipelineScheduleSingle, + ScheduleDualPipeV, + ScheduleZBVZeroBubble, + get_schedule_class, +) from modalities.loss_functions import Loss from modalities.models.model import NNModel @@ -25,36 +31,36 @@ class Pipeline: def __init__( self, - pp_stage: PipelineStage, - model_part: nn.Module, - pp_schedule: Optional[PipelineScheduleSingle] = None, + pp_stages: Iterable[PipelineStage], + model_parts: Iterable[nn.Module], + pp_schedule: Optional[PipelineScheduleSingle | PipelineScheduleMulti] = None, ): - self._pp_stage = pp_stage - self._model_part = model_part + self._pp_stages = list(pp_stages) + self._model_parts = list(model_parts) self._pp_schedule = pp_schedule @property - def is_first_pp_stage(self) -> bool: - return self._pp_stage.is_first + def has_first_pp_stage(self) -> bool: + return any(stage.is_first for stage in self._pp_stages) @property - def is_last_pp_stage(self) -> bool: - return self._pp_stage.is_last + def has_last_pp_stage(self) -> bool: + return any(stage.is_last for stage in self._pp_stages) @property - def pp_stage(self) -> PipelineStage: - return self._pp_stage + def pp_stages(self) -> list[PipelineStage]: + return self._pp_stages @property - def model_part(self) -> nn.Module: - return self._model_part + def model_parts(self) -> list[nn.Module]: + return self._model_parts @property - def pp_schedule(self) -> Optional[PipelineScheduleSingle]: + def pp_schedule(self) -> Optional[PipelineScheduleSingle | PipelineScheduleMulti]: return self._pp_schedule @pp_schedule.setter - def pp_schedule(self, schedule: PipelineScheduleSingle): + def pp_schedule(self, schedule: PipelineScheduleSingle | PipelineScheduleMulti): self._pp_schedule = schedule @@ -68,12 +74,14 @@ class PipelineSelectionTypes(Enum): class ComponentSelectorFromPipeline: @staticmethod - def select(pipeline: Pipeline, selection_type: PipelineSelectionTypes) -> Any: + def select( + pipeline: Pipeline, selection_type: PipelineSelectionTypes + ) -> list[PipelineStage] | list[nn.Module] | PipelineScheduleSingle | PipelineScheduleMulti | None: """Selects a component from the pipeline based on the selection type.""" if selection_type == PipelineSelectionTypes.PP_STAGE: - return pipeline.pp_stage + return pipeline.pp_stages elif selection_type == PipelineSelectionTypes.MODEL_PART: - return pipeline.model_part + return pipeline.model_parts elif selection_type == PipelineSelectionTypes.PP_SCHEDULE: return pipeline.pp_schedule else: @@ -85,9 +93,9 @@ class PipelineFactory: @staticmethod def get_pipeline( - pp_stage: PipelineStage, model_part: NNModel, pp_schedule: Optional[PipelineScheduleSingle] = None + pp_stages: list[PipelineStage], model_parts: list[NNModel], pp_schedule: Optional[PipelineScheduleSingle] = None ) -> Pipeline: - return Pipeline(pp_stage=pp_stage, model_part=model_part, pp_schedule=pp_schedule) + return Pipeline(pp_stages=pp_stages, model_parts=model_parts, pp_schedule=pp_schedule) @staticmethod def get_staged_pipeline( @@ -107,17 +115,9 @@ def get_staged_pipeline( ) pp_mesh = device_mesh[ParallelismDegrees.PP.value] - schedule_class = get_schedule_class(pp_schedule_name) - is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) - if not is_single_stage_schedule: - raise ValueError( - f"Unsupported pipeline schedule: {pp_schedule_name}. We only support single-stage schedules." - ) - # torchtitan returns tuple of stages and models as depending on the schedule - # we might have multiple stages and model parts per rank. - # So far we don't support multi-stage schedules, which is why instead of tuples - # we work directly with the stage and model. - pp_stage, model_part = PipelineFactory._get_split_model( + schedule_class: Type[PipelineScheduleSingle | PipelineScheduleMulti] = get_schedule_class(pp_schedule_name) + + pp_stages, model_parts = PipelineFactory._get_split_model( whole_model=whole_model, schedule_class=schedule_class, pp_mesh=pp_mesh, @@ -125,122 +125,55 @@ def get_staged_pipeline( fqns_per_stage=fqns_per_stage, ) - pipeline = Pipeline(pp_stage=pp_stage, model_part=model_part) + pipeline = Pipeline(pp_stages=pp_stages, model_parts=model_parts) return pipeline @staticmethod def _get_split_model( whole_model: NNModel, - schedule_class: Type[PipelineScheduleSingle], + schedule_class: Type[PipelineScheduleSingle | PipelineScheduleMulti], pp_mesh: DeviceMesh, device: torch.device, fqns_per_stage: list[list[str]], - ) -> tuple[PipelineStage, NNModel]: - def get_stage_id_of_pp_rank(pp_mesh: DeviceMesh): - # NOTE: torch titan a more complicated way to get the stage id of pp rank - # since they also allow for multi-stage schedules - pp_rank = pp_mesh.get_local_rank() - return pp_rank - - @staticmethod - def _get_fqn_tree(fqns: list[str]) -> dict[str, Any]: - fqn_tree = {} - fqns = set(fqns) # Ensure unique FQNs - for fqn in fqns: - parts = fqn.split(".") - current_level = fqn_tree - for part in parts[:-1]: - if part not in current_level: - current_level[part] = {} - elif len(current_level) == 0: - raise ValueError(f"Part {part} of {fqn} already exists " "in the tree as a leaf node.") - current_level = current_level[part] - if parts[-1] in current_level: - raise ValueError( - f" Leaf of {fqn} has already been defined in the tree as an intermediate node or leaf! " - "Cannot replace the existing node as a leaf." - ) - current_level[parts[-1]] = {} - - return fqn_tree - - def _build_stage_from_modules( - fqn_tree: dict[str, Any], module: nn.Module, module_name: Optional[str] = None - ) -> nn.Module: - if isinstance(module, nn.ModuleDict): - if module_name not in fqn_tree: - dict_modules = nn.ModuleDict({}) - else: - if len(fqn_tree) == 0: - # If the module is a leaf node, we can directly use it - dict_modules = module - else: - # If the module is not a leaf node, we need to build a staged module - # recursively from the FQN tree - dict_modules = {} - dict_module_names = [name for name in module.keys() if name in fqn_tree[module_name]] - for dict_module_name in dict_module_names: - dict_modules[dict_module_name] = _build_stage_from_modules( - fqn_tree=fqn_tree[module_name], - module=module[dict_module_name], - module_name=dict_module_name, - ) - dict_modules = nn.ModuleDict(dict_modules) - # setattr(module, module_name, dict_modules) - return dict_modules - - elif isinstance(module, nn.ModuleList): - if module_name not in fqn_tree: - list_modules = nn.ModuleList([]) - else: - if len(fqn_tree[module_name]) == 0: - # If the module is a leaf node, we can directly use it - list_modules = module - else: - # If the module is not a leaf node, we need to build a staged module - # recursively from the FQN tree - list_modules = [] - list_indices = [i for i in range(len(module)) if str(i) in fqn_tree[module_name]] - for idx in list_indices: - list_modules.append( - _build_stage_from_modules( - fqn_tree=fqn_tree[module_name], module=module[idx], module_name=str(idx) - ) - ) - list_modules = nn.ModuleList(list_modules) - # setattr(module, module_name, list_modules) - return list_modules - - else: # normal nn.Module - if module_name is not None and module_name not in fqn_tree: - # If the module is not in the FQN tree, set it to None - return None - elif module_name is not None and len(fqn_tree[module_name]) == 0: - # If the module is a leaf node, we can directly use it - return module - else: - # If the module is in the FQN tree, we need to build a staged module - # recursively from the FQN tree - for module_name, module_value in module.named_children(): - # If the module is not a leaf node, we need to build a staged module - # recursively from the FQN tree - staged_module = _build_stage_from_modules( - fqn_tree=fqn_tree, module=module_value, module_name=module_name - ) - setattr(module, module_name, staged_module) - - return module + ) -> tuple[list[PipelineStage], list[NNModel]]: + num_stages = len(fqns_per_stage) + stage_indices = PipelineFactory._get_stage_ids_of_pp_rank(pp_mesh, num_stages, schedule_class) + stages, stage_modules = zip( + *( + PipelineFactory._build_model_part_for_stage(whole_model, pp_mesh, device, fqns_per_stage, stage_idx) + for stage_idx in stage_indices + ) + ) + return list(stages), list(stage_modules) - if not issubclass(schedule_class, PipelineScheduleSingle): - raise NotImplementedError("Only single-stage schedules are supported for pipeline parallelism.") + @staticmethod + def _get_stage_ids_of_pp_rank( + pp_mesh: DeviceMesh, + num_stages: int, + schedule_class: Type[PipelineScheduleSingle | PipelineScheduleMulti], + ) -> list[int]: + style = "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop" + pp_size = pp_mesh.size() + pp_rank = pp_mesh.get_local_rank() + stages_per_rank = num_stages // pp_size + if style == "loop": + return [pp_rank + s * pp_size for s in range(stages_per_rank)] + elif style == "v": + if stages_per_rank != 2: + raise ValueError(f"v schedules assume 2 stages per rank but got {stages_per_rank}.") + stage_v_pairs = list(zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1))) + return list(stage_v_pairs[pp_rank]) + else: + raise ValueError(f"Unsupported schedule style: {style}") - # NOTE: For multi-stage schedule, e.g., Interleaved 1F1B, we have multiple stages per pp rank. - # This would need to be adapted accordingly in this case. - stage_idx = get_stage_id_of_pp_rank(pp_mesh) + @staticmethod + def _build_model_part_for_stage( + whole_model: NNModel, pp_mesh: DeviceMesh, device: torch.device, fqns_per_stage: list[list[str]], stage_idx: int + ) -> tuple[PipelineStage, NNModel]: module_names = fqns_per_stage[stage_idx] whole_model = copy.deepcopy(whole_model) - fqn_tree = _get_fqn_tree(module_names) - stage_modules = _build_stage_from_modules(fqn_tree, whole_model) + fqn_tree = PipelineFactory._get_fqn_tree(module_names) + stage_modules = PipelineFactory._build_stage_from_modules(fqn_tree, whole_model) stage_modules = cast(NNModel, stage_modules) PipelineFactory._filter_weight_decay_groups_(stage_modules) stage = PipelineStage( @@ -250,8 +183,99 @@ def _build_stage_from_modules( device=device, group=pp_mesh.get_group("pp"), ) + return stage, stage_modules + @staticmethod + def _get_fqn_tree(fqns: list[str]) -> dict[str, Any]: + fqn_tree: dict[str, Any] = {} + fqns = set(fqns) # Ensure unique FQNs + for fqn in fqns: + parts = fqn.split(".") + current_level = fqn_tree + for part in parts[:-1]: + if part not in current_level: + current_level[part] = {} + elif len(current_level) == 0: + raise ValueError(f"Part {part} of {fqn} already exists " "in the tree as a leaf node.") + current_level = current_level[part] + if parts[-1] in current_level: + raise ValueError( + f" Leaf of {fqn} has already been defined in the tree as an intermediate node or leaf! " + "Cannot replace the existing node as a leaf." + ) + current_level[parts[-1]] = {} + + return fqn_tree + + @staticmethod + def _build_stage_from_modules( + fqn_tree: dict[str, Any], module: nn.Module, module_name: Optional[str] = None + ) -> nn.Module: + if isinstance(module, nn.ModuleDict): + if module_name not in fqn_tree: + dict_modules = nn.ModuleDict({}) + else: + if len(fqn_tree) == 0: + # If the module is a leaf node, we can directly use it + dict_modules = module + else: + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + dict_modules = {} + dict_module_names = [name for name in module.keys() if name in fqn_tree[module_name]] + for dict_module_name in dict_module_names: + dict_modules[dict_module_name] = PipelineFactory._build_stage_from_modules( + fqn_tree=fqn_tree[module_name], + module=module[dict_module_name], + module_name=dict_module_name, + ) + dict_modules = nn.ModuleDict(dict_modules) + # setattr(module, module_name, dict_modules) + return dict_modules + + elif isinstance(module, nn.ModuleList): + if module_name not in fqn_tree: + list_modules = nn.ModuleList([]) + else: + if len(fqn_tree[module_name]) == 0: + # If the module is a leaf node, we can directly use it + list_modules = module + else: + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + list_modules = [] + list_indices = [i for i in range(len(module)) if str(i) in fqn_tree[module_name]] + for idx in list_indices: + list_modules.append( + PipelineFactory._build_stage_from_modules( + fqn_tree=fqn_tree[module_name], module=module[idx], module_name=str(idx) + ) + ) + list_modules = nn.ModuleList(list_modules) + # setattr(module, module_name, list_modules) + return list_modules + + else: # normal nn.Module + if module_name is not None and module_name not in fqn_tree: + # If the module is not in the FQN tree, set it to None + return None + elif module_name is not None and len(fqn_tree[module_name]) == 0: + # If the module is a leaf node, we can directly use it + return module + else: + # If the module is in the FQN tree, we need to build a staged module + # recursively from the FQN tree + for module_name, module_value in module.named_children(): + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + staged_module = PipelineFactory._build_stage_from_modules( + fqn_tree=fqn_tree, module=module_value, module_name=module_name + ) + setattr(module, module_name, staged_module) + + return module + @staticmethod def _filter_weight_decay_groups_(stage_modules: NNModel): params = {name for name, parameter in stage_modules.named_parameters() if parameter.requires_grad} @@ -274,15 +298,41 @@ def get_scheduled_pipeline( # TODO: Addd validation in config that batch_size is divisible by microbatch_size # and n_microbatches must be >= pp_degree n_microbatches = batch_size // microbatch_size - num_total_stages = pp_degree - pp_schedule_class = get_schedule_class(pp_schedule_name) - pp_schedule = pp_schedule_class( - stage=pipeline.pp_stage, - n_microbatches=n_microbatches, - loss_fn=loss_fn, - ) + num_total_stages = pp_degree * len(pipeline.pp_stages) + pp_schedule = PipelineFactory._build_pp_schedule(loss_fn, pp_schedule_name, n_microbatches, pipeline.pp_stages) logger.info( f"Using pipeline schedule {pp_schedule} with {n_microbatches} microbatches and {num_total_stages} stages." ) pipeline.pp_schedule = pp_schedule return pipeline + + @staticmethod + def _build_pp_schedule( + loss_fn: Loss, + pp_schedule_name: str, + n_microbatches: int, + pp_stage_or_stages: PipelineStage | list[PipelineStage], + ) -> PipelineScheduleSingle | PipelineScheduleMulti: + pp_schedule_class: Type[PipelineScheduleSingle | PipelineScheduleMulti] = get_schedule_class(pp_schedule_name) + if issubclass(pp_schedule_class, PipelineScheduleSingle): + if isinstance(pp_stage_or_stages, list): + assert len(pp_stage_or_stages) == 1, ( + f"Expected a single PipelineStage for single-stage schedule " + f"but got {len(pp_stage_or_stages)} stages." + ) + pp_stage_or_stages = pp_stage_or_stages[0] + pp_schedule = pp_schedule_class( + stage=pp_stage_or_stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + ) + elif issubclass(pp_schedule_class, PipelineScheduleMulti): + assert isinstance(pp_stage_or_stages, list), "Expected a list of PipelineStages for multi-stage schedule." + pp_schedule = pp_schedule_class( + stages=pp_stage_or_stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + ) + else: + raise ValueError(f"Unsupported pipeline schedule class: {pp_schedule_class}.") + return pp_schedule diff --git a/src/modalities/models/parallelism/pipeline_parallelism_configs.py b/src/modalities/models/parallelism/pipeline_parallelism_configs.py index 831a6e15e..ec16cdac0 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism_configs.py +++ b/src/modalities/models/parallelism/pipeline_parallelism_configs.py @@ -11,6 +11,7 @@ PydanticStagesGeneratorType, ) from modalities.models.parallelism.pipeline_parallelism import PipelineSelectionTypes +from modalities.utils.deprecated_alias import add_deprecated_alias class FQNsPerStageGeneratorConfig(BaseModel): # TODO duplicate @@ -40,7 +41,9 @@ class ComponentSelectorFromPipelineConfig(BaseModel): selection_type: PipelineSelectionTypes +@add_deprecated_alias("pp_stages", "pp_stage") +@add_deprecated_alias("model_parts", "model_part") class PipelineConfig(BaseModel): - pp_stage: PydanticPipelineStageType - model_part: PydanticPytorchModuleType + pp_stages: list[PydanticPipelineStageType] + model_parts: list[PydanticPytorchModuleType] pp_schedule: PydanticPipelineType | None = None diff --git a/src/modalities/models/parallelism/stages_generator.py b/src/modalities/models/parallelism/stages_generator.py index 0a212672a..38a75051b 100644 --- a/src/modalities/models/parallelism/stages_generator.py +++ b/src/modalities/models/parallelism/stages_generator.py @@ -14,7 +14,7 @@ def __init__(self, num_model_layers: int, input_layer_equivalence: int = 1, outp def get_stages(self, num_layers_per_stage: int, pp_dims: int) -> list[list[str]]: """ - Generate FQNs for each stage in a GPT-2 model. + Generate FQNs for each stage in a model. Args: num_layers_per_stage (int): Number of layers per stage. @@ -36,13 +36,6 @@ def get_stages(self, num_layers_per_stage: int, pp_dims: int) -> list[list[str]] f"{self._output_layer_equivalence=} {num_layers_per_stage=}" ) - stages_per_rank = num_virtual_stages // pp_dims - if stages_per_rank != 1: - raise ValueError( - f"Stages per rank {stages_per_rank} must be 1 for single-stage schedules. " - f"Please adjust {num_layers_per_stage=} to ensure each PP rank has exactly one stage." - ) - # Potential split points for GPT-2 model with each potential split point # listing the FQNs of the modules in that stage and the computational weight. # The computational weight of the input and output modules are estimated @@ -112,7 +105,10 @@ def _get_potential_split_points( # The computational weight of the input and output modules are estimated # based on the number of layers they correspond to. potential_split_points = [ - (["transformer.wte", "transformer.wpe", "transformer.drop"], self._input_layer_equivalence), + ( # FIXME wpe and drop probably should not get the higher weight + ["transformer.wte", "transformer.wpe", "transformer.drop"], + self._input_layer_equivalence, + ), *[([f"transformer.h.{i}"], 1) for i in range(self._num_model_layers)], (["transformer.lm_head_norm", "transformer.lm_head"], self._output_layer_equivalence), ] diff --git a/src/modalities/optimizers/optimizer_list.py b/src/modalities/optimizers/optimizer_list.py new file mode 100644 index 000000000..bbe51834f --- /dev/null +++ b/src/modalities/optimizers/optimizer_list.py @@ -0,0 +1,52 @@ +# This file contains code adapted from: +# https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/optimizer.py +# which is licensed under the BSD 3-Clause "New" or "Revised" License: +# https://github.com/pytorch/torchtitan/blob/main/LICENSE + +import functools +from typing import Any, Iterable + +from torch import nn +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict, set_optimizer_state_dict +from torch.distributed.checkpoint.stateful import Stateful +from torch.optim import Optimizer +from torch.optim.optimizer import ParamsT + + +class OptimizersList(Optimizer, Stateful, list[Optimizer]): + """Class to handle multiple optimizers for different model parts. + Particular relevant for pipeline parallelism, where each stage has its own optimizer. + This class wraps a list of optimizers and provides a unified interface to step, zero_grad, + state_dict and load_state_dict. + """ + + def __init__(self, model_parts: Iterable[nn.Module], optimizers: Iterable[Optimizer]): + list.__init__(self, optimizers) + self._model_parts = list(model_parts) + assert len(self) > 0, "OptimizersList requires at least one optimizer" + assert len(self._model_parts) == len(self), "Number of model parts must match number of optimizers" + all_params: ParamsT = [p for model in self._model_parts for p in model.parameters() if p.requires_grad] + Optimizer.__init__(self, all_params, dict()) + + def step(self, *args, **kwargs): + for optimizer in self: + optimizer.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs): + for optimizer in self: + optimizer.zero_grad(*args, **kwargs) + + def state_dict(self) -> list[dict[str, Any]]: + func = functools.partial( + get_optimizer_state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + return {k: v for sd in map(func, self._model_parts, self) for k, v in sd.items()} + + def load_state_dict(self, state_dict: dict[str, Any]): + func = functools.partial( + set_optimizer_state_dict, + optim_state_dict=state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + list(map(func, self._model_parts, self)) diff --git a/src/modalities/optimizers/scheduler_list.py b/src/modalities/optimizers/scheduler_list.py new file mode 100644 index 000000000..86468e14a --- /dev/null +++ b/src/modalities/optimizers/scheduler_list.py @@ -0,0 +1,46 @@ +# This file contains code adapted from: +# https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/lr_scheduler.py +# which is licensed under the BSD 3-Clause "New" or "Revised" License: +# https://github.com/pytorch/torchtitan/blob/main/LICENSE + +import copy +from typing import Any, Iterable + +from torch.distributed.checkpoint.stateful import Stateful +from torch.optim.lr_scheduler import LRScheduler + + +class SchedulerList(LRScheduler, Stateful, list[LRScheduler]): + """A list of learning rate schedulers that can be treated as a single scheduler. + Each scheduler in the list should correspond to an optimizer in a multi-optimizer setup. + NOTE: Similar to torchtitan, this class assumes that all schedulers have the same state. + """ + + def __init__(self, schedulers: Iterable[LRScheduler]): + list.__init__(self, schedulers) + assert len(self) > 0, "SchedulerList requires at least one scheduler" + + def state_dict(self) -> dict[str, Any]: + return self[0].state_dict() + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + for scheduler in self: + scheduler.load_state_dict(copy.deepcopy(state_dict)) + + def get_last_lr(self): + return self[0].get_last_lr() + + def get_lr(self): + return self[0].get_lr() + + def step(self, epoch: int | None = None): + for scheduler in self: + scheduler.step(epoch) + + @property + def base_lrs(self): + return self[0].base_lrs + + @property + def last_epoch(self): + return self[0].last_epoch diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 1eeecc841..a629e84d9 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -1,9 +1,11 @@ from dataclasses import dataclass -from typing import Callable, Type +from typing import Any, Callable, Type import torch import torch.nn as nn from pydantic import BaseModel +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import BatchSampler, DistributedSampler, SequentialSampler from modalities.checkpointing.checkpoint_saving import CheckpointSaving @@ -107,6 +109,8 @@ ) from modalities.optimizers.lr_schedulers import DummyLRScheduler from modalities.optimizers.optimizer_factory import OptimizerFactory +from modalities.optimizers.optimizer_list import OptimizersList +from modalities.optimizers.scheduler_list import SchedulerList from modalities.running_env.fsdp.device_mesh import DeviceMeshConfig, get_device_mesh, get_parallel_degree from modalities.tokenization.tokenizer_wrapper import PreTrainedHFTokenizer, PreTrainedSPTokenizer from modalities.training.gradient_clipping.fsdp_gradient_clipper import ( @@ -123,6 +127,7 @@ ) from modalities.utils.debug_components import Debugging, HookRegistration from modalities.utils.debugging_configs import DebuggingConfig, NaNHookConfig, PrintForwardHookConfig +from modalities.utils.maybe_list_parameter import MaybeListDecorator, maybe_list_parameter from modalities.utils.mfu import GPT2MFUCalculator from modalities.utils.number_conversion import ( LocalNumBatchesFromNumSamplesConfig, @@ -140,6 +145,14 @@ from modalities.utils.profilers.steppable_component_configs import SteppableForwardPassConfig from modalities.utils.profilers.steppable_components import SteppableForwardPass +maybe_model_list: MaybeListDecorator[nn.Module, ..., Any, None] = maybe_list_parameter("model") +maybe_model_list_for_optimizer: MaybeListDecorator[nn.Module, ..., Optimizer, OptimizersList] = maybe_list_parameter( + "wrapped_model", apply_to_list_input_and_result=OptimizersList +) +maybe_optimizer_list: MaybeListDecorator[Optimizer, ..., LRScheduler, SchedulerList] = maybe_list_parameter( + "optimizer", apply_to_list_result=SchedulerList +) + @dataclass class ComponentEntity: @@ -157,14 +170,16 @@ class ComponentEntity: component_key: str variant_key: str - component_type: Type | Callable + component_type: Type[Any] | Callable[..., Any] component_config_type: Type[BaseModel] COMPONENTS = [ # models ComponentEntity("model", "gpt2", GPT2ModelFactory.get_gpt2_model, GPT2LLMConfig), - ComponentEntity("model", "gpt2_tp", GPT2ModelFactory.get_gpt2_tensor_parallelized_model, GPT2ModelTPConfig), + ComponentEntity( + "model", "gpt2_tp", maybe_model_list(GPT2ModelFactory.get_gpt2_tensor_parallelized_model), GPT2ModelTPConfig + ), ComponentEntity( "model", "huggingface_pretrained_model", HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig ), @@ -172,9 +187,14 @@ class ComponentEntity: "model", "fsdp1_checkpointed", ModelFactory.get_fsdp1_checkpointed_model, FSDP1CheckpointedModelConfig ), ComponentEntity("model", "fsdp1_wrapped", ModelFactory.get_fsdp1_wrapped_model, FSDPWrappedModelConfig), - ComponentEntity("model", "fsdp2_wrapped", ModelFactory.get_fsdp2_wrapped_model, FSDP2WrappedModelConfig), ComponentEntity( - "model", "model_initialized", ModelFactory.get_weight_initialized_model, WeightInitializedModelConfig + "model", "fsdp2_wrapped", maybe_model_list(ModelFactory.get_fsdp2_wrapped_model), FSDP2WrappedModelConfig + ), + ComponentEntity( + "model", + "model_initialized", + maybe_model_list(ModelFactory.get_weight_initialized_model), + WeightInitializedModelConfig, ), ComponentEntity( "model", @@ -185,13 +205,16 @@ class ComponentEntity: ComponentEntity( "model", "activation_checkpointed", - ModelFactory.get_activation_checkpointed_fsdp2_model_, + maybe_model_list(ModelFactory.get_activation_checkpointed_fsdp2_model_), ActivationCheckpointedModelConfig, ), ComponentEntity("model", "compiled", ModelFactory.get_compiled_model, CompiledModelConfig), ComponentEntity("model", "coca", CoCa, CoCaConfig), ComponentEntity( - "model", "debugging_enriched", ModelFactory.get_debugging_enriched_model, DebuggingEnrichedModelConfig + "model", + "debugging_enriched", + maybe_model_list(ModelFactory.get_debugging_enriched_model), + DebuggingEnrichedModelConfig, ), ComponentEntity("pipeline", "staged", PipelineFactory.get_staged_pipeline, StagedPipelineConfig), ComponentEntity("pipeline", "scheduled", PipelineFactory.get_scheduled_pipeline, ScheduledPipelineConfig), @@ -211,9 +234,13 @@ class ComponentEntity: ), # losses ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), - # optmizers - ComponentEntity("optimizer", "adam", OptimizerFactory.get_adam, AdamOptimizerConfig), - ComponentEntity("optimizer", "adam_w", OptimizerFactory.get_adam_w, AdamWOptimizerConfig), + # optimizers + ComponentEntity( + "optimizer", "adam", maybe_model_list_for_optimizer(OptimizerFactory.get_adam), AdamOptimizerConfig + ), + ComponentEntity( + "optimizer", "adam_w", maybe_model_list_for_optimizer(OptimizerFactory.get_adam_w), AdamWOptimizerConfig + ), ComponentEntity( "optimizer", "fsdp1_checkpointed", @@ -224,13 +251,24 @@ class ComponentEntity: ComponentEntity("app_state", "raw", AppStateFactory.get_raw_app_state, RawAppStateConfig), ComponentEntity("app_state", "dcp", AppStateFactory.get_dcp_checkpointed_app_state_, DCPAppStateConfig), # schedulers - ComponentEntity("scheduler", "dummy_lr", DummyLRScheduler, DummyLRSchedulerConfig), - ComponentEntity("scheduler", "step_lr", torch.optim.lr_scheduler.StepLR, StepLRSchedulerConfig), - ComponentEntity("scheduler", "constant_lr", torch.optim.lr_scheduler.ConstantLR, ConstantLRSchedulerConfig), - ComponentEntity("scheduler", "linear_lr", torch.optim.lr_scheduler.LinearLR, LinearLRSchedulerConfig), - ComponentEntity("scheduler", "onecycle_lr", torch.optim.lr_scheduler.OneCycleLR, OneCycleLRSchedulerConfig), + ComponentEntity("scheduler", "dummy_lr", maybe_optimizer_list(DummyLRScheduler), DummyLRSchedulerConfig), + ComponentEntity( + "scheduler", "step_lr", maybe_optimizer_list(torch.optim.lr_scheduler.StepLR), StepLRSchedulerConfig + ), ComponentEntity( - "scheduler", "cosine_annealing_lr", torch.optim.lr_scheduler.CosineAnnealingLR, CosineAnnealingLRSchedulerConfig + "scheduler", "constant_lr", maybe_optimizer_list(torch.optim.lr_scheduler.ConstantLR), ConstantLRSchedulerConfig + ), + ComponentEntity( + "scheduler", "linear_lr", maybe_optimizer_list(torch.optim.lr_scheduler.LinearLR), LinearLRSchedulerConfig + ), + ComponentEntity( + "scheduler", "onecycle_lr", maybe_optimizer_list(torch.optim.lr_scheduler.OneCycleLR), OneCycleLRSchedulerConfig + ), + ComponentEntity( + "scheduler", + "cosine_annealing_lr", + maybe_optimizer_list(torch.optim.lr_scheduler.CosineAnnealingLR), + CosineAnnealingLRSchedulerConfig, ), # tokenizers ComponentEntity("tokenizer", "pretrained_hf_tokenizer", PreTrainedHFTokenizer, PreTrainedHFTokenizerConfig), @@ -435,11 +473,13 @@ class ComponentEntity: ), # Debugging components ComponentEntity("debugging", "settings", Debugging, DebuggingConfig), - ComponentEntity("model_debugging_hook", "nan_hook", HookRegistration.register_nan_hooks, NaNHookConfig), + ComponentEntity( + "model_debugging_hook", "nan_hook", maybe_model_list(HookRegistration.register_nan_hooks), NaNHookConfig + ), ComponentEntity( "model_debugging_hook", "print_forward_hook", - HookRegistration.register_print_forward_hooks, + maybe_model_list(HookRegistration.register_print_forward_hooks), PrintForwardHookConfig, ), ] diff --git a/src/modalities/running_env/cuda_env.py b/src/modalities/running_env/cuda_env.py index b58efc7f4..27a0c3307 100644 --- a/src/modalities/running_env/cuda_env.py +++ b/src/modalities/running_env/cuda_env.py @@ -1,4 +1,5 @@ import os +import traceback from datetime import timedelta from typing import Any @@ -6,6 +7,9 @@ import torch.distributed as dist from modalities.config.config import ProcessGroupBackendType +from modalities.utils.logger_utils import get_logger + +logger = get_logger(__name__) class CudaEnv: @@ -47,11 +51,17 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """ local_rank = int(os.getenv("LOCAL_RANK", "-1")) if exc_type is torch.cuda.OutOfMemoryError: - print(f"[Rank {local_rank}] CUDA OOM during block, emptying cache.") + logger.error(f"[Rank {local_rank}] CUDA OOM during block, emptying cache.") torch.cuda.empty_cache() + if exc_type is not None: + logger.error(f"[Rank {local_rank}] Exception of type {exc_type} occurred: {exc_val}") + tb_str = "".join(traceback.format_exception(exc_type, exc_val, exc_tb)) + logger.error(f"[Rank {local_rank}] Traceback:\n{tb_str}") try: if dist.is_initialized(): dist.destroy_process_group() except Exception as e: - print(f"[Rank {local_rank}] Error during process group cleanup: {e}") + logger.error(f"[Rank {local_rank}] Error during process group cleanup: {e}") + tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + logger.error(f"[Rank {local_rank}] Traceback during cleanup:\n{tb_str}") diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 731802c4c..f2253e7b6 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -5,7 +5,6 @@ import torch import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -23,6 +22,7 @@ from modalities.training.training_progress import TrainingProgress from modalities.util import Aggregator, TimeRecorder, print_rank_0 from modalities.utils.mfu import MFUCalculatorABC +from modalities.utils.typing_utils import FSDPX class ThroughputAggregationKeys(Enum): @@ -104,7 +104,7 @@ def _get_num_train_steps_done(micro_batch_id: int, gradient_acc_steps: int) -> i def _train_batch( self, batch: DatasetBatch, - model: FSDP, + model_parts: list[FSDPX], optimizer: Optimizer, scheduler: LRScheduler, loss_fun: Loss, @@ -116,7 +116,7 @@ def _train_batch( Args: batch (DatasetBatch): The input batch of data. - model (FSDP): The model to train. + model_parts (list[FSDPX]): The model parts to train. optimizer (Optimizer): The optimizer used for training. scheduler (LRScheduler): The learning rate scheduler. loss_fun (Loss): The loss function used for training. @@ -140,18 +140,20 @@ def _train_batch( # with self.train_context(optional_context_parallel_ctx): targets, losses = ( (batch.targets[loss_fun.target_key].contiguous(), []) - if scheduled_pipeline.is_last_pp_stage + if scheduled_pipeline.has_last_pp_stage else (None, None) ) - if scheduled_pipeline.is_first_pp_stage: - pp_schedule.step(batch.samples[model.sample_key].contiguous(), target=targets, losses=losses) + if scheduled_pipeline.has_first_pp_stage: + pp_schedule.step(batch.samples[model_parts[0].sample_key].contiguous(), target=targets, losses=losses) else: pp_schedule.step(target=targets, losses=losses) - loss = torch.mean(torch.stack(losses)).to(losses[0].device) if scheduled_pipeline.is_last_pp_stage else None + loss = ( + torch.mean(torch.stack(losses)).to(losses[0].device) if scheduled_pipeline.has_last_pp_stage else None + ) else: # else continue with loss calculation - result_batch = model_predict_batch(model=model, batch=batch) + result_batch = model_predict_batch(model=model_parts[0], batch=batch) loss = loss_fun(result_batch) (loss / self.gradient_acc_steps).backward() @@ -196,10 +198,13 @@ def train( Returns: None """ - model = app_state.model + model_parts = app_state.model_parts optimizer = app_state.optimizer lr_scheduler = app_state.lr_scheduler - model.train() + if scheduled_pipeline is None: + assert len(model_parts) == 1, "Expected a single model part when no scheduled pipeline is provided." + for m in model_parts: + m.train() cumulated_losses = self._reset_tracked_losses() @@ -239,7 +244,7 @@ def train( gradient_norm_score, ) = self._train_batch( batch=batch, - model=model, + model_parts=model_parts, optimizer=optimizer, scheduler=lr_scheduler, loss_fun=loss_fun, diff --git a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py index d13fdd999..a5c0d2cfe 100644 --- a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py +++ b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py @@ -102,7 +102,7 @@ class FSDP2LoggingOnlyGradientClipper(GradientClipperIF): def __init__( self, - wrapped_model: FSDP2, + model_parts: FSDP2 | list[FSDP2], norm_type: GradientClippingMode, device_mesh: Optional[DeviceMesh] = None, error_if_nonfinite: bool = False, @@ -112,7 +112,7 @@ def __init__( Initialize the FSDP2LoggingOnlyGradientClipper. Args: - wrapped_model (FSDP2): The wrapped FSDP2 model. + model_parts (FSDP2 | list[FSDP2]): The wrapped FSDP2 model or list of models. norm_type (GradientClippingMode): The type of gradient clipping. device_mesh (DeviceMesh, optional): The device mesh used for distributed training. Defaults to None. error_if_nonfinite (bool): if True, an error is thrown if the total @@ -126,7 +126,7 @@ def __init__( Returns: None """ - self.wrapped_model = wrapped_model + self.models = model_parts if isinstance(model_parts, list) else [model_parts] self.norm_type = norm_type self.device_mesh = device_mesh self.error_if_nonfinite = error_if_nonfinite @@ -140,7 +140,7 @@ def clip_gradients(self) -> torch.Tensor: Returns: torch.Tensor: The gradient norms. """ - grads = [p.grad for p in self.wrapped_model.parameters() if p.grad is not None] + grads = [p.grad for model in self.models for p in model.parameters() if p.grad is not None] total_norm = torch.nn.utils.get_total_norm( tensors=grads, norm_type=self.norm_type.value, @@ -176,7 +176,7 @@ class FSDP2GradientClipper(FSDP2LoggingOnlyGradientClipper): def __init__( self, - wrapped_model: FSDP2, + model_parts: FSDP2 | list[FSDP2], max_norm: float, norm_type: GradientClippingMode, device_mesh: Optional[DeviceMesh] = None, @@ -187,7 +187,7 @@ def __init__( Initialize the FSDP2GradientClipper object. Args: - wrapped_model (FSDP2): The wrapped FSDP2 model. + model_parts (FSDP2 | list[FSDP2]): The wrapped FSDP2 model or list of model parts. max_norm (float): The maximum norm value for gradient clipping. norm_type (GradientClippingMode): The type of gradient clipping. device_mesh (DeviceMesh, optional): The device mesh used for distributed training. Defaults to None. @@ -203,7 +203,7 @@ def __init__( None """ super().__init__( - wrapped_model=wrapped_model, + model_parts=model_parts, norm_type=norm_type, device_mesh=device_mesh, error_if_nonfinite=error_if_nonfinite, @@ -220,10 +220,11 @@ def clip_gradients(self) -> torch.Tensor: torch.Tensor: The gradient norm after clipping. """ total_norm = super().clip_gradients() - torch.nn.utils.clip_grads_with_norm_( - parameters=self.wrapped_model.parameters(), - max_norm=self.max_norm, - total_norm=total_norm, - foreach=self.foreach, - ) + for model in self.models: + torch.nn.utils.clip_grads_with_norm_( + parameters=model.parameters(), + max_norm=self.max_norm, + total_norm=total_norm, + foreach=self.foreach, + ) return total_norm diff --git a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py index fde4c3f1b..f101abc5c 100644 --- a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py +++ b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py @@ -2,8 +2,16 @@ from pydantic import BaseModel, Field -from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticPytorchModuleType +from modalities.config.pydantic_if_types import ( + PydanticDeviceMeshIFType, + PydanticPytorchModuleOrListType, + PydanticPytorchModuleType, +) from modalities.training.gradient_clipping.fsdp_gradient_clipper import GradientClippingMode +from modalities.utils.deprecated_alias import add_deprecated_alias +from modalities.utils.logger_utils import get_logger + +logger = get_logger("fsdp_gradient_clipper_config") class FSDP1GradientClipperConfig(BaseModel): @@ -26,6 +34,7 @@ class FSDP1GradientClipperConfig(BaseModel): wrapped_model: PydanticPytorchModuleType +@add_deprecated_alias("model_parts", "wrapped_model") class FSDP2GradientClipperConfig(BaseModel): """ Configuration class for FSDP gradient clipper. @@ -33,19 +42,19 @@ class FSDP2GradientClipperConfig(BaseModel): Args: max_norm (float): The maximum norm value for gradient clipping. norm_type (GradientClippingMode): The type of gradient clipping to be applied. - wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + model_parts (PydanticPytorchModuleOrListType): The wrapped PyTorch model (parts). device_mesh (PydanticDeviceMeshIFType | None): The device mesh configuration. Attributes: max_norm (float): The maximum norm value for gradient clipping. norm_type (GradientClippingMode): The type of gradient clipping to be applied. - wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + model_parts (PydanticPytorchModuleOrListType): The wrapped PyTorch model (parts). device_mesh (PydanticDeviceMeshIFType | None): The device mesh configuration. """ max_norm: Annotated[float, Field(strict=True, gt=0)] norm_type: GradientClippingMode - wrapped_model: PydanticPytorchModuleType + model_parts: PydanticPytorchModuleOrListType device_mesh: PydanticDeviceMeshIFType @@ -66,21 +75,22 @@ class FSDP1DummyGradientClipperConfig(BaseModel): norm_type: GradientClippingMode +@add_deprecated_alias("model_parts", "wrapped_model") class FSDP2DummyGradientClipperConfig(BaseModel): """ Configuration class for FSDP dummy gradient clipper. Args: - wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + model_parts (PydanticPytorchModuleOrListType): The wrapped PyTorch model (parts). norm_type (GradientClippingMode): The type of gradient clipping to be applied. device_mesh (PydanticDeviceMeshIFType | None): The device mesh configuration. Attributes: - wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + model_parts (PydanticPytorchModuleOrListType): The wrapped PyTorch model (parts). norm_type (GradientClippingMode): The type of gradient clipping to be applied. device_mesh (PydanticDeviceMeshIFType | None): The device mesh configuration. """ - wrapped_model: PydanticPytorchModuleType + model_parts: PydanticPytorchModuleOrListType norm_type: GradientClippingMode device_mesh: PydanticDeviceMeshIFType diff --git a/src/modalities/util.py b/src/modalities/util.py index 4bff43859..a4cd32f12 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -20,6 +20,7 @@ from modalities.exceptions import TimeRecorderStateError from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, has_parallelism_method from modalities.running_env.fsdp.reducer import Reducer +from modalities.utils.maybe_list_parameter import maybe_list_parameter from modalities.utils.typing_utils import FSDPX @@ -167,6 +168,7 @@ def get_local_number_of_trainable_parameters(model: nn.Module) -> int: return num_params +@maybe_list_parameter("model", apply_to_list_result=sum) def get_total_number_of_trainable_parameters(model: FSDPX, device_mesh: DeviceMesh | None) -> Number: """Returns the total number of trainable parameters across all ranks. The model must be sharded with FSDP1 or FSDP2. @@ -331,9 +333,10 @@ def get_module_class_from_name(module: torch.nn.Module, name: str) -> Type[torch if module.__class__.__name__ == name: return module.__class__ elif len(modules_children) == 0: - return + return None else: for child_module in modules_children: module_class = get_module_class_from_name(child_module, name) if module_class is not None: return module_class + return None diff --git a/src/modalities/utils/deprecated_alias.py b/src/modalities/utils/deprecated_alias.py new file mode 100644 index 000000000..83c58d73d --- /dev/null +++ b/src/modalities/utils/deprecated_alias.py @@ -0,0 +1,96 @@ +import copy +import warnings +from typing import Any, Callable + +from pydantic import AliasPath, BaseModel, model_validator +from pydantic.aliases import AliasChoices +from pydantic.fields import FieldInfo + + +def add_deprecated_alias( + field_name: str, alias: str, warning_message: str | None = None +) -> Callable[[type[BaseModel]], type[BaseModel]]: + """ + Decorator to add a deprecated alias to a specific field in a Pydantic BaseModel. + Issues a deprecation warning when the alias is used. + + Args: + field_name (str): The name of the field to add an alias for + alias (str): The deprecated alias name to register + warning_message (str | None): Custom warning message (optional) + Returns: + Callable[[type[BaseModel]], type[BaseModel]]: Decorator function + """ + + def decorator(cls: type[BaseModel]) -> type[BaseModel]: + if not issubclass(cls, BaseModel): + raise TypeError("Decorator can only be applied to Pydantic BaseModel subclasses") + if field_name not in cls.model_fields: + raise ValueError(f"While adding alias to BaseModel: Field '{field_name}' not found in model") + + new_field = _build_new_field_with_alias(cls, field_name, alias) + cls = _add_field_and_deprecation_validator_to_class(cls, new_field, field_name, alias, warning_message) + + return cls + + return decorator + + +def _build_new_field_with_alias(cls: type[BaseModel], field_name: str, alias: str) -> FieldInfo: + field_info = cls.model_fields[field_name] + + aliases = _build_alias_list(alias, field_info) + + new_field = copy.deepcopy(field_info) + new_field.validation_alias = AliasChoices(*aliases) + if new_field.alias_priority is None: + # deprecated alias should have lower priority than original field + new_field.alias_priority = 2 + return new_field + + +def _build_alias_list(alias: str, field_info: FieldInfo) -> list[str | AliasPath]: + aliases: list[str | AliasPath] = [alias] + + # Handle existing aliases + existing_alias = field_info.validation_alias + if existing_alias: + if isinstance(existing_alias, AliasChoices): + aliases.extend(existing_alias.choices) + else: + aliases.append(existing_alias) + + return aliases + + +def _add_field_and_deprecation_validator_to_class( + cls: type[BaseModel], new_field: FieldInfo, field_name: str, alias: str, warning_message: str | None +) -> type[BaseModel]: + cls.model_fields[field_name] = new_field + cls = _add_deprecation_warning(cls, field_name, alias, warning_message) + res = cls.model_rebuild(force=True) + if res is None or not res: + raise RuntimeError("Failed to rebuild the Pydantic model after adding deprecated alias") + return cls + + +def _add_deprecation_warning( + cls: type[BaseModel], field_name: str, alias: str, warning_message: str | None +) -> type[BaseModel]: + # Store deprecated aliases info for the validator + if not hasattr(cls, "_deprecated_aliases"): + setattr(cls, "_deprecated_aliases", {}) + cls._deprecated_aliases[alias] = { + "field_name": field_name, + "warning_message": warning_message or f"Alias '{alias}' is deprecated. Use '{field_name}' instead.", + } + + @model_validator(mode="before") + def check_deprecated_aliases(cls: type[BaseModel], data: Any) -> Any: + if isinstance(data, dict): + for deprecated_alias, info in getattr(cls, "_deprecated_aliases", {}).items(): + if deprecated_alias in data: + warnings.warn(info["warning_message"], DeprecationWarning, stacklevel=3) + return data + + return type(cls.__name__, (cls,), {"check_deprecated_aliases": check_deprecated_aliases}) diff --git a/src/modalities/utils/maybe_list_parameter.py b/src/modalities/utils/maybe_list_parameter.py new file mode 100644 index 000000000..975f947bb --- /dev/null +++ b/src/modalities/utils/maybe_list_parameter.py @@ -0,0 +1,85 @@ +import inspect +from functools import wraps +from typing import Callable, Concatenate, ParamSpec, TypeAlias, TypeVar, cast + +T = TypeVar("T") # Represents the parameter we want to wrap into a list. +P = ParamSpec("P") # Represents the other parameters of the function. +R1 = TypeVar("R1") # Represents the return type of the base function. +R2 = TypeVar("R2") # Represents the return type of the additional reducers. + +ListResultsReducer: TypeAlias = Callable[[list[R1]], R2] +ListInputAndResultsReducer: TypeAlias = Callable[[list[T], list[R1]], R2] + +BaseFunc: TypeAlias = Callable[Concatenate[T, P], R1] +MaybeListWrappedFunc: TypeAlias = Callable[Concatenate[T | list[T], P], R1 | list[R1] | R2] + +MaybeListDecorator: TypeAlias = Callable[[BaseFunc[T, P, R1]], MaybeListWrappedFunc[T, P, R1, R2]] + + +def maybe_list_parameter( + parameter_name: str, + apply_to_list_result: ListResultsReducer[R1, R2] | None = None, + apply_to_list_input_and_result: ListInputAndResultsReducer[T, R1, R2] | None = None, +) -> MaybeListDecorator[T, P, R1, R2]: + """Decorator factory allowing a specific parameter to be a single item or a list. + If a list is provided, the wrapped function is called once per element and a list + of results is returned; otherwise the single result is returned. + + Args: + parameter_name (str): The name of the parameter to treat as a list or single item. + apply_to_list_result (ListResultsReducer | None): Reduces list of results -> single R + apply_to_list_input_and_result (ListInputAndResultsReducer | None): + Takes (original list input, list of results) -> single R + (mutually exclusive) + """ + + def decorator(func: BaseFunc[T, P, R1]) -> MaybeListWrappedFunc[T, P, R1, R2]: + sig = inspect.signature(func) + + # Find positional index of the target parameter (if present) + param_pos_index: int | None = None + for idx, (name, _) in enumerate(sig.parameters.items()): + if name == parameter_name: + param_pos_index = idx + break + if param_pos_index is None: + raise ValueError(f"Parameter '{parameter_name}' not found in function '{func.__name__}' signature.") + + @wraps(func) + def maybe_list_parameter_wrapper(*args: P.args, **kwargs: P.kwargs) -> R1 | list[R1] | R2: + # Obtain value (positional or kw) + if parameter_name in kwargs: + param_value: T | list[T] = kwargs[parameter_name] + elif param_pos_index < len(args): + param_value = args[param_pos_index] + else: + # Parameter not supplied; just call through + return func(*args, **kwargs) + + # If not a list, call directly + if not isinstance(param_value, list): + return func(*args, **kwargs) + + # Process each element + results: list[R1] = [] + for item in param_value: + if parameter_name in kwargs: + new_kwargs = dict(kwargs) + new_kwargs[parameter_name] = item + results.append(func(*args, **new_kwargs)) + else: + new_args = list(args) + new_args[param_pos_index] = item + results.append(func(*new_args, **kwargs)) + + if apply_to_list_result is not None: + if apply_to_list_input_and_result is not None: + raise ValueError("Cannot provide both apply_to_list_result and apply_to_list_input_and_result.") + return apply_to_list_result(results) + if apply_to_list_input_and_result is not None: + return apply_to_list_input_and_result(param_value, results) + return results + + return cast(MaybeListWrappedFunc, maybe_list_parameter_wrapper) + + return decorator diff --git a/src/modalities/utils/mfu.py b/src/modalities/utils/mfu.py index d096a4061..76c63d122 100644 --- a/src/modalities/utils/mfu.py +++ b/src/modalities/utils/mfu.py @@ -89,32 +89,33 @@ def _get_theoretical_gpu_peak_performance_single(precision: torch.dtype, gpu_typ return None @staticmethod - def _get_theoretical_gpu_peak_performance(wrapped_model: FSDPX, world_size: int) -> Optional[Number]: + def _get_theoretical_gpu_peak_performance(model_parts: FSDPX | list[FSDP2], world_size: int) -> Optional[Number]: """ Calculates the accumulated theoretical peak performance based on all GPUs, i.e., #GPU=world_size, in units FLOPs / s for given gpu type. Args: - model (FSDPX): The model for which to calculate the theoretical peak performance. + model_parts (FSDPX | list[FSDP2]): The model or model parts for which + to calculate the theoretical peak performance. world_size (int): The number of GPUs used in parallel. Returns: (Number, optional): The accumulated theoretical peak performance of all GPUs, or None if it cannot be calculated. """ - if isinstance(wrapped_model, FSDP1): - precision = wrapped_model.mixed_precision.param_dtype + if isinstance(model_parts, FSDP1): + precision = model_parts.mixed_precision.param_dtype if ( - wrapped_model.mixed_precision.reduce_dtype != precision - or wrapped_model.mixed_precision.buffer_dtype != precision + model_parts.mixed_precision.reduce_dtype != precision + or model_parts.mixed_precision.buffer_dtype != precision ): warnings.warn(f"Could not get theoretical GPU peak performance for mixed precision type = {precision}.") return None - elif isinstance(wrapped_model, FSDP2): + elif isinstance(model_parts, FSDP2) or isinstance(model_parts, list): warnings.warn("MFU is computed based on the assumption that bf16 precision is used.") precision = torch.bfloat16 else: - raise TypeError(f"Model should be of type FSDPX, but is {type(wrapped_model)} instead.") + raise TypeError(f"Model should be of type FSDPX, but is {type(model_parts)} instead.") device_name = torch.cuda.get_device_name() if device_name.startswith("NVIDIA A100"): @@ -153,10 +154,10 @@ def __init__( sequence_length: int, n_embd: int, world_size: int, - wrapped_model: FSDPX, + model_parts: FSDPX | list[FSDP2], device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, ): - self._num_params = get_total_number_of_trainable_parameters(model=wrapped_model, device_mesh=device_mesh) + self._num_params = get_total_number_of_trainable_parameters(model=model_parts, device_mesh=device_mesh) self._n_layer = n_layer self._sequence_length = sequence_length self._n_embd = n_embd @@ -167,7 +168,7 @@ def __init__( n_embd=self._n_embd, ) self._theoretical_gpu_peak_performance = MFUCalculatorABC._get_theoretical_gpu_peak_performance( - wrapped_model, world_size + model_parts, world_size ) @staticmethod diff --git a/tests/checkpointing/checkpointing_test_utils.py b/tests/checkpointing/checkpointing_test_utils.py index c350ccbc8..a9da2c1a9 100644 --- a/tests/checkpointing/checkpointing_test_utils.py +++ b/tests/checkpointing/checkpointing_test_utils.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from pydantic import BaseModel from torch.distributed.tensor import DTensor @@ -7,6 +9,7 @@ from modalities.config.component_factory import ComponentFactory from modalities.config.config import PydanticPytorchModuleType from modalities.models.gpt2.gpt2_model import GPT2LLM +from modalities.models.parallelism.pipeline_parallelism import Pipeline from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry from modalities.utils.typing_utils import FSDPX @@ -14,7 +17,7 @@ class CheckpointingTestUtils: @staticmethod - def generate_batch(gpt2_model_config: dict): + def generate_batch(gpt2_model_config: dict[str, Any]) -> tuple[dict[str, torch.Tensor], torch.Tensor]: # prepare input and targets if "settings" in gpt2_model_config: batch_size = gpt2_model_config["settings"]["step_profile"]["local_train_micro_batch_size"] @@ -36,12 +39,15 @@ def generate_batch(gpt2_model_config: dict): @staticmethod def forward_backward_pass( prediction_key: str, - model: FSDPX, + model: FSDPX | list[FSDPX], optimizer: Optimizer, - batch_input_ids_dict: dict, + batch_input_ids_dict: dict[str, torch.Tensor], batch_target_ids: torch.Tensor, ) -> torch.Tensor: ce_loss = CrossEntropyLoss() + if isinstance(model, list): + assert len(model) == 1, "Only single model part supported in this utility function." + model = model[0] # clear the gradients optimizer.zero_grad() @@ -59,25 +65,25 @@ def forward_backward_pass( @staticmethod def forward_backward_pp_pass( - scheduled_pipeline, + scheduled_pipeline: Pipeline, optimizer: Optimizer, - batch_input_ids_dict: dict, + batch_input_ids_dict: dict[str, torch.Tensor], batch_target_ids: torch.Tensor, ): pp_schedule = scheduled_pipeline.pp_schedule # Pipeline Parallel forward / backward inside step() call # with self.train_context(optional_context_parallel_ctx): - targets, losses = (batch_target_ids.contiguous(), []) if scheduled_pipeline.is_last_pp_stage else (None, None) + targets, losses = (batch_target_ids.contiguous(), []) if scheduled_pipeline.has_last_pp_stage else (None, None) - if scheduled_pipeline.is_first_pp_stage: + if scheduled_pipeline.has_first_pp_stage: pp_schedule.step( - batch_input_ids_dict[scheduled_pipeline.model_part.sample_key].contiguous(), + batch_input_ids_dict[scheduled_pipeline.model_parts[0].sample_key].contiguous(), target=targets, losses=losses, ) else: pp_schedule.step(target=targets, losses=losses) - loss = torch.mean(torch.stack(losses)).to(losses[0].device) if scheduled_pipeline.is_last_pp_stage else None + loss = torch.mean(torch.stack(losses)).to(losses[0].device) if scheduled_pipeline.has_last_pp_stage else None optimizer.step() # clear the gradients optimizer.zero_grad() @@ -85,7 +91,7 @@ def forward_backward_pp_pass( return loss @staticmethod - def get_gpt2_model_from_config(gpt2_model_config_dict: dict) -> GPT2LLM: + def get_gpt2_model_from_config(gpt2_model_config_dict: dict[str, Any]) -> GPT2LLM: class GPT2InstantationModel(BaseModel): model: PydanticPytorchModuleType @@ -100,25 +106,55 @@ class GPT2InstantationModel(BaseModel): return model @staticmethod - def clone_parameters(fsdp_wrapped_model: FSDPX): - return [p.clone() for p in fsdp_wrapped_model.parameters() if p.requires_grad and p.numel() > 0] + def clone_parameters(fsdp_wrapped_model: FSDPX | list[FSDPX]) -> list[list[torch.Tensor]]: + if not isinstance(fsdp_wrapped_model, list): + fsdp_wrapped_model = [fsdp_wrapped_model] + return [[p.clone() for p in m.parameters() if p.requires_grad and p.numel() > 0] for m in fsdp_wrapped_model] @staticmethod def assert_equality_optimizer_param_group( - optimizer_1_state_dict: dict, optimizer_2_state_dict: dict, must_be_equal: bool + optimizer_1_state_dict: dict[str, Any], optimizer_2_state_dict: dict[str, Any], must_be_equal: bool ): + # Need to differentiate between normal optimizer state dicts and flattened ones. + if "param_groups" in optimizer_1_state_dict: + assert "param_groups" in optimizer_2_state_dict + optimizer_1_state_dict = optimizer_1_state_dict["param_groups"] + optimizer_2_state_dict = optimizer_2_state_dict["param_groups"] + else: + optimizer_1_state_dict = {k: v for k, v in optimizer_1_state_dict.items() if k.startswith("param_groups")} + optimizer_2_state_dict = {k: v for k, v in optimizer_2_state_dict.items() if k.startswith("param_groups")} + assert len(optimizer_1_state_dict) > 0, "No param_groups found in flattened optimizer state dict." if must_be_equal: assert ( - optimizer_1_state_dict["param_groups"] == optimizer_2_state_dict["param_groups"] + optimizer_1_state_dict == optimizer_2_state_dict ), "_assert_equality_optimizer_param_group failed (must_be_equal = True)" else: assert not ( - optimizer_1_state_dict["param_groups"] == optimizer_2_state_dict["param_groups"] + optimizer_1_state_dict == optimizer_2_state_dict ), "_assert_equality_optimizer_param_group failed (must_be_equal = False)" @staticmethod def assert_equality_optimizer_state( - optimizer_1_state_dict: dict, optimizer_2_state_dict: dict, must_be_equal: bool + optimizer_1_state_dict: dict[str, Any], optimizer_2_state_dict: dict[str, Any], must_be_equal: bool + ): + # Need to differentiate between normal optimizer state dicts and flattened ones. + if "state" in optimizer_1_state_dict: + assert "state" in optimizer_2_state_dict + CheckpointingTestUtils.assert_equality_non_flattened_optimizer_state( + optimizer_1_state_dict=optimizer_1_state_dict, + optimizer_2_state_dict=optimizer_2_state_dict, + must_be_equal=must_be_equal, + ) + else: + CheckpointingTestUtils.assert_equality_flattened_optimizer_state( + optimizer_1_state_dict=optimizer_1_state_dict, + optimizer_2_state_dict=optimizer_2_state_dict, + must_be_equal=must_be_equal, + ) + + @staticmethod + def assert_equality_non_flattened_optimizer_state( + optimizer_1_state_dict: dict[str, Any], optimizer_2_state_dict: dict[str, Any], must_be_equal: bool ): optimizer_1_state = optimizer_1_state_dict["state"] optimizer_2_state = optimizer_2_state_dict["state"] @@ -137,15 +173,39 @@ def assert_equality_optimizer_state( ) @staticmethod - def assert_equality_two_models(params_1: list[torch.Tensor], params_2: list[torch.Tensor], must_be_equal: bool): - for p1, p2 in zip(params_1, params_2): + def assert_equality_flattened_optimizer_state( + optimizer_1_state_dict: dict[str, Any], optimizer_2_state_dict: dict[str, Any], must_be_equal: bool + ): + optimizer_1_state = {k: v for k, v in optimizer_1_state_dict.items() if k.startswith("state")} + optimizer_2_state = {k: v for k, v in optimizer_2_state_dict.items() if k.startswith("state")} + assert len(optimizer_1_state) > 0, "No state found in flattened optimizer state dict." + assert set(optimizer_1_state.keys()) == set(optimizer_2_state.keys()) + for state_key in optimizer_1_state.keys(): CheckpointingTestUtils.assert_equality_two_tensors( - tensor_1=p1, - tensor_2=p2, + tensor_1=optimizer_1_state[state_key], + tensor_2=optimizer_2_state[state_key], must_be_equal=must_be_equal, - msg_on_failure="_assert_equality_two_models failed", + msg_on_failure="_assert_equality_optimizer_state failed", ) + @staticmethod + def assert_equality_two_models( + params_1: list[torch.Tensor] | list[list[torch.Tensor]], + params_2: list[torch.Tensor] | list[list[torch.Tensor]], + must_be_equal: bool, + ): + for p1, p2 in zip(params_1, params_2): + if isinstance(p1, list): + assert isinstance(p2, list), "_assert_equality_two_models failed (type mismatch with list)" + CheckpointingTestUtils.assert_equality_two_models(params_1=p1, params_2=p2, must_be_equal=must_be_equal) + else: + CheckpointingTestUtils.assert_equality_two_tensors( + tensor_1=p1, + tensor_2=p2, + must_be_equal=must_be_equal, + msg_on_failure="_assert_equality_two_models failed", + ) + @staticmethod def assert_equality_two_tensors( tensor_1: torch.Tensor, tensor_2: torch.Tensor, must_be_equal: bool, msg_on_failure: str = "" @@ -158,3 +218,4 @@ def assert_equality_two_tensors( assert torch.equal(tensor_1, tensor_2), f"{msg_on_failure} (must_be_equal = True)" else: assert not torch.equal(tensor_1, tensor_2), f"{msg_on_failure} (must_be_equal = False)" + assert not torch.equal(tensor_1, tensor_2), f"{msg_on_failure} (must_be_equal = False)" diff --git a/tests/checkpointing/test_fsdp1_to_disc_checkpointing.py b/tests/checkpointing/test_fsdp1_to_disc_checkpointing.py index 5b5f0428b..934e13245 100644 --- a/tests/checkpointing/test_fsdp1_to_disc_checkpointing.py +++ b/tests/checkpointing/test_fsdp1_to_disc_checkpointing.py @@ -136,7 +136,7 @@ def _test_save_checkpoint_after_backward_pass_impl( sharding_strategy=ShardingStrategy.FULL_SHARD, ) - untrained_model_parameters = CheckpointingTestUtils.clone_parameters(fsdp1_wrapped_model) + untrained_model_parameters = CheckpointingTestUtils.clone_parameters(fsdp1_wrapped_model)[0] untrained_optimizer_state_dict = deepcopy(optimizer.state_dict()) # run backward pass @@ -148,7 +148,7 @@ def _test_save_checkpoint_after_backward_pass_impl( batch_input_ids_dict=batch_input_ids_dict, batch_target_ids=batch_target_ids, ) - updated_model_parameters = CheckpointingTestUtils.clone_parameters(fsdp1_wrapped_model) + updated_model_parameters = CheckpointingTestUtils.clone_parameters(fsdp1_wrapped_model)[0] updated_optimizer_state_dict = deepcopy(optimizer.state_dict()) # save model and optimizer before backward pass @@ -197,7 +197,7 @@ def _test_save_checkpoint_after_backward_pass_impl( optimizer=optimizer_2, model=fsdp1_wrapped_model_2, file_path=optimizer_checkpointing_path ) - loaded_and_updated_model_parameters = CheckpointingTestUtils.clone_parameters(fsdp1_wrapped_model) + loaded_and_updated_model_parameters = CheckpointingTestUtils.clone_parameters(fsdp1_wrapped_model)[0] loaded_and_updated_optimizer_state_dict = deepcopy(optimizer_2.state_dict()) # make sure that after the update all weights are DIFFERENT from the original ones diff --git a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py index 6ecd65751..ce4195b7f 100644 --- a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py +++ b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py @@ -180,7 +180,7 @@ def _test_save_checkpoint_after_backward_pass_impl( prediction_key = gpt2_model_config_dict["model_raw"]["config"]["prediction_key"] # save the initial model and optimizer state dicts - untrained_model_parameters = CheckpointingTestUtils.clone_parameters(app_state1.model) + untrained_model_parameters = CheckpointingTestUtils.clone_parameters(app_state1.model_parts) untrained_optimizer_state_dict = deepcopy(app_state1.optimizer.state_dict()) # run backward pass @@ -195,14 +195,14 @@ def _test_save_checkpoint_after_backward_pass_impl( else: loss_0 = CheckpointingTestUtils.forward_backward_pass( prediction_key=prediction_key, - model=app_state1.model, + model=app_state1.model_parts, optimizer=app_state1.optimizer, batch_input_ids_dict=batch_input_ids_dict, batch_target_ids=batch_target_ids, ) # save the updated model and optimizer states for later comparisons - updated_model_parameters = CheckpointingTestUtils.clone_parameters(app_state1.model) + updated_model_parameters = CheckpointingTestUtils.clone_parameters(app_state1.model_parts) updated_optimizer_state_dict = deepcopy(app_state1.optimizer.state_dict()) # checkpoint the model and optimizer before backward pass @@ -248,7 +248,7 @@ def _test_save_checkpoint_after_backward_pass_impl( checkpoint_dir_path=dcp_checkpoint_folder_path, ) - loaded_and_updated_model_parameters = CheckpointingTestUtils.clone_parameters(app_state1.model) + loaded_and_updated_model_parameters = CheckpointingTestUtils.clone_parameters(app_state1.model_parts) loaded_and_updated_optimizer_state_dict = deepcopy(app_state1.optimizer.state_dict()) # perform another forward pass and backward pass for the previous and the loaded model @@ -273,7 +273,7 @@ def _test_save_checkpoint_after_backward_pass_impl( else: loss_1 = CheckpointingTestUtils.forward_backward_pass( prediction_key=prediction_key, - model=app_state1.model, + model=app_state1.model_parts, optimizer=app_state1.optimizer, batch_input_ids_dict=batch_input_ids_dict, batch_target_ids=batch_target_ids, @@ -281,7 +281,7 @@ def _test_save_checkpoint_after_backward_pass_impl( loss_2 = CheckpointingTestUtils.forward_backward_pass( prediction_key=prediction_key, - model=app_state2.model, + model=app_state2.model_parts, optimizer=app_state2.optimizer, batch_input_ids_dict=batch_input_ids_dict, batch_target_ids=batch_target_ids, @@ -291,6 +291,29 @@ def _test_save_checkpoint_after_backward_pass_impl( assert loss_1 < loss_0, f"loss_1 = {loss_1} is not less than loss_0 = {loss_0}" # check that the model and optimizer states after each backward pass are as expected + TestFSDP2DCPCheckpointing._check_states_match_as_expected( + app_state1, + app_state2, + untrained_model_parameters, + untrained_optimizer_state_dict, + updated_model_parameters, + updated_optimizer_state_dict, + loaded_and_updated_model_parameters, + loaded_and_updated_optimizer_state_dict, + ) + + @staticmethod + def _check_states_match_as_expected( + app_state1: AppState, + app_state2: AppState, + untrained_model_parameters: list[torch.Tensor], + untrained_optimizer_state_dict: dict, + updated_model_parameters: list[torch.Tensor], + updated_optimizer_state_dict: dict, + loaded_and_updated_model_parameters: list[torch.Tensor], + loaded_and_updated_optimizer_state_dict: dict, + past_backward_pass: bool = False, + ) -> None: # model weights CheckpointingTestUtils.assert_equality_two_models( untrained_model_parameters, updated_model_parameters, must_be_equal=False @@ -298,12 +321,12 @@ def _test_save_checkpoint_after_backward_pass_impl( CheckpointingTestUtils.assert_equality_two_models( loaded_and_updated_model_parameters, updated_model_parameters, must_be_equal=True ) - CheckpointingTestUtils.assert_equality_two_models( - app_state1.model.parameters(), app_state2.model.parameters(), must_be_equal=True - ) - CheckpointingTestUtils.assert_equality_two_models( - app_state1.model.parameters(), updated_model_parameters, must_be_equal=False - ) + for m1, m2 in zip(app_state1.model_parts, app_state2.model_parts): + CheckpointingTestUtils.assert_equality_two_models(m1.parameters(), m2.parameters(), must_be_equal=True) + if past_backward_pass: + CheckpointingTestUtils.assert_equality_two_models( + [m.parameters() for m in app_state1.model_parts], updated_model_parameters, must_be_equal=False + ) # param groups CheckpointingTestUtils.assert_equality_optimizer_param_group( @@ -322,6 +345,7 @@ def _test_save_checkpoint_after_backward_pass_impl( CheckpointingTestUtils.assert_equality_optimizer_state( app_state1.optimizer.state_dict(), app_state2.optimizer.state_dict(), must_be_equal=True ) - CheckpointingTestUtils.assert_equality_optimizer_state( - app_state1.optimizer.state_dict(), updated_optimizer_state_dict, must_be_equal=False - ) + if past_backward_pass: + CheckpointingTestUtils.assert_equality_optimizer_state( + app_state1.optimizer.state_dict(), updated_optimizer_state_dict, must_be_equal=False + ) diff --git a/tests/conftest.py b/tests/conftest.py index 1315057c9..bfb9a93f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -125,22 +125,22 @@ def wrapped_gpt2_tokenizer() -> PreTrainedHFTokenizer: @pytest.fixture(scope="function") -def checkpoint_saving_mock(): +def checkpoint_saving_mock() -> CheckpointSaving: return MagicMock(spec=CheckpointSaving) @pytest.fixture(scope="function") -def evaluator_mock(): +def evaluator_mock() -> Evaluator: return MagicMock(spec=Evaluator) @pytest.fixture(scope="function") -def nn_model_mock(): +def nn_model_mock() -> NNModel: return MagicMock(spec=NNModel) @pytest.fixture(scope="function") -def optimizer_mock(): +def optimizer_mock() -> Optimizer: return MagicMock(spec=Optimizer) @@ -166,41 +166,43 @@ def custom_step_function(lr_decay_factor): @pytest.fixture(scope="function") -def scheduler_mock(): +def scheduler_mock() -> LRScheduler: mocked_lr_scheduler = MagicMock(spec=LRScheduler) mocked_lr_scheduler.get_last_lr = lambda: [0.0] return mocked_lr_scheduler @pytest.fixture(scope="function") -def app_state_mock(): - return MagicMock(spec=AppState) +def app_state_mock() -> AppState: + app_state = MagicMock(spec=AppState) + app_state.model_parts = [MagicMock()] + return app_state @pytest.fixture(scope="function") -def gradient_clipper_mock(): +def gradient_clipper_mock() -> GradientClipperIF: gradient_clipper = MagicMock(spec=GradientClipperIF) gradient_clipper.clip_gradients = lambda: torch.Tensor([0.0]) return gradient_clipper @pytest.fixture(scope="function") -def loss_mock(): +def loss_mock() -> Loss: return MagicMock(spec=Loss, return_value=torch.rand(1, requires_grad=True)) @pytest.fixture(scope="function") -def llm_data_loader_mock(): +def llm_data_loader_mock() -> LLMDataLoader: return MagicMock(spec=LLMDataLoader) @pytest.fixture(scope="function") -def progress_publisher_mock(): +def progress_publisher_mock() -> MessagePublisher: return MagicMock(spec=MessagePublisher) @pytest.fixture(scope="function") -def trainer(progress_publisher_mock, gradient_clipper_mock): +def trainer(progress_publisher_mock: MessagePublisher, gradient_clipper_mock: GradientClipperIF) -> Trainer: return Trainer( global_rank=int(os.getenv("RANK")), progress_publisher=progress_publisher_mock, diff --git a/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp.yaml b/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp.yaml new file mode 100644 index 000000000..696d9e479 --- /dev/null +++ b/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp.yaml @@ -0,0 +1,369 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: tmp/checkpoints + train_dataset_path: tests/end2end_tests/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 4 + evaluation_interval_in_steps: 1 + consistency_enforcement: + enforce_tokens_per_step_consistency: false + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 2 + sequence_length: 256 + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_packed_mem_map_dataset_continuous + config: + dataset_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: #7 # for the batch progress subscriber + component_key: number_conversion + variant_key: num_steps_from_num_tokens + config: + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_tokens: ${settings.training_target.num_target_tokens} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: [] + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + tensor_parallel_degree: 1 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +dp_degree: + component_key: number_conversion + variant_key: parallel_degree + config: # get the parallel degree from the device mesh + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + parallelism_methods: [dp_shard, dp_replicate] + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: Interleaved1F1B + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 1 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + # maybe better to use the fsdp model and the schedule here + # instead of passing in the staged pipeline? + # If fsdp_model creates a copy then this is not in the scope of + # the staged pipeline. + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name} + num_layers_per_stage: 2 + +model_raw: + component_key: model + variant_key: gpt2 + config: + seed: 42 + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 8 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: save_all + config: {} diff --git a/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp_tp.yaml b/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp_tp.yaml index 930d710bb..dbd97f9e5 100644 --- a/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp_tp.yaml +++ b/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp_tp.yaml @@ -196,7 +196,7 @@ scheduled_pipeline: loss_fn: instance_key: loss_fn pass_type: BY_REFERENCE - pp_schedule_name: gpipe + pp_schedule_name: Interleaved1F1B batch_size: ${settings.step_profile.local_train_micro_batch_size} microbatch_size: 1 pp_degree: ${device_mesh.config.pipeline_parallel_degree} @@ -273,7 +273,7 @@ staged_pipeline: instance_key: device_mesh pass_type: BY_REFERENCE local_rank: ${settings.cuda_env.local_rank} - pp_schedule_name: gpipe + pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name} num_layers_per_stage: 2 model_raw: diff --git a/tests/end2end_tests/configs/gpt2_train_num_steps_7_tp.yaml b/tests/end2end_tests/configs/gpt2_train_num_steps_7_tp.yaml new file mode 100644 index 000000000..ebf915db8 --- /dev/null +++ b/tests/end2end_tests/configs/gpt2_train_num_steps_7_tp.yaml @@ -0,0 +1,315 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: tmp/checkpoints + train_dataset_path: tests/end2end_tests/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 4 + evaluation_interval_in_steps: 1 + consistency_enforcement: + enforce_tokens_per_step_consistency: false + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 2 + sequence_length: 256 + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_packed_mem_map_dataset_continuous + config: + dataset_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: + component_key: number_conversion + variant_key: num_steps_from_num_tokens + config: + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_tokens: ${settings.training_target.num_target_tokens} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: [] + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 1 + tensor_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +dp_degree: + component_key: number_conversion + variant_key: parallel_degree + config: # get the parallel degree from the device mesh + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + parallelism_methods: [dp_shard, dp_replicate] + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: gpt2_tp_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +gpt2_tp_model: + component_key: model + variant_key: gpt2_tp + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +model_raw: + component_key: model + variant_key: gpt2 + config: + seed: 42 + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 8 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: save_all + config: {} diff --git a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml new file mode 100644 index 000000000..b8e85bb5d --- /dev/null +++ b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml @@ -0,0 +1,322 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: tmp/checkpoints + train_dataset_path: tests/end2end_tests/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 2 + evaluation_interval_in_steps: 1 + consistency_enforcement: + enforce_tokens_per_step_consistency: false + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 2 + local_train_micro_batch_size: 1 + sequence_length: 256 + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: global_num_target_tokens_from_checkpoint_path + config: + checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} + num_target_steps: # for the batch progress subscriber + component_key: number_conversion + variant_key: num_target_steps_from_checkpoint_path + config: + checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} + training_progress: + global_num_seen_tokens: # used below + component_key: number_conversion + variant_key: global_num_seen_tokens_from_checkpoint_path + config: + checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} + num_seen_steps: # for the batch progress subscriber + component_key: number_conversion + variant_key: num_seen_steps_from_checkpoint_path + config: + checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} + num_seen_samples: + component_key: number_conversion + variant_key: num_samples_from_num_tokens + config: + num_tokens: ${settings.training_progress.global_num_seen_tokens} + sequence_length: ${settings.step_profile.sequence_length} + last_step: # for the scheduler + component_key: number_conversion + variant_key: last_step_from_checkpoint_path + config: + checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} + warmstart_checkpoint_paths: + # we pass in the checkpoint paths as filenames such that the num_target_tokens and num_target_steps can be calculated and correctly passed to the training loop + # Within the test is replaced with the actual path to the checkpoint. + model_checkpoint_path: eid_0-seen_steps_4-seen_tokens_4096-target_steps_7-target_tokens_7168 + optimizer_checkpoint_path: eid_0-seen_steps_4-seen_tokens_4096-target_steps_7-target_tokens_7168 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: [] + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 1 + tensor_parallel_degree: 1 + data_parallel_shard_degree: 2 + world_size: ${settings.cuda_env.world_size} + +dp_degree: + component_key: number_conversion + variant_key: parallel_degree + config: + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + parallelism_methods: [dp_shard, dp_replicate] + +app_state: + component_key: app_state + variant_key: dcp + config: + raw_app_state: + instance_key: app_state_raw + pass_type: BY_REFERENCE + checkpoint_dir_path: checkpoint/path/will/be/set/in/code + +app_state_raw: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_raw: + component_key: model + variant_key: gpt2 + config: + seed: 42 + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 8 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + # last_epoch: ${settings.training_progress.last_step} # Not required. App state will take care of the correct initialization. + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: save_all + config: {} diff --git a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_pp_tp.yaml b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_pp_tp.yaml index 81e9a566b..b298c8a3a 100644 --- a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_pp_tp.yaml +++ b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_pp_tp.yaml @@ -215,7 +215,7 @@ scheduled_pipeline: loss_fn: instance_key: loss_fn pass_type: BY_REFERENCE - pp_schedule_name: gpipe + pp_schedule_name: Interleaved1F1B batch_size: ${settings.step_profile.local_train_micro_batch_size} microbatch_size: 1 pp_degree: ${device_mesh.config.pipeline_parallel_degree} @@ -292,7 +292,7 @@ staged_pipeline: instance_key: device_mesh pass_type: BY_REFERENCE local_rank: ${settings.cuda_env.local_rank} - pp_schedule_name: gpipe + pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name} num_layers_per_stage: 2 model_raw: diff --git a/tests/end2end_tests/test_fsdp2_warmstart_pp_tp.py b/tests/end2end_tests/test_fsdp2_warmstart_pp_tp.py index 31c05ca2a..0020fc087 100644 --- a/tests/end2end_tests/test_fsdp2_warmstart_pp_tp.py +++ b/tests/end2end_tests/test_fsdp2_warmstart_pp_tp.py @@ -32,6 +32,9 @@ tmp_folder = working_dir / "../tmp/fsdp2_warmstart_pp_tp" working_dir = working_dir / "configs" +num_steps = 7 +num_tokens = 1024 * num_steps + class TrainDataloaderInstantiationModel(BaseModel): settings: TrainingComponentsInstantiationModel.Settings @@ -48,6 +51,9 @@ class TestWarmstart: [ ("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_pp_tp.yaml", 8, 8), ("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 8, 2), + # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), + # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2), + # ("gpt2_train_num_steps_7_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), ("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_grad_accu.yaml", 8, 1), ("gpt2_train_num_steps_7_grad_accu.yaml", "gpt2_warm_start_from_step_4_pp_tp.yaml", 1, 8), ], @@ -158,10 +164,14 @@ def _first_training_impl(process_id: int, first_config: str, checkpoint_root_pat assert checkpoint_info_file_path.exists(), "Missing last_checkpoint_info.json after first training." with open(checkpoint_info_file_path, "r") as f: checkpoint_info = json.load(f) - expected_cp_suffix = "eid_0-seen_steps_4-seen_tokens_4096-target_steps_7-target_tokens_7168" - assert checkpoint_info["checkpoint_folder_path"].endswith( - expected_cp_suffix - ), "Checkpoint info file does not point to expected step 4 folder." + expected_cp_suffix = ( + f"eid_0-seen_steps_4-seen_tokens_4096-target_steps_{num_steps}-target_tokens_{num_tokens}" + ) + assert checkpoint_info["checkpoint_folder_path"].endswith(expected_cp_suffix), ( + "Checkpoint info file does not point to expected step 4 folder.\n" + f" Expected suffix: {expected_cp_suffix}\n" + f" Got: {checkpoint_info['checkpoint_folder_path']}" + ) assert Path(checkpoint_info["checkpoint_folder_path"]).exists(), "Checkpoint folder path does not exist." # enumerate checkpoint paths and ensure max seen matches info @@ -230,7 +240,9 @@ def _second_training_impl(process_id: int, second_config: str, checkpoint_root_p checkpoint_path = str(checkpoint_root_path) # path to checkpoint from first training (step 4) warmstart_checkpoint_dir = ( - checkpoint_root_path / "0" / "eid_0-seen_steps_4-seen_tokens_4096-target_steps_7-target_tokens_7168" + checkpoint_root_path + / "0" + / f"eid_0-seen_steps_4-seen_tokens_4096-target_steps_{num_steps}-target_tokens_{num_tokens}" ) gpt2_warm_start_config_dict["app_state"]["config"]["checkpoint_dir_path"] = str(warmstart_checkpoint_dir) gpt2_warm_start_config_dict["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"][ @@ -270,18 +282,24 @@ def _second_training_impl(process_id: int, second_config: str, checkpoint_root_p with open(checkpoint_root_path / "experiment_0_loss_scores.txt", "r") as f: loaded_loss_values_0 = json.load(f) - assert loaded_loss_values_0[4:] == pytest.approx( - loss_scores_1, abs=1e-16 - ), "Warmstart loss trajectory mismatch with from-scratch continuation." + assert loaded_loss_values_0[4:] == pytest.approx(loss_scores_1, rel=1e-2), ( + "Warmstart loss trajectory mismatch with from-scratch continuation.\n" + f"Expected {loaded_loss_values_0[4:]}, got {loss_scores_1}." + ) # Additionally assert checkpoint info integrity from first run still present checkpoint_info_file_path = checkpoint_root_path / "0" / "last_checkpoint_info.json" assert checkpoint_info_file_path.exists(), "Missing last_checkpoint_info.json from first training." with open(checkpoint_info_file_path, "r") as f: checkpoint_info = json.load(f) - assert checkpoint_info["checkpoint_folder_path"].endswith( - "eid_0-seen_steps_4-seen_tokens_4096-target_steps_7-target_tokens_7168" - ), "Incorrect checkpoint folder path recorded." + expected_cp_suffix = ( + f"eid_0-seen_steps_4-seen_tokens_4096-target_steps_{num_steps}-target_tokens_{num_tokens}" + ) + assert checkpoint_info["checkpoint_folder_path"].endswith(expected_cp_suffix), ( + "Incorrect checkpoint folder path recorded.\n" + f" Expected suffix: {expected_cp_suffix}\n" + f" Got: {checkpoint_info['checkpoint_folder_path']}" + ) # Compare final scheduler state with open(scheduler_info_path, "r") as f: diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml index f41e912bc..de0a070f6 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml @@ -38,7 +38,7 @@ scheduled_pipeline: loss_fn: instance_key: loss_fn pass_type: BY_REFERENCE - pp_schedule_name: gpipe + pp_schedule_name: Interleaved1F1B batch_size: ${settings.step_profile.local_train_micro_batch_size} microbatch_size: 2 pp_degree: ${device_mesh.config.pipeline_parallel_degree} @@ -100,7 +100,7 @@ staged_pipeline: instance_key: device_mesh pass_type: BY_REFERENCE local_rank: ${settings.cuda_env.local_rank} - pp_schedule_name: gpipe + pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name} num_layers_per_stage: 4 initialized_model: diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml index fb8ee5f7d..cd822525c 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml @@ -38,7 +38,7 @@ scheduled_pipeline: loss_fn: instance_key: loss_fn pass_type: BY_REFERENCE - pp_schedule_name: gpipe + pp_schedule_name: Interleaved1F1B batch_size: ${settings.step_profile.local_train_micro_batch_size} microbatch_size: 2 pp_degree: ${device_mesh.config.pipeline_parallel_degree} @@ -111,7 +111,7 @@ staged_pipeline: instance_key: device_mesh pass_type: BY_REFERENCE local_rank: ${settings.cuda_env.local_rank} - pp_schedule_name: gpipe + pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name} num_layers_per_stage: 4 initialized_model: diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py index 9534153a3..76bf64691 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -106,13 +106,13 @@ def _forward_step_with_pp( print(f"Exception in _forward_step_with_pp: {e}") traceback.print_exc() raise e - return scheduled_pipeline.is_last_pp_stage, loss_pp + return scheduled_pipeline.has_last_pp_stage, loss_pp def _forward_step(self, scheduled_pipeline: Pipeline, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """Runs a forward step on the model.""" pp_schedule = scheduled_pipeline.pp_schedule - targets, losses = (targets, []) if scheduled_pipeline.is_last_pp_stage else (None, None) - if scheduled_pipeline.is_first_pp_stage: + targets, losses = (targets, []) if scheduled_pipeline.has_last_pp_stage else (None, None) + if scheduled_pipeline.has_first_pp_stage: pp_schedule.step(inputs, target=targets, losses=losses) else: pp_schedule.step(target=targets, losses=losses) @@ -120,7 +120,7 @@ def _forward_step(self, scheduled_pipeline: Pipeline, inputs: torch.Tensor, targ # accumulate losses across pipeline microbatches return ( torch.mean(torch.stack(losses)).to(losses[0].device) - if scheduled_pipeline.is_last_pp_stage + if scheduled_pipeline.has_last_pp_stage else torch.tensor([-1.0], device=inputs.device) ) diff --git a/tests/fsdp2_parallelization/test_tensor_parallelism.py b/tests/fsdp2_parallelization/test_tensor_parallelism.py index f611b7164..6ae48fa04 100644 --- a/tests/fsdp2_parallelization/test_tensor_parallelism.py +++ b/tests/fsdp2_parallelization/test_tensor_parallelism.py @@ -17,6 +17,7 @@ from modalities.models.gpt2.gpt2_model import TransformerMLP from modalities.models.model import SwiGLU from tests.end2end_tests.custom_components import MultiProcessingCudaEnv +from tests.utility import find_free_port def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path) -> Path: @@ -34,7 +35,7 @@ def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: @pytest.fixture -def tmp_config_dir(tmp_path_factory) -> Path: +def tmp_config_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: return tmp_path_factory.mktemp("patched_configs") @@ -55,19 +56,17 @@ class ComponentsInstantiationModel(BaseModel): return components.model, components.device_mesh @pytest.mark.parametrize( - "activation_type, fsdp2_config_path, tp_config_path, port", + "activation_type, fsdp2_config_path, tp_config_path", [ ( "gelu", Path("tests/fsdp2_parallelization/tp_test_configs/fsdp2_config.yaml"), Path("tests/fsdp2_parallelization/tp_test_configs/tp_config.yaml"), - 22235, ), ( "swiglu", Path("tests/fsdp2_parallelization/tp_test_configs/fsdp2_config.yaml"), Path("tests/fsdp2_parallelization/tp_test_configs/tp_config.yaml"), - 22246, ), ], ) @@ -77,9 +76,9 @@ def test_tp_sharding( fsdp2_config_path: Path, tp_config_path: Path, tmp_config_dir: Path, - port: int, ): world_size = 4 + port = find_free_port() mp.spawn( self._test_tp_sharding_impl, args=(activation_type, fsdp2_config_path, tp_config_path, world_size, tmp_config_dir, port), diff --git a/tests/test_gym.py b/tests/test_gym.py index 82d073261..f6bbcf92b 100644 --- a/tests/test_gym.py +++ b/tests/test_gym.py @@ -1,18 +1,26 @@ from unittest.mock import call +from pytest import MonkeyPatch + +from modalities.checkpointing.checkpoint_saving import CheckpointSaving +from modalities.checkpointing.stateful.app_state import AppState +from modalities.dataloader.dataloader import LLMDataLoader +from modalities.evaluator import Evaluator from modalities.gym import Gym +from modalities.loss_functions import Loss +from modalities.trainer import Trainer from tests.utility import configure_dataloader_mock def test_run_cpu_only( - monkeypatch, - checkpoint_saving_mock, - evaluator_mock, - app_state_mock, - loss_mock, - llm_data_loader_mock, set_env_cpu, - trainer, + monkeypatch: MonkeyPatch, + checkpoint_saving_mock: CheckpointSaving, + evaluator_mock: Evaluator, + app_state_mock: AppState, + loss_mock: Loss, + llm_data_loader_mock: LLMDataLoader, + trainer: Trainer, ): num_batches = 4 num_ranks = 1 @@ -36,5 +44,5 @@ def test_run_cpu_only( evaluation_data_loaders=[], checkpoint_saving=checkpoint_saving_mock, ) - app_state_mock.model.assert_has_calls([call(b.samples) for b in batches]) + app_state_mock.model_parts[0].assert_has_calls([call(b.samples) for b in batches]) app_state_mock.optimizer.step.assert_called() diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index 46bfc4b81..58793e7ae 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -1,20 +1,26 @@ -from unittest.mock import call +from unittest.mock import MagicMock, call import numpy as np +from modalities.checkpointing.checkpoint_saving import CheckpointSaving +from modalities.checkpointing.stateful.app_state import AppState +from modalities.dataloader.dataloader import LLMDataLoader +from modalities.evaluator import Evaluator from modalities.gym import Gym +from modalities.loss_functions import Loss from modalities.optimizers.lr_schedulers import DummyLRScheduler +from modalities.trainer import Trainer from tests.utility import configure_dataloader_mock def test_run_scheduler( set_env_cpu, - checkpoint_saving_mock, - evaluator_mock, - app_state_mock, - loss_mock, - llm_data_loader_mock, - trainer, + checkpoint_saving_mock: CheckpointSaving, + evaluator_mock: Evaluator, + app_state_mock: AppState, + loss_mock: Loss, + llm_data_loader_mock: LLMDataLoader, + trainer: Trainer, ): num_batches = 4 num_ranks = 1 @@ -38,11 +44,11 @@ def test_run_scheduler( checkpointing_interval_in_steps=1, evaluation_interval_in_steps=1, ) - app_state_mock.model.assert_has_calls([call(b.samples) for b in batches]) + app_state_mock.model_parts[0].assert_has_calls([call(b.samples) for b in batches]) app_state_mock.lr_scheduler.step.assert_called() -def test_dummy_lr_scheduler(optimizer_with_param_groups_mock): +def test_dummy_lr_scheduler(optimizer_with_param_groups_mock: MagicMock): # we test that the optimizer step function reduces the lr by 0.01 for each param group. # we also test that the scheduler step function does not change the lr. diff --git a/tests/test_util.py b/tests/test_util.py index f71450982..c82521b17 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -140,16 +140,16 @@ class CustomComponentInstantiationModel(BaseModel): components: CustomComponentInstantiationModel = main_obj.build_components( components_model_type=CustomComponentInstantiationModel ) - wrapped_model = components.app_state.model + wrapped_model = components.app_state.model_parts TestUtils._assert_correct_total_number_of_trainable_parameters( wrapped_model=wrapped_model, - device_mesh=components.device_mesh, expected_nr_parameters=expected_nr_parameters, + device_mesh=components.device_mesh, ) def _assert_correct_total_number_of_trainable_parameters( - wrapped_model: FSDPX, expected_nr_parameters: int, device_mesh: PydanticDeviceMeshIFType | None + wrapped_model: FSDPX | list[FSDPX], expected_nr_parameters: int, device_mesh: PydanticDeviceMeshIFType | None ): nr_parameters = get_total_number_of_trainable_parameters(model=wrapped_model, device_mesh=device_mesh) assert nr_parameters == expected_nr_parameters diff --git a/tests/training/gradient_clipping/test_fsdp_gradient_clipper.py b/tests/training/gradient_clipping/test_fsdp_gradient_clipper.py index edc797151..b1cc9d1ea 100644 --- a/tests/training/gradient_clipping/test_fsdp_gradient_clipper.py +++ b/tests/training/gradient_clipping/test_fsdp_gradient_clipper.py @@ -123,7 +123,7 @@ def test_fsdp2_gradient_clipper(): max_norm = 1.0 norm_type = GradientClippingMode.P2_NORM - clipper = FSDP2GradientClipper(wrapped_model=mock_model, max_norm=max_norm, norm_type=norm_type) + clipper = FSDP2GradientClipper(model_parts=mock_model, max_norm=max_norm, norm_type=norm_type) # Call clip_gradients norm = clipper.clip_gradients() @@ -144,7 +144,7 @@ def test_fsdp2_logging_only_gradient_clipper(): mock_model = MockFSDPModel() norm_type = GradientClippingMode.P2_NORM - clipper = FSDP2LoggingOnlyGradientClipper(wrapped_model=mock_model, norm_type=norm_type) + clipper = FSDP2LoggingOnlyGradientClipper(model_parts=mock_model, norm_type=norm_type) # Call clip_gradients norm = clipper.clip_gradients() @@ -195,7 +195,7 @@ def test_pipeline_parallelized_clipping_equivalent_to_single_stage_clipping(): # perform clipping on the full model (single-stage) FSDP2GradientClipper( - wrapped_model=full, + model_parts=full, max_norm=max_norm, norm_type=GradientClippingMode.P2_NORM, device_mesh=None, @@ -298,7 +298,7 @@ def __getitem__(self, name: str): # call the clipping function which will perform all_reduce across the pp group FSDP2GradientClipper( - wrapped_model=part, + model_parts=part, max_norm=max_norm, norm_type=GradientClippingMode.P2_NORM, device_mesh=mesh, diff --git a/tests/utils/test_maybe_list_parameter.py b/tests/utils/test_maybe_list_parameter.py new file mode 100644 index 000000000..74cc0fef5 --- /dev/null +++ b/tests/utils/test_maybe_list_parameter.py @@ -0,0 +1,179 @@ +from modalities.utils.maybe_list_parameter import maybe_list_parameter + + +def test_maybe_list_parameter_works_on_single_item(): + @maybe_list_parameter("x") + def square(x: int) -> int: + return x * x + + result = square(x=3) + assert result == 9 + + +def test_maybe_list_parameter_works_on_list(): + @maybe_list_parameter("x") + def square(x: int) -> int: + return x * x + + result = square(x=[1, 2, 3, 4]) + assert result == [1, 4, 9, 16] + + +def test_maybe_list_parameter_works_on_positional_args(): + @maybe_list_parameter("x") + def square(x: int) -> int: + return x * x + + result = square(5) + assert result == 25 + + result = square([2, 3, 4]) + assert result == [4, 9, 16] + + +def test_maybe_list_parameter_raises_on_missing_parameter(): + try: + + @maybe_list_parameter("y") + def square(x: int) -> int: + return x * x + + except ValueError as e: + assert str(e) == "Parameter 'y' not found in function 'square' signature." + else: + assert False, "Expected ValueError was not raised" + + +def test_maybe_list_parameter_works_on_mixed_args(): + @maybe_list_parameter("x") + def add_and_square(x: int, y: int) -> int: + return (x + y) * (x + y) + + result = add_and_square(2, y=3) + assert result == 25 # (2 + 3)^2 + + result = add_and_square([1, 2], y=3) + assert result == [16, 25] # (1 + 3)^2, (2 + 3)^2 + + +def test_maybe_list_parameter_works_on_mixed_args_list_positional(): + @maybe_list_parameter("x") + def add_and_square(x: int, y: int) -> int: + return (x + y) * (x + y) + + result = add_and_square(4, y=5) + assert result == 81 # (4 + 5)^2 + + result = add_and_square([2, 3], y=5) + assert result == [49, 64] # (2 + 5)^2, (3 + 5)^2 + + +def test_maybe_list_parameter_works_on_mixed_args_list_positional_no_kw(): + @maybe_list_parameter("x") + def add_and_square(x: int, y: int) -> int: + return (x + y) * (x + y) + + result = add_and_square(6, 7) + assert result == 169 # (6 + 7)^2 + + result = add_and_square([3, 4], 7) + assert result == [100, 121] # (3 + 7)^2, (4 + 7)^2 + + +def test_maybe_list_parameter_works_on_empty_list(): + @maybe_list_parameter("x") + def square(x: int) -> int: + return x * x + + result = square(x=[]) + assert result == [] + + +def test_maybe_list_parameter_works_on_no_value(): + @maybe_list_parameter("x") + def square(x: int = 10) -> int: + return x * x + + result = square() + assert result == 100 + + +def test_maybe_list_parameter_works_on_no_value_list(): + @maybe_list_parameter("x") + def square(x: int = 10) -> int: + return x * x + + result = square(x=[]) + assert result == [] + + +def test_maybe_list_parameter_works_on_no_value_with_other_args(): + @maybe_list_parameter("x") + def add_and_square(x: int = 10, y: int = 5) -> int: + return (x + y) * (x + y) + + result = add_and_square(y=3) + assert result == 169 # (10 + 3)^2 + + +def test_maybe_list_parameter_works_on_no_value_with_other_args_list(): + @maybe_list_parameter("x") + def add_and_square(x: int = 10, y: int = 5) -> int: + return (x + y) * (x + y) + + result = add_and_square(x=[], y=3) + assert result == [] + + +def test_maybe_list_parameter_works_on_multiple_parameters(): + @maybe_list_parameter("x") + @maybe_list_parameter("y") + def add(x: int, y: int) -> int: + return x + y + + result = add(x=[1, 2], y=[10, 20]) + assert result == [[11, 21], [12, 22]] + + +def test_maybe_list_parameter_works_on_multiple_parameters_mixed(): + @maybe_list_parameter("x") + @maybe_list_parameter("y") + def add(x: int, y: int) -> int: + return x + y + + result = add(x=[1, 2], y=30) + assert result == [31, 32] + + result = add(x=7, y=[70, 80]) + assert result == [77, 87] + + +def test_maybe_list_parameter_works_on_multiple_parameters_single(): + @maybe_list_parameter("x") + @maybe_list_parameter("y") + def add(x: int, y: int) -> int: + return x + y + + result = add(x=4, y=6) + assert result == 10 + + +def test_maybe_list_parameter_calls_apply_to_list_result(): + @maybe_list_parameter("x", apply_to_list_result=sum) + def square(x: int) -> int: + return x * x + + result = square(x=[1, 2, 3]) + assert result == 14 # 1^2 + 2^2 + 3^2 = 14 + + +def test_maybe_list_parameter_calls_apply_to_list_input_and_result(): + def apply_func(inputs: list[int], results: list[int]) -> list[int]: + return [i + r for i, r in zip(inputs, results)] + + @maybe_list_parameter("x", apply_to_list_input_and_result=apply_func) + def square(x: int) -> int: + return x * x + + result = square(x=[1, 2, 3]) + assert result == [2, 6, 12] # [1+1^2, 2+2^2, 3+3^2]