From aae544390f446002423fe5b07409bcd556b9a2f2 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 30 Jan 2026 15:55:50 +0800 Subject: [PATCH 01/43] update --- swift/megatron/init.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index d934e20292..fb2d1ff909 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -876,7 +876,7 @@ def new_save_sharded_modelopt_state(model, *args, **kwargs): checkpointing.save_sharded_modelopt_state = new_save_sharded_modelopt_state -def _patch_megatron(): +def init_megatron_env(): os.environ.pop('VLLM_USE_MODELSCOPE', None) logging_level = logging.root.level _patch_flash_attn() @@ -917,16 +917,3 @@ def _patch_megatron(): import megatron.core logger.info(f'megatron.core.__version__: {megatron.core.__version__}') - - -def init_megatron_env() -> None: - if 'MEGATRON_LM_PATH' not in os.environ: - # TODO: Synchronization issues may occur in DDP scenarios - # if the distributed environment has not been initialized. - os.environ['MEGATRON_LM_PATH'] = git_clone_github( - 'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.15.0') - with safe_ddp_context(hash_id='megatron-lm'): - if not is_megatron_available(): - subprocess_run([sys.executable, '-m', 'pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']]) - sys.path.insert(0, os.environ['MEGATRON_LM_PATH']) - _patch_megatron() From 4939ccbb1c2a89c592032c3b7d662a41c8d4add0 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 2 Feb 2026 20:02:57 +0800 Subject: [PATCH 02/43] remove megatron.training --- swift/megatron/convert.py | 3 --- swift/megatron/init.py | 24 +++++------------------ swift/megatron/model/gpt_bridge.py | 1 - swift/megatron/model/gpt_model.py | 1 - swift/megatron/model/gpts/minimax_m2.py | 1 - swift/megatron/model/gpts/qwen3_next.py | 1 - swift/megatron/model/mm_gpt_model.py | 1 - swift/megatron/model/mm_gpts/kimi_vl.py | 1 - swift/megatron/model/mm_gpts/llama4.py | 1 - swift/megatron/model/mm_gpts/qwen3_vl.py | 1 - swift/megatron/model/mm_gpts/utils.py | 1 - swift/megatron/model/model_provider.py | 6 +++--- swift/megatron/model/rope.py | 1 - swift/megatron/pipelines/export/export.py | 6 +++--- swift/megatron/trainers/base.py | 16 +++++++-------- swift/megatron/utils/utils.py | 10 +++------- 16 files changed, 22 insertions(+), 53 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 64c5c38718..a64e4aedb0 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -6,9 +6,6 @@ from dataclasses import fields import torch -from megatron.training.checkpointing import load_checkpoint -from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint -from megatron.training.initialize import initialize_megatron from transformers.utils import strtobool from swift.arguments import ExportArguments diff --git a/swift/megatron/init.py b/swift/megatron/init.py index fb2d1ff909..8913289dd7 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -20,8 +20,7 @@ from tqdm import tqdm from transformers.utils import is_torch_npu_available -from swift.utils import (get_logger, git_clone_github, is_flash_attn_3_available, is_megatron_available, - safe_ddp_context, split_list, subprocess_run) +from swift.utils import get_logger, is_flash_attn_3_available, is_last_rank, split_list logger = get_logger() @@ -384,15 +383,6 @@ def sharded_state_dict( TEGroupedLinear.sharded_state_dict = sharded_state_dict -def _patch_megatron_tokenizer(): - from megatron.training import global_vars - - def build_tokenizer(args): - return 'dummy_tokenizer' - - global_vars.build_tokenizer = build_tokenizer - - def _patch_mtp(): from megatron.core import InferenceParams from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer @@ -527,7 +517,6 @@ def sharded_state_dict( def _patch_TransformerLayer(): import megatron.core - from megatron.training import get_args from megatron.core.transformer import TransformerLayer _origin_forward = TransformerLayer.forward mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -673,7 +662,6 @@ def _patch_mrope(): import megatron.core from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd from megatron.core.models.common.embeddings import rope_utils - from megatron.training import get_args mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -811,7 +799,6 @@ def _new_load_inline(*args, **kwargs): def _patch_megatron_timeout(): - from megatron.training import get_args from megatron.core import parallel_state create_group_origin = parallel_state.create_group @@ -826,7 +813,7 @@ def create_group(ranks=None, timeout=None, *_args, **kwargs): def _patch_megatron_swanlab(): - from megatron.training import global_vars, is_last_rank, wandb_utils, get_args + from megatron.training import global_vars, wandb_utils, get_args def _set_wandb_writer(*_args, **kwargs): args = get_args() @@ -888,14 +875,13 @@ def init_megatron_env(): _patch_TEGroupedLinear() _patch_TransformerLayer() _patch_compile_helpers() - _patch_build_train_valid_test_datasets() + # _patch_build_train_valid_test_datasets() _patch_mrope() _patch__write_item() - _patch_megatron_tokenizer() _patch_mtp() _patch_megatron_timeout() - _patch_megatron_swanlab() - _patch_modelopt() + # _patch_megatron_swanlab() + # _patch_modelopt() logging.root.setLevel(logging_level) # revert logger level from swift.megatron import tuners # patch lora try: diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index c1954069dc..84bd2c1bef 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -9,7 +9,6 @@ import torch.nn.functional as F import transformers from megatron.core import mpu -from megatron.training import get_args from packaging import version from peft.utils import ModulesToSaveWrapper from tqdm import tqdm diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 16e6cb378e..d05ee8f860 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -23,7 +23,6 @@ from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import WrappedTensor, deprecate_inference_params -from megatron.training import get_args from packaging import version from swift.utils import get_logger diff --git a/swift/megatron/model/gpts/minimax_m2.py b/swift/megatron/model/gpts/minimax_m2.py index a628cd6c46..1ec21df584 100644 --- a/swift/megatron/model/gpts/minimax_m2.py +++ b/swift/megatron/model/gpts/minimax_m2.py @@ -8,7 +8,6 @@ from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.training import get_args from packaging import version from swift.model import ModelType diff --git a/swift/megatron/model/gpts/qwen3_next.py b/swift/megatron/model/gpts/qwen3_next.py index 36beebbf45..5bc665b49f 100644 --- a/swift/megatron/model/gpts/qwen3_next.py +++ b/swift/megatron/model/gpts/qwen3_next.py @@ -17,7 +17,6 @@ from megatron.core.transformer.transformer_block import TransformerBlockSubmodules from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import deprecate_inference_params, is_fa_min_version -from megatron.training import get_args from packaging import version from swift.megatron.utils import get_local_layer_specs diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index 86ff306491..ec0de8033b 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -9,7 +9,6 @@ from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.training import get_args from packaging import version from .gpt_model import GPTModel diff --git a/swift/megatron/model/mm_gpts/kimi_vl.py b/swift/megatron/model/mm_gpts/kimi_vl.py index 0f9dc43edd..0eb9ddb1c3 100644 --- a/swift/megatron/model/mm_gpts/kimi_vl.py +++ b/swift/megatron/model/mm_gpts/kimi_vl.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import torch -from megatron.training import get_args from PIL import Image from transformers.dynamic_module_utils import get_class_from_dynamic_module diff --git a/swift/megatron/model/mm_gpts/llama4.py b/swift/megatron/model/mm_gpts/llama4.py index 8237157103..2950d9b800 100644 --- a/swift/megatron/model/mm_gpts/llama4.py +++ b/swift/megatron/model/mm_gpts/llama4.py @@ -6,7 +6,6 @@ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.transformer_layer import get_transformer_layer_offset -from megatron.training import get_args from packaging import version from swift.model import ModelType diff --git a/swift/megatron/model/mm_gpts/qwen3_vl.py b/swift/megatron/model/mm_gpts/qwen3_vl.py index 8b150fd3f7..0607bdd11e 100644 --- a/swift/megatron/model/mm_gpts/qwen3_vl.py +++ b/swift/megatron/model/mm_gpts/qwen3_vl.py @@ -10,7 +10,6 @@ from megatron.core.models.gpt import gpt_model from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor -from megatron.training import get_args from PIL import Image from swift.model import ModelType diff --git a/swift/megatron/model/mm_gpts/utils.py b/swift/megatron/model/mm_gpts/utils.py index d9ced80c27..1e6151ad4c 100644 --- a/swift/megatron/model/mm_gpts/utils.py +++ b/swift/megatron/model/mm_gpts/utils.py @@ -4,7 +4,6 @@ import torch from megatron.core.models.huggingface import HuggingFaceModule as _HuggingFaceModule -from megatron.training import get_args from transformers import PreTrainedModel from transformers.utils import ContextManagers diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py index 5c32fe60b4..2c73edbbd8 100644 --- a/swift/megatron/model/model_provider.py +++ b/swift/megatron/model/model_provider.py @@ -10,9 +10,6 @@ get_gpt_mtp_block_spec) from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import get_gpt_heterogeneous_layer_spec from megatron.core.transformer.spec_utils import import_module -from megatron.training import get_args, print_rank_0 -from megatron.training.arguments import core_transformer_config_from_args -from megatron.training.yaml_arguments import core_transformer_config_from_yaml from packaging import version mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -74,6 +71,9 @@ def model_provider(pre_process=True, Returns: Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model """ + from megatron.training import get_args, print_rank_0 + from megatron.training.arguments import core_transformer_config_from_args + from megatron.training.yaml_arguments import core_transformer_config_from_yaml from .register import get_megatron_model_meta args = get_args() use_te = args.transformer_impl == 'transformer_engine' diff --git a/swift/megatron/model/rope.py b/swift/megatron/model/rope.py index 6e9cb5c8bf..241893b009 100644 --- a/swift/megatron/model/rope.py +++ b/swift/megatron/model/rope.py @@ -3,7 +3,6 @@ import torch import transformers -from megatron.training import get_args from packaging import version from transformers import PretrainedConfig diff --git a/swift/megatron/pipelines/export/export.py b/swift/megatron/pipelines/export/export.py index 41e0f6f699..f2c2504139 100644 --- a/swift/megatron/pipelines/export/export.py +++ b/swift/megatron/pipelines/export/export.py @@ -5,9 +5,9 @@ import torch.distributed as dist from megatron.core import mpu -from megatron.training import initialize_megatron -from megatron.training.checkpointing import load_checkpoint -from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint +# from megatron.training import initialize_megatron +# from megatron.training.checkpointing import load_checkpoint +# from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint from transformers.utils import strtobool from swift.megatron.arguments import MegatronExportArguments diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 5e445f342a..3d9cd06a5b 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -25,14 +25,14 @@ from megatron.core.transformer.moe.moe_utils import track_moe_metrics from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper from megatron.core.utils import StragglerDetector -from megatron.training import (checkpointing, ft_integration, get_args, get_model, get_tensorboard_writer, get_timers, - get_wandb_writer, initialize, is_last_rank, one_logger_utils, pretrain, print_rank_0, - print_rank_last, training) -from megatron.training.checkpointing import check_checkpoint_args, load_checkpoint, set_checkpoint_version -from megatron.training.dist_signal_handler import DistributedSignalHandler -from megatron.training.theoretical_memory_usage import report_theoretical_memory -from megatron.training.training import num_floating_point_operations -from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory, unwrap_model +# from megatron.training import (checkpointing, ft_integration, get_args, get_model, get_tensorboard_writer, get_timers, +# get_wandb_writer, initialize, is_last_rank, one_logger_utils, pretrain, print_rank_0, +# print_rank_last, training) +# from megatron.training.checkpointing import check_checkpoint_args, load_checkpoint, set_checkpoint_version +# from megatron.training.dist_signal_handler import DistributedSignalHandler +# from megatron.training.theoretical_memory_usage import report_theoretical_memory +# from megatron.training.training import num_floating_point_operations +# from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory, unwrap_model from modelscope import check_local_model_is_latest from packaging import version from tqdm.auto import tqdm diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index 84bb997dd4..85bdbfcba2 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -13,7 +13,6 @@ from megatron.core.transformer.transformer_block import get_num_layers_to_build from megatron.core.transformer.transformer_layer import get_transformer_layer_offset from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default -from megatron.training import checkpointing, get_args from packaging import version from peft.tuners.lora import Linear as LoraLinear from peft.utils.other import ModulesToSaveWrapper @@ -165,9 +164,8 @@ def new_deepcopy(x, *args, **kwargs): copy.deepcopy = _origin_deepcopy -def prepare_adapter(model): +def prepare_adapter(args, model): from swift.megatron.tuners import LoraParallelLinear - args = get_args() set_linear_is_expert(model) target_modules = get_target_modules(args, model) modules_to_save = get_modules_to_save(args, model) @@ -202,8 +200,7 @@ def prepare_adapter(model): return model -def prepare_mcore_model(model): - args = get_args() +def prepare_mcore_model(args, model): if args.tuner_type == 'full': freeze_parameters(model, args.freeze_parameters_ratio, args.freeze_parameters, args.freeze_parameters_regex) if args.trainable_parameters or args.trainable_parameters_regex: @@ -302,8 +299,7 @@ def copy_ref_adapter_weight(model, ref_adapter_name: str): sub_module[ref_adapter_name].load_state_dict(sub_module['default'].state_dict()) -def forward_step_helper(model, inputs, dtype=None): - args = get_args() +def forward_step_helper(args, model, inputs, dtype=None): if mpu.is_pipeline_first_stage(): micro_batch_size = 1 # use qkv_format 'thd' if not args.padding_free: From bcfe5123991f8fcba582f922168305250e969bec Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 2 Feb 2026 22:41:04 +0800 Subject: [PATCH 03/43] update --- swift/megatron/arguments/megatron_args.py | 39 +------- .../megatron/arguments/megatron_base_args.py | 1 - swift/megatron/convert.py | 11 +-- swift/megatron/init.py | 17 +--- swift/megatron/model/model_provider.py | 5 +- swift/megatron/model/register.py | 2 - swift/megatron/pipelines/export/export.py | 6 +- swift/megatron/trainers/base.py | 10 +- swift/megatron/utils/__init__.py | 1 + swift/megatron/utils/convert_utils.py | 5 +- swift/megatron/utils/megatron_lm_utils.py | 93 +++++++++++++++++++ swift/megatron/utils/patcher.py | 3 +- 12 files changed, 112 insertions(+), 81 deletions(-) create mode 100644 swift/megatron/utils/megatron_lm_utils.py diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 8c52c377c5..135633946c 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -612,9 +612,6 @@ class MegatronArguments(ExtraMegatronArguments): num_workers: int = 4 no_data_sharding: bool = False - # extra_args for megatron - megatron_extra_kwargs: Optional[Union[dict, str]] = None - def _set_default(self): if self.mlp_padding_free and (self.sequence_parallel or self.context_parallel_size > 1): raise ValueError('mlp_padding_free is not compatible with sequence parallel or context parallel.') @@ -734,6 +731,7 @@ def __post_init__(self): if self.apply_wd_to_qk_layernorm and self.hf_model_type != 'qwen3_next': raise ValueError('apply_wd_to_qk_layernorm is only supported for qwen3_next') self._set_default() + self._init_vpp_size() self.model_info, self.model_meta = get_model_info_meta( self.model, model_type=self.model_type, use_hf=self.use_hf, hub_token=self.hub_token) self.model_type = self.model_info.model_type @@ -780,7 +778,6 @@ def __post_init__(self): self._init_moe() self._init_mixed_precision() - self.megatron_extra_kwargs = json_parse_to_dict(self.megatron_extra_kwargs) self._init_no_rope_fusion() def _init_no_rope_fusion(self): @@ -793,37 +790,9 @@ def _init_no_rope_fusion(self): self.no_rope_fusion = False logger.info(f'Setting args.no_rope_fusion: {self.no_rope_fusion}.') - def _args_to_argv(self) -> Tuple[List[Any], Dict[str, Any]]: - new_args = [] - args_dict = asdict(self) - extra_args = {} - extra_args['model_dir'] = self.model_info.model_dir - extra_args['is_multimodal'] = self.model_meta.is_multimodal - # model_type may be overridden by megatron - extra_args['hf_model_type'] = self.model_type - megatron_extra_kwargs = args_dict.pop('megatron_extra_kwargs') - args_dict.update(megatron_extra_kwargs) - for k, value in args_dict.items(): - if k not in MegatronArguments.__annotations__ and k not in megatron_extra_kwargs: - extra_args[k] = value - continue - if value is None or value is False: - continue - new_args.append(f"--{k.replace('_', '-')}") - if isinstance(value, list): - new_args += [str(v) for v in value] - elif value is not True: - new_args.append(str(value)) - - return new_args, extra_args - - def parse_to_megatron(self): - new_args, extra_args = self._args_to_argv() - sys._old_argv = sys.argv - sys.argv = sys.argv[:1] + new_args - # parameter conflict - extra_args.pop('loss_scale', None) - return extra_args + def _init_vpp_size(self): + # TODO + self.virtual_pipeline_model_parallel_size = None def _load_adapter_config(self): assert len(self.adapters) == 1, 'Currently only support one adapter' diff --git a/swift/megatron/arguments/megatron_base_args.py b/swift/megatron/arguments/megatron_base_args.py index 549c26d55f..98e3bb7886 100644 --- a/swift/megatron/arguments/megatron_base_args.py +++ b/swift/megatron/arguments/megatron_base_args.py @@ -50,4 +50,3 @@ def init_model_args(self, tokenizer, config): if getattr(self, k) is None: setattr(self, k, v) MegatronArguments.__post_init__(self) - self.extra_args = self.parse_to_megatron() diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index a64e4aedb0..729e9583b3 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -13,7 +13,8 @@ from swift.utils import get_logger, get_n_params_grads, is_master from .arguments import MegatronArguments from .model import get_megatron_model_meta -from .utils import convert_hf_config, patch_load_base_checkpoint, patch_torch_dist_shard, test_convert_precision +from .utils import (convert_hf_config, initialize_megatron, patch_load_base_checkpoint, patch_torch_dist_shard, + test_convert_precision) logger = get_logger() @@ -58,9 +59,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: **current_convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype) - extra_args = megatron_args.parse_to_megatron() - extra_args_provider = megatron_model_meta.extra_args_provider - initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=extra_args) + initialize_megatron(megatron_args) mg_model = megatron_model_meta.model_provider() logger.info('Megatron model created successfully.') @@ -105,9 +104,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: **current_convert_kwargs, save=args.output_dir if args.to_mcore else None, torch_dtype=args.torch_dtype) - extra_args = megatron_args.parse_to_megatron() - extra_args_provider = megatron_model_meta.extra_args_provider - initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=extra_args) + initialize_megatron(megatron_args) mg_model = megatron_model_meta.model_provider() if megatron_args.load is None: diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 8913289dd7..38bc96a606 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -798,20 +798,6 @@ def _new_load_inline(*args, **kwargs): cpp_extension.load_inline = load_inline -def _patch_megatron_timeout(): - from megatron.core import parallel_state - - create_group_origin = parallel_state.create_group - - def create_group(ranks=None, timeout=None, *_args, **kwargs): - args = get_args() - if timeout is None: - timeout = timedelta(minutes=args.distributed_timeout_minutes) - return create_group_origin(ranks, timeout, *_args, **kwargs) - - parallel_state.create_group = create_group - - def _patch_megatron_swanlab(): from megatron.training import global_vars, wandb_utils, get_args @@ -868,7 +854,7 @@ def init_megatron_env(): logging_level = logging.root.level _patch_flash_attn() _patch_transformer_engine() - _patch_unified_memory() + # _patch_unified_memory() _patch_TELinear() _patch__batched_p2p_ops() _patch_mla_attention() @@ -879,7 +865,6 @@ def init_megatron_env(): _patch_mrope() _patch__write_item() _patch_mtp() - _patch_megatron_timeout() # _patch_megatron_swanlab() # _patch_modelopt() logging.root.setLevel(logging_level) # revert logger level diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py index 2c73edbbd8..86ee49666a 100644 --- a/swift/megatron/model/model_provider.py +++ b/swift/megatron/model/model_provider.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Optional, Union import megatron.core -import megatron.legacy import torch from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, @@ -56,9 +55,7 @@ def _get_transformer_layer_spec(use_te, config): # Code borrowed from NVIDIA/Megatron-LM -def model_provider(pre_process=True, - post_process=True, - vp_stage: Optional[int] = None) -> Union['GPTModel', megatron.legacy.model.GPTModel]: +def model_provider(pre_process=True, post_process=True, vp_stage: Optional[int] = None) -> 'GPTModel': """Builds the model. If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 9f996e832a..3f7df1d5ec 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -28,8 +28,6 @@ class MegatronModelMeta: visual_cls: Optional[Type[nn.Module]] = None get_mtp_block_spec: Optional[Callable] = None - extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None - def __post_init__(self): if self.megatron_model_type in MLLMMegatronModelType.__dict__: self.is_multimodal = True diff --git a/swift/megatron/pipelines/export/export.py b/swift/megatron/pipelines/export/export.py index f2c2504139..81bd317373 100644 --- a/swift/megatron/pipelines/export/export.py +++ b/swift/megatron/pipelines/export/export.py @@ -39,8 +39,7 @@ def convert_mcore2hf(self) -> None: hf_config = self.processor.model_info.config args.init_model_args(self.tokenizer, hf_config) megatron_model_meta = args.megatron_model_meta - extra_args_provider = megatron_model_meta.extra_args_provider - initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args.extra_args) + initialize_megatron(args) pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() @@ -93,8 +92,7 @@ def convert_hf2mcore(self) -> None: self.processor = template.processor args.init_model_args(self.tokenizer, self.processor.model_info.config) megatron_model_meta = args.megatron_model_meta - extra_args_provider = megatron_model_meta.extra_args_provider - initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args.extra_args) + initialize_megatron(args) pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 3d9cd06a5b..96ead60887 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -1222,14 +1222,8 @@ def train(self, train_dataset, val_dataset, data_collator): datasets_provider = get_swift_datasets_provider(train_dataset, val_dataset) datasets_provider.is_distributed = True with self.patch_megatron_data_collator(data_collator), self._get_iters(train_dataset, val_dataset): - extra_args_provider = args.megatron_model_meta.extra_args_provider - pretrain( - datasets_provider, - args.megatron_model_meta.model_provider, - ModelType.encoder_or_decoder, - self.forward_step, - extra_args_provider=extra_args_provider, - args_defaults=args.extra_args) + pretrain(datasets_provider, args.megatron_model_meta.model_provider, ModelType.encoder_or_decoder, + self.forward_step) # Code borrowed from NVIDIA/Megatron-LM def build_pretraining_data_loader(self, dataset, consumed_samples, data_collator=None): diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 74eb9a8db5..7ab00db11d 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -2,6 +2,7 @@ from .config import convert_hf_config from .convert_utils import test_convert_precision +from .megatron_lm_utils import initialize_megatron from .patcher import patch_load_base_checkpoint, patch_merge_fn, patch_torch_dist_shard from .utils import (MegatronTrainerState, adapter_state_dict_context, copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/convert_utils.py b/swift/megatron/utils/convert_utils.py index e10faabfbd..3c4d701001 100644 --- a/swift/megatron/utils/convert_utils.py +++ b/swift/megatron/utils/convert_utils.py @@ -8,7 +8,6 @@ import torch.distributed as dist import torch.nn as nn from megatron.core import mpu -from megatron.training import get_args from swift.utils import HfConfigFactory, get_logger, to_device, to_float_dtype from .utils import forward_step_helper, get_padding_to @@ -144,7 +143,8 @@ def get_examples(is_multimodal: bool) -> Dict[str, Any]: return data -def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float32): +def test_convert_precision(args, hf_model, mg_model, template): + torch_dtype = args.test_convert_dtype template.set_mode('train') _test_params_sum(mg_model) @@ -172,7 +172,6 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float hf_logits = hf_logits.to('cuda') hf_model.to('cpu') - args = get_args() template.use_megatron = True inputs = template.encode(get_examples(is_multimodal)) mg_inputs = to_device(template.data_collator([inputs], padding_to=get_padding_to(args)), 'cuda') diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py new file mode 100644 index 0000000000..46e253a6d1 --- /dev/null +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -0,0 +1,93 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# Parts of the functions in this file are code borrowed from NVIDIA/Megatron-LM +from contextlib import contextmanager +from datetime import timedelta + +import torch +from megatron.core import mpu, tensor_parallel + +from swift.utils import get_logger, is_master, seed_everything + +logger = get_logger() + + +@contextmanager +def _patch_megatron_timeout(distributed_timeout_minutes): + from megatron.core import parallel_state + + origin_create_group = parallel_state.create_group + + def create_group(ranks=None, timeout=None, *_args, **kwargs): + if timeout is None: + timeout = timedelta(minutes=distributed_timeout_minutes) + return origin_create_group(ranks, timeout, *_args, **kwargs) + + parallel_state.create_group = create_group + try: + yield + finally: + parallel_state.create_group = origin_create_group + + +def _initialize_mpu(args): + """Initialize torch.distributed and core model parallel.""" + if not torch.distributed.is_initialized(): + raise ValueError('torch.distributed is not initialized') + args.rank = torch.distributed.get_rank() + args.world_size = torch.distributed.get_world_size() + + if mpu.model_parallel_is_initialized(): + logger.info('model parallel is already initialized') + else: + with _patch_megatron_timeout(args.distributed_timeout_minutes): + mpu.initialize_model_parallel( + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, + args.virtual_pipeline_model_parallel_size, + context_parallel_size=args.context_parallel_size, + expert_model_parallel_size=args.expert_model_parallel_size, + expert_tensor_parallel_size=args.expert_tensor_parallel_size, + distributed_timeout_minutes=args.distributed_timeout_minutes, + ) + if is_master(): + logger.info(f'tp: {args.tensor_model_parallel_size}, pp: {args.pipeline_model_parallel_size}, ' + f'vpp: {args.virtual_pipeline_model_parallel_size}, cp: {args.context_parallel_size}, ' + f'ep: {args.expert_model_parallel_size}, etp: {args.expert_tensor_parallel_size}') + + +def _set_random_seed( + seed_: int, + data_parallel_random_init: bool = True, + te_rng_tracker: bool = False, + inference_rng_tracker: bool = False, + use_cudagraphable_rng: bool = False, +): + """Set random seed for reproducability.""" + if seed_ is not None and seed_ > 0: + # Ensure that different pipeline MP stages get different seeds. + seed = seed_ + (1009 * mpu.get_pipeline_model_parallel_rank()) + # Ensure different data parallel ranks get different seeds + if data_parallel_random_init: + seed = seed + (11 * mpu.get_data_parallel_rank()) + seed_everything(seed) + if torch.cuda.device_count() > 0: + tensor_parallel.model_parallel_cuda_manual_seed(seed, te_rng_tracker, inference_rng_tracker, + use_cudagraphable_rng) + else: + raise ValueError('Seed ({}) should be a positive integer.'.format(seed_)) + + +def initialize_megatron(args): + # Pytorch distributed. + _initialize_mpu(args) + + # Random seeds for reproducibility. + logger.info(f'Setting random seeds to {args.seed}.') + _set_random_seed(args.seed) + + # Setup MoE aux loss scale value. + if args.num_experts is not None: + from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler + MoEAuxLossAutoScaler.set_loss_scale(torch.ones(1, device=torch.cuda.current_device())) + + # TODO: tp_comm_overlap, _compile_dependencies diff --git a/swift/megatron/utils/patcher.py b/swift/megatron/utils/patcher.py index ceaa967a58..b6ccc5b4ae 100644 --- a/swift/megatron/utils/patcher.py +++ b/swift/megatron/utils/patcher.py @@ -4,10 +4,11 @@ import torch from megatron.core.dist_checkpointing.mapping import ShardedTensorFactory from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy -from megatron.training import checkpointing from swift.utils import get_logger +# from megatron.training import checkpointing + logger = get_logger() From 332e1844e3b4fa94c31add4a4f4a1b95e8a49e6a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 2 Feb 2026 23:57:59 +0800 Subject: [PATCH 04/43] update --- swift/megatron/arguments/megatron_args.py | 1 + swift/megatron/model/gpt_model.py | 18 +- swift/megatron/model/model_provider.py | 191 +++++++--------------- swift/megatron/utils/megatron_lm_utils.py | 56 +++++++ 4 files changed, 121 insertions(+), 145 deletions(-) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 135633946c..a51234a1c9 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -421,6 +421,7 @@ class MegatronArguments(ExtraMegatronArguments): no_masked_softmax_fusion: bool = False no_bias_dropout_fusion: Optional[bool] = None no_bias_swiglu_fusion: bool = False + no_bias_gelu_fusion: bool = False no_rope_fusion: Optional[bool] = None no_gradient_accumulation_fusion: bool = False cross_entropy_loss_fusion: bool = False diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index d05ee8f860..1794a0560f 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -62,17 +62,10 @@ def __init__( max_sequence_length: int, pre_process: bool = True, post_process: bool = True, - fp16_lm_cross_entropy: bool = False, - parallel_output: bool = True, share_embeddings_and_output_weights: bool = False, position_embedding_type: Literal['learned_absolute', 'rope', 'mrope', 'none'] = 'learned_absolute', - rotary_percent: float = 1.0, rotary_base: int = 10000, hf_rope_scaling: Dict[str, Any] = None, - rope_scaling: bool = False, - rope_scaling_factor: float = 8.0, - scatter_embedding_sequence_parallel: bool = True, - seq_len_interpolation_factor: Optional[float] = None, mtp_block_spec: Optional[ModuleSpec] = None, vp_stage: Optional[int] = None, ): @@ -98,28 +91,19 @@ def __init__( max_sequence_length, pre_process=pre_process, post_process=post_process, - fp16_lm_cross_entropy=fp16_lm_cross_entropy, - parallel_output=parallel_output, share_embeddings_and_output_weights=share_embeddings_and_output_weights, position_embedding_type=position_embedding_type, - rotary_percent=rotary_percent, rotary_base=rotary_base, rope_scaling=rope_scaling, - rope_scaling_factor=rope_scaling_factor, - scatter_embedding_sequence_parallel=scatter_embedding_sequence_parallel, - seq_len_interpolation_factor=seq_len_interpolation_factor, mtp_block_spec=mtp_block_spec, **kwargs, ) if config.multi_latent_attention: self.rotary_pos_emb = RotaryEmbedding( kv_channels=config.qk_pos_emb_head_dim, - rotary_percent=rotary_percent, + rotary_percent=1, rotary_interleaved=config.rotary_interleaved, - seq_len_interpolation_factor=seq_len_interpolation_factor, rotary_base=rotary_base, - rope_scaling=rope_scaling, - rope_scaling_factor=rope_scaling_factor, use_cpu_initialization=config.use_cpu_initialization, ) # save memory diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py index 86ee49666a..7b4674a16e 100644 --- a/swift/megatron/model/model_provider.py +++ b/swift/megatron/model/model_provider.py @@ -11,51 +11,31 @@ from megatron.core.transformer.spec_utils import import_module from packaging import version +from swift.megatron.utils import core_transformer_config_from_args +from swift.utils import get_logger + +logger = get_logger() + mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') if TYPE_CHECKING: from .gpt_model import GPTModel -def _get_transformer_layer_spec(use_te, config): - """Get transformer layer specification based on configuration. - - Args: - use_te (bool): Whether to use Transformer Engine - args: Training arguments - config: Model configuration - - Returns: - transformer_layer_spec: The transformer layer specification - """ - args = get_args() - if use_te: - if mcore_013: - kwargs = {'qk_l2_norm': args.qk_l2_norm, 'use_kitchen': config.use_kitchen} - else: - kwargs = {} - return get_gpt_layer_with_transformer_engine_spec( - args.num_experts, - args.moe_grouped_gemm, - args.qk_layernorm, - args.multi_latent_attention, - moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, - **kwargs, - ) - else: - return get_gpt_layer_local_spec( - args.num_experts, - args.moe_grouped_gemm, - args.qk_layernorm, - args.multi_latent_attention, - moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, - normalization=args.normalization, - use_kitchen=config.use_kitchen, - ) +def _get_transformer_layer_spec(args): + kwargs = {'qk_l2_norm': args.qk_l2_norm} if mcore_013 else {} + return get_gpt_layer_with_transformer_engine_spec( + args.num_experts, + args.moe_grouped_gemm, + args.qk_layernorm, + args.multi_latent_attention, + moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, + **kwargs, + ) # Code borrowed from NVIDIA/Megatron-LM -def model_provider(pre_process=True, post_process=True, vp_stage: Optional[int] = None) -> 'GPTModel': +def model_provider(args, pre_process=True, post_process=True, vp_stage: Optional[int] = None) -> 'GPTModel': """Builds the model. If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. @@ -68,102 +48,57 @@ def model_provider(pre_process=True, post_process=True, vp_stage: Optional[int] Returns: Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model """ - from megatron.training import get_args, print_rank_0 - from megatron.training.arguments import core_transformer_config_from_args - from megatron.training.yaml_arguments import core_transformer_config_from_yaml from .register import get_megatron_model_meta - args = get_args() - use_te = args.transformer_impl == 'transformer_engine' megatron_model_meta = get_megatron_model_meta(args.hf_model_type) - if args.record_memory_history: - torch.cuda.memory._record_memory_history( - True, - # keep 100,000 alloc/free events from before the snapshot - trace_alloc_max_entries=100000, - # record stack information for the trace events - trace_alloc_record_context=True) - - def oom_observer(device, alloc, device_alloc, device_free): - # snapshot right after an OOM happened - print('saving allocated state during OOM') - snapshot = torch.cuda.memory._snapshot() - from pickle import dump - dump(snapshot, open(f'oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}', 'wb')) - - torch._C._cuda_attach_out_of_memory_observer(oom_observer) - - print_rank_0('building GPT model ...') - # Experimental loading arguments from yaml - if args.yaml_cfg is not None: - config = core_transformer_config_from_yaml(args, 'language_model') - else: - config = core_transformer_config_from_args(args) + logger.info('building GPT model ...') + config = core_transformer_config_from_args(args) config.variable_seq_lengths = True - if args.use_legacy_models: - model = megatron.legacy.model.GPTModel( - config, - num_tokentypes=0, - parallel_output=True, - pre_process=pre_process, - post_process=post_process, - ) - else: # using core models - if args.spec is not None: - transformer_layer_spec = import_module(args.spec) - elif megatron_model_meta.get_transformer_layer_spec is not None: - transformer_layer_spec = megatron_model_meta.get_transformer_layer_spec(config, vp_stage=vp_stage) + if megatron_model_meta.get_transformer_layer_spec is not None: + transformer_layer_spec = megatron_model_meta.get_transformer_layer_spec(config, vp_stage=vp_stage) + else: + if args.num_experts: + kwargs = {'qk_l2_norm': args.qk_l2_norm, 'vp_stage': vp_stage} if mcore_013 else {} + # Define the decoder block spec + transformer_layer_spec = get_gpt_decoder_block_spec( + config, use_transformer_engine=True, normalization=args.normalization, **kwargs) + else: + # Define the decoder layer spec + transformer_layer_spec = _get_transformer_layer_spec(args) + mtp_block_spec = None + if args.mtp_num_layers is not None: + if hasattr(transformer_layer_spec, 'layer_specs') and len(transformer_layer_spec.layer_specs) == 0: + # Get the decoder layer spec explicitly if no decoder layer in the last stage, + # Only happens with block spec (TransformerBlockSubmodules) when using MoE. + transformer_layer_spec_for_mtp = _get_transformer_layer_spec(args) + else: + transformer_layer_spec_for_mtp = transformer_layer_spec + kwargs = {'vp_stage': vp_stage} if mcore_013 else {} + if megatron_model_meta.get_mtp_block_spec is not None: + get_mtp_block_spec = megatron_model_meta.get_mtp_block_spec else: - if args.num_experts: - kwargs = {'qk_l2_norm': args.qk_l2_norm, 'vp_stage': vp_stage} if mcore_013 else {} - # Define the decoder block spec - transformer_layer_spec = get_gpt_decoder_block_spec( - config, use_transformer_engine=use_te, normalization=args.normalization, **kwargs) - elif args.heterogeneous_layers_config_path is not None: - transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te) - else: - # Define the decoder layer spec - transformer_layer_spec = _get_transformer_layer_spec(use_te, config) - mtp_block_spec = None - if args.mtp_num_layers is not None: - if hasattr(transformer_layer_spec, 'layer_specs') and len(transformer_layer_spec.layer_specs) == 0: - # Get the decoder layer spec explicitly if no decoder layer in the last stage, - # Only happens with block spec (TransformerBlockSubmodules) when using MoE. - transformer_layer_spec_for_mtp = _get_transformer_layer_spec(use_te, config) - else: - transformer_layer_spec_for_mtp = transformer_layer_spec - kwargs = {'vp_stage': vp_stage} if mcore_013 else {} - if megatron_model_meta.get_mtp_block_spec is not None: - get_mtp_block_spec = megatron_model_meta.get_mtp_block_spec - else: - get_mtp_block_spec = get_gpt_mtp_block_spec - mtp_block_spec = get_mtp_block_spec( - config, transformer_layer_spec_for_mtp, use_transformer_engine=use_te, **kwargs) - - if args.use_shared_expert_gate and args.num_experts and args.moe_shared_expert_intermediate_size: - for layer_spec in transformer_layer_spec.layer_specs: - if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): - layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} - model = megatron_model_meta.model_cls( - config=config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=math.ceil(args.padded_vocab_size / args.tensor_model_parallel_size) - * args.tensor_model_parallel_size, - max_sequence_length=args.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, - parallel_output=True, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_percent=args.rotary_percent, - rotary_base=args.rotary_base, - hf_rope_scaling=args.rope_scaling, - rope_scaling=args.use_rope_scaling, - rope_scaling_factor=args.rope_scaling_factor, - seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor, - mtp_block_spec=mtp_block_spec, - vp_stage=vp_stage, - ) + get_mtp_block_spec = get_gpt_mtp_block_spec + mtp_block_spec = get_mtp_block_spec( + config, transformer_layer_spec_for_mtp, use_transformer_engine=True, **kwargs) + + if args.use_shared_expert_gate and args.num_experts and args.moe_shared_expert_intermediate_size: + for layer_spec in transformer_layer_spec.layer_specs: + if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): + layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} + model = megatron_model_meta.model_cls( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=math.ceil(args.padded_vocab_size / args.tensor_model_parallel_size) + * args.tensor_model_parallel_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_base=args.rotary_base, + hf_rope_scaling=args.rope_scaling, + mtp_block_spec=mtp_block_spec, + vp_stage=vp_stage, + ) return model diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index 46e253a6d1..9bde2e3ca8 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -1,10 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # Parts of the functions in this file are code borrowed from NVIDIA/Megatron-LM +import dataclasses from contextlib import contextmanager from datetime import timedelta import torch +import torch.nn.functional as F from megatron.core import mpu, tensor_parallel +from megatron.core.fusions.fused_bias_geglu import quick_gelu +from megatron.core.transformer import MLATransformerConfig, TransformerConfig from swift.utils import get_logger, is_master, seed_everything @@ -91,3 +95,55 @@ def initialize_megatron(args): MoEAuxLossAutoScaler.set_loss_scale(torch.ones(1, device=torch.cuda.current_device())) # TODO: tp_comm_overlap, _compile_dependencies + + +def core_transformer_config_from_args(args, config_class=None): + # Config class. + config_class = config_class or TransformerConfig + + if args.multi_latent_attention: + config_class = MLATransformerConfig + + # Translate args to core transformer configuration + kw_args = {} + for f in dataclasses.fields(config_class): + if hasattr(args, f.name): + kw_args[f.name] = getattr(args, f.name) + kw_args['persist_layer_norm'] = True + # TODO: apply_layernorm_1p + kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p + kw_args['layernorm_epsilon'] = args.norm_epsilon + kw_args['deallocate_pipeline_outputs'] = True + kw_args['pipeline_dtype'] = args.torch_dtype + kw_args['batch_p2p_comm'] = True + kw_args['num_moe_experts'] = args.num_experts + kw_args['rotary_interleaved'] = args.rotary_interleaved + kw_args['num_layers_in_first_pipeline_stage'] = args.decoder_first_pipeline_num_layers + kw_args['num_layers_in_last_pipeline_stage'] = args.decoder_last_pipeline_num_layers + kw_args['fp8_param'] = args.fp8_param_gather + if args.swiglu: + kw_args['activation_func'] = F.silu + kw_args['gated_linear_unit'] = True + kw_args['bias_activation_fusion'] = args.bias_swiglu_fusion + else: + kw_args['bias_activation_fusion'] = args.bias_gelu_fusion + if args.quick_geglu: + assert not args.swiglu + kw_args['gated_linear_unit'] = True + kw_args['activation_func'] = quick_gelu + if args.group_query_attention: + kw_args['num_query_groups'] = args.num_query_groups + else: + kw_args['num_query_groups'] = None + if args.rope_type is None: + # Pop 'rope_type' to let the config class use the default value. + kw_args.pop('rope_type', None) + else: + assert (args.multi_latent_attention or args.rope_type + == 'rope'), (f'Common attention only support rope_type="rope", but got {args.rope_type}.') + + kw_args['cp_comm_type'] = 'p2p' + kw_args['inference_sampling_seed'] = args.seed + + # Return config. + return config_class(**kw_args) From e58dde8f60893b9ad4eef956c332139412a88d84 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 3 Feb 2026 11:03:30 +0800 Subject: [PATCH 05/43] update --- swift/megatron/arguments/megatron_args.py | 13 ++++++---- swift/megatron/arguments/sft_args.py | 4 +-- swift/megatron/convert.py | 8 +++--- swift/megatron/model/gpt_bridge.py | 4 +-- swift/megatron/model/gpt_model.py | 7 +++-- swift/megatron/model/model_provider.py | 1 - swift/megatron/model/register.py | 7 +++-- swift/megatron/model/rope.py | 7 +++-- swift/megatron/pipelines/export/export.py | 8 +++--- swift/megatron/trainers/base.py | 31 +++++++++++------------ swift/megatron/trainers/gkd_trainer.py | 2 +- swift/megatron/trainers/rollout_mixin.py | 2 +- swift/megatron/utils/__init__.py | 2 +- swift/megatron/utils/config.py | 6 ++--- swift/megatron/utils/megatron_lm_utils.py | 16 +++++------- 15 files changed, 58 insertions(+), 60 deletions(-) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index a51234a1c9..29c4f243b1 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -380,6 +380,7 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): linear_conv_kernel_dim: Optional[int] = None layer_types: Optional[List[str]] = None apply_wd_to_qk_layernorm: bool = False + apply_layernorm_1p: bool = False # qwen3_vl, qwen3_omni mrope_interleaved: Optional[bool] = None @@ -418,7 +419,7 @@ class MegatronArguments(ExtraMegatronArguments): train_iters: Optional[int] = None log_interval: int = 5 tensorboard_dir: Optional[str] = None - no_masked_softmax_fusion: bool = False + masked_softmax_fusion: bool = True no_bias_dropout_fusion: Optional[bool] = None no_bias_swiglu_fusion: bool = False no_bias_gelu_fusion: bool = False @@ -469,7 +470,7 @@ class MegatronArguments(ExtraMegatronArguments): no_load_rng: bool = False finetune: bool = False ckpt_format: Literal['torch', 'torch_dist', 'zarr'] = 'torch_dist' - no_initialization: bool = True + perform_initialization: bool = False auto_detect_ckpt_format: bool = True exit_on_missing_checkpoint: bool = True async_save: bool = False @@ -522,7 +523,7 @@ class MegatronArguments(ExtraMegatronArguments): activation_func_clamp_value: Optional[float] = None glu_linear_offset: Optional[float] = None untie_embeddings_and_output_weights: Optional[bool] = None - disable_bias_linear: Optional[bool] = None + add_bias_linear: Optional[bool] = None add_qkv_bias: Optional[bool] = None attention_dropout: Optional[float] = None hidden_dropout: float = 0. @@ -647,8 +648,8 @@ def _set_default(self): self.glu_linear_offset = 0. if self.add_qkv_bias is None: self.add_qkv_bias = True - if self.disable_bias_linear is None: - self.disable_bias_linear = True + if self.add_bias_linear is None: + self.add_bias_linear = False if self.qk_layernorm is None: self.qk_layernorm = False if self.multi_latent_attention is None: @@ -736,6 +737,8 @@ def __post_init__(self): self.model_info, self.model_meta = get_model_info_meta( self.model, model_type=self.model_type, use_hf=self.use_hf, hub_token=self.hub_token) self.model_type = self.model_info.model_type + self.model_dir = self.model_info.model_dir + self.is_multimodal = self.model_meta.is_multimodal if self.pipeline_model_parallel_size == 1 and (self.decoder_first_pipeline_num_layers is not None or self.decoder_last_pipeline_num_layers is not None): raise ValueError('pipeline_model_parallel_size must be greater than 1 if you want to set ' diff --git a/swift/megatron/arguments/sft_args.py b/swift/megatron/arguments/sft_args.py index 0fc8161fb7..181cb055f8 100644 --- a/swift/megatron/arguments/sft_args.py +++ b/swift/megatron/arguments/sft_args.py @@ -46,6 +46,6 @@ def __post_init__(self): if self.tensorboard_dir is None and self.save is not None: self.tensorboard_dir = f'{self.save}/runs' self.tensorboard_dir = to_abspath(self.tensorboard_dir) - if self.load is None and self.model is None and self.no_initialization: + if self.load is None and self.model is None and not self.perform_initialization: raise ValueError('You did not pass `--load/--model` to read weights, so you need to set ' - '`--no_initialization false` to allow the model to initialize weights properly.') + '`--perform_initialization true` to allow the model to initialize weights properly.') diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 729e9583b3..6151249245 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -61,9 +61,9 @@ def convert_hf2mcore(args: ExportArguments) -> None: torch_dtype=args.torch_dtype) initialize_megatron(megatron_args) - mg_model = megatron_model_meta.model_provider() + mg_model = megatron_model_meta.model_provider(megatron_args) logger.info('Megatron model created successfully.') - bridge = megatron_model_meta.bridge_cls() + bridge = megatron_model_meta.bridge_cls(megatron_args) bridge.load_weights(mg_model, args.model_info.model_dir) logger.info('Successfully transferred HF model weights to MG model.') _test_convert_precision = strtobool(os.getenv('SWIFT_TEST_CONVERT_PRECISION', '0')) @@ -106,7 +106,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: torch_dtype=args.torch_dtype) initialize_megatron(megatron_args) - mg_model = megatron_model_meta.model_provider() + mg_model = megatron_model_meta.model_provider(megatron_args) if megatron_args.load is None: raise ValueError('Please specify `--mcore_model`.') with patch_load_base_checkpoint(): @@ -119,7 +119,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: mg_model = peft_model.merge_and_unload() logger.info('Megatron model created successfully.') if args.to_hf: - bridge = megatron_model_meta.bridge_cls() + bridge = megatron_model_meta.bridge_cls(megatron_args) logger.info('Converting weights and saving the model...') bridge.save_weights([mg_model], args.output_dir, processor=processor, config=hf_config) if is_master(): diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 84bd2c1bef..8f1ab86b78 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -35,9 +35,9 @@ class GPTBridge: hf_score_key = 'score.weight' hf_state_dict_mapping = {} - def __init__(self, disable_tqmd: bool = False): + def __init__(self, args, disable_tqmd: bool = False): from .register import get_megatron_model_meta - self.args = get_args() + self.args = args self.disable_tqmd = disable_tqmd or not is_last_rank() self._target_device = None self._only_last_rank = False diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 1794a0560f..9000279f50 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -94,7 +94,6 @@ def __init__( share_embeddings_and_output_weights=share_embeddings_and_output_weights, position_embedding_type=position_embedding_type, rotary_base=rotary_base, - rope_scaling=rope_scaling, mtp_block_spec=mtp_block_spec, **kwargs, ) @@ -111,9 +110,9 @@ def __init__( if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'): del self.decoder.layers[i].self_attention.rotary_pos_emb self.attention_scaling = 1. - new_inv_freq, self.attention_scaling = get_rope_inv_freq() + self.args = args = config.args + new_inv_freq, self.attention_scaling = get_rope_inv_freq(args) self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) - args = get_args() if args.task_type == 'seq_cls' and self.post_process: self.output_layer = OutputLayerLinear( config.hidden_size, @@ -343,7 +342,7 @@ def _postprocess( """ if not self.post_process: return hidden_states - args = get_args() + args = self.args labels = labels if args.task_type == 'causal_lm' else None in_inference_mode = inference_context is not None and not self.training if in_inference_mode: diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py index 7b4674a16e..b80fdfcd3b 100644 --- a/swift/megatron/model/model_provider.py +++ b/swift/megatron/model/model_provider.py @@ -29,7 +29,6 @@ def _get_transformer_layer_spec(args): args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention, - moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, **kwargs, ) diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 3f7df1d5ec..e606cbed5b 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from argparse import ArgumentParser from dataclasses import dataclass -from typing import Callable, List, Optional, Type +from typing import TYPE_CHECKING, Callable, List, Optional, Type import torch.nn as nn @@ -12,6 +12,9 @@ from .mm_gpt_model import MultimodalGPTModel from .model_provider import model_provider as model_provider_func +if TYPE_CHECKING: + from swift.megatron.arguments import MegatronArguments + MEGATRON_MODEL_MAPPING = {} @@ -24,7 +27,7 @@ class MegatronModelMeta: bridge_cls: Type[GPTBridge] = GPTBridge model_cls: Optional[Type[nn.Module]] = None get_transformer_layer_spec: Optional[Callable] = None - model_provider: Callable[[], nn.Module] = model_provider_func + model_provider: Callable[['MegatronArguments'], nn.Module] = model_provider_func visual_cls: Optional[Type[nn.Module]] = None get_mtp_block_spec: Optional[Callable] = None diff --git a/swift/megatron/model/rope.py b/swift/megatron/model/rope.py index 241893b009..9943c33b16 100644 --- a/swift/megatron/model/rope.py +++ b/swift/megatron/model/rope.py @@ -107,9 +107,8 @@ def _get_rope_type(rope_scaling: Optional[Dict[str, Any]]): return rope_type -def get_rope_inv_freq(seq_len=None): +def get_rope_inv_freq(args, seq_len=None): from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS - args = get_args() ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS) dummy_config = _get_dummy_config(args) rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(args.rope_scaling)] @@ -127,7 +126,7 @@ def longrope_frequency_update(args, model, inv_freq, seq_len: int): original_max_position_embeddings = args.max_position_embeddings if not hasattr(model, 'long_inv_freq'): - model.long_inv_freq, _ = get_rope_inv_freq(seq_len=original_max_position_embeddings + 1) + model.long_inv_freq, _ = get_rope_inv_freq(args, seq_len=original_max_position_embeddings + 1) model.original_inv_freq = inv_freq.clone() if seq_len > original_max_position_embeddings: @@ -144,7 +143,7 @@ def dynamic_frequency_update(args, model, inv_freq, seq_len: int): model.original_inv_freq = inv_freq.clone() attention_scaling = None if seq_len > model.max_seq_len_cached: # growth - new_inv_freq, attention_scaling = get_rope_inv_freq(seq_len=seq_len) + new_inv_freq, attention_scaling = get_rope_inv_freq(args, seq_len=seq_len) inv_freq.data.copy_(new_inv_freq) model.max_seq_len_cached = seq_len diff --git a/swift/megatron/pipelines/export/export.py b/swift/megatron/pipelines/export/export.py index 81bd317373..b237250ab6 100644 --- a/swift/megatron/pipelines/export/export.py +++ b/swift/megatron/pipelines/export/export.py @@ -43,8 +43,8 @@ def convert_mcore2hf(self) -> None: pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() - mg_model = megatron_model_meta.model_provider(pre_process=pre_process, post_process=post_process) - bridge = megatron_model_meta.bridge_cls() + mg_model = megatron_model_meta.model_provider(args, pre_process=pre_process, post_process=post_process) + bridge = megatron_model_meta.bridge_cls(args) if args.load is not None: with patch_load_base_checkpoint(): load_checkpoint([mg_model], None, None, strict=True) @@ -96,9 +96,9 @@ def convert_hf2mcore(self) -> None: pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() - mg_model = megatron_model_meta.model_provider(pre_process=pre_process, post_process=post_process) + mg_model = megatron_model_meta.model_provider(args, pre_process=pre_process, post_process=post_process) logger.info('Megatron model created successfully.') - bridge = megatron_model_meta.bridge_cls() + bridge = megatron_model_meta.bridge_cls(args) if args.model is not None: bridge.load_weights(mg_model, args.model_info.model_dir) elif args.load is not None: diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 96ead60887..05e7d57e10 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -99,7 +99,7 @@ def _get_mean_metric(): @property def bridge(self): if self._bridge is None: - self._bridge = self.args.megatron_model_meta.bridge_cls() + self._bridge = self.args.megatron_model_meta.bridge_cls(self.args) return self._bridge @contextmanager @@ -109,7 +109,7 @@ def _get_iters(self, train_dataset, val_dataset): def initialize_megatron(*_args, **kwargs): res = origin_initialize_megatron(*_args, **kwargs) - args = get_args() + args = self.args data_parallel_size = mpu.get_data_parallel_world_size() step_batch_size = args.micro_batch_size * data_parallel_size num_generations = args.num_generations if args.rlhf_type == 'grpo' else 1 @@ -158,7 +158,7 @@ def new_cyclic_iter(self, iterable): yield from self._origin_cyclic_iter(iterable) return - args = get_args() + args = self.args n_epoch = 0 is_finished = False while True: @@ -254,7 +254,7 @@ def _patch_load_state_dict(self, load_base_checkpoint): checkpointing.origin__load_base_checkpoint = checkpointing._load_base_checkpoint checkpointing._load_base_checkpoint = load_base_checkpoint - args = get_args() + args = self.args origin_load_state_dict = torch.nn.Module.load_state_dict origin_no_load_optim = args.no_load_optim origin_no_load_rng = args.no_load_rng @@ -317,7 +317,7 @@ def _get_param_groups( Returns: List of parameter groups. """ - args = get_args() + args = self.args if self.args.vit_lr is not None or self.args.aligner_lr is not None: assert self.args.megatron_model_meta.is_multimodal vit_lr = self.args.vit_lr if self.args.vit_lr is not None else self.args.lr @@ -474,10 +474,8 @@ def _load_iteration(self): def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): - args = get_args() - def new_model_provider_func(*_args, **kwargs): - model = model_provider_func(*_args, **kwargs) + model = model_provider_func(self.args, *_args, **kwargs) if args.load is None: self.bridge.load_weights(model, args.model_dir) self.unwrapped_models.append(model) @@ -496,6 +494,7 @@ def new_model_provider_func(*_args, **kwargs): self._init_multimodal_full() # read iteration + args = self.args if not args.finetune: args.iteration, args.num_floating_point_operations_so_far = self._load_iteration() @@ -587,7 +586,7 @@ def evaluate( eval_iters=None, ): """Evaluation.""" - args = get_args() + args = self.args timers = get_timers() timers('evaluate', log_level=0).start(barrier=True) @@ -746,7 +745,7 @@ def evaluate_and_print_results( ): """Helper function to evaluate and dump results on screen.""" - args = get_args() + args = self.args if write_to_tensorboard: writer = get_tensorboard_writer() else: @@ -841,7 +840,7 @@ def custom_log(self, total_loss_dict, mode: Literal['train', 'eval'], iteration= total_loss_dict.update(metrics) self._remove_log(total_loss_dict) if iteration is None: - args = get_args() + args = self.args iteration = args.curr_iteration + 1 if writer: for k, v in metrics.items(): @@ -853,7 +852,7 @@ def custom_log(self, total_loss_dict, mode: Literal['train', 'eval'], iteration= def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad): """Log training information such as losses, timing, ....""" - args = get_args() + args = self.args timers = get_timers() writer = get_tensorboard_writer() wandb_writer = get_wandb_writer() @@ -1139,7 +1138,7 @@ def copy_path(src_path: str, tgt_path: str): raise ValueError(f'Source path is neither a file nor a directory: {src_path}') def save_checkpoint(self, iteration, model, *_args, **kwargs): - args = get_args() + args = self.args output_dir = os.path.join(args.save, f'checkpoint-{iteration}') os.makedirs(output_dir, exist_ok=True) origin_save = args.save @@ -1197,7 +1196,7 @@ def _patch_megatron(self): training.save_checkpoint = self.save_checkpoint def _init_multimodal_full(self): - args = get_args() + args = self.args visual_cls = self.args.megatron_model_meta.visual_cls if args.tuner_type == 'full' and args.is_multimodal and visual_cls is not None: vision_tower = [f'visual.{vit}' for vit in getattr(visual_cls, '_vision_tower', [])] @@ -1232,7 +1231,7 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, data_collator if dataset is None: return None - args = get_args() + args = self.args if args.dataloader_type == 'external': # External dataloaders are passed through. User is expected to provide a # torch-compatible dataloader and define samplers, if needed. @@ -1316,7 +1315,7 @@ def _prepare_batch(self, data, vp_stage=None, num_samples=None): batch = get_batch_on_this_tp_rank(data, vp_stage=vp_stage) if num_samples is None: num_samples = batch.pop('num_samples') - args = get_args() + args = self.args text_position_ids = batch.pop('text_position_ids', None) batch.pop('attention_mask_2d', None) if text_position_ids is None: diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index 22fd1beb36..3d1c02efb8 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -202,7 +202,7 @@ def _load_teacher_model(self, teacher_model_path: str, model_type: str): teacher_models = get_model(teacher_megatron_model_meta.model_provider, model_type, wrap_with_ddp=False) # Create bridge for teacher model (for weight loading) - teacher_bridge = teacher_megatron_model_meta.bridge_cls() + teacher_bridge = teacher_megatron_model_meta.bridge_cls(self.args) # Load teacher weights and set to eval mode for m in teacher_models: diff --git a/swift/megatron/trainers/rollout_mixin.py b/swift/megatron/trainers/rollout_mixin.py index b56c10e26d..d6df2d495e 100644 --- a/swift/megatron/trainers/rollout_mixin.py +++ b/swift/megatron/trainers/rollout_mixin.py @@ -275,7 +275,7 @@ def _prepare_vllm_engine(self): def bridge(self): """Lazy initialization of weight bridge for Megatron-to-vLLM weight transfer.""" if self._bridge is None: - self._bridge = self.args.megatron_model_meta.bridge_cls(disable_tqmd=True) + self._bridge = self.args.megatron_model_meta.bridge_cls(self.args, disable_tqmd=True) return self._bridge @profiling_decorator diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 7ab00db11d..bd4cf9b252 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -2,7 +2,7 @@ from .config import convert_hf_config from .convert_utils import test_convert_precision -from .megatron_lm_utils import initialize_megatron +from .megatron_lm_utils import core_transformer_config_from_args, initialize_megatron from .patcher import patch_load_base_checkpoint, patch_merge_fn, patch_torch_dist_shard from .utils import (MegatronTrainerState, adapter_state_dict_context, copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/config.py b/swift/megatron/utils/config.py index 24720ad3a4..ac282d95ec 100644 --- a/swift/megatron/utils/config.py +++ b/swift/megatron/utils/config.py @@ -19,7 +19,7 @@ 'untie_embeddings_and_output_weights': ['tie_word_embeddings'], 'swiglu': ['hidden_act'], 'add_qkv_bias': ['attention_bias', 'qkv_bias', 'use_bias'], - 'disable_bias_linear': ['mlp_bias'], + 'add_bias_linear': ['mlp_bias'], 'kv_channels': ['head_dim'], 'hf_model_type': ['model_type'], # moe @@ -68,7 +68,7 @@ def _convert_config(config, _internal_call=False) -> Dict[str, Any]: continue if k == 'rotary_base': megatron_config[k] = int(hf_v) - elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}: + elif k in {'untie_embeddings_and_output_weights', 'moe_router_pre_softmax'}: megatron_config[k] = not hf_v elif k == 'swiglu': if hf_v == 'silu': @@ -138,7 +138,7 @@ def convert_hf_config(config) -> Dict[str, Any]: elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}: res['rotary_interleaved'] = True elif llm_model_type == 'gpt_oss': - res['disable_bias_linear'] = False + res['add_bias_linear'] = True res['no_bias_dropout_fusion'] = True res['softmax_type'] = 'learnable' res['swiglu'] = False diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index 9bde2e3ca8..71c5ab8c81 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -124,9 +124,9 @@ def core_transformer_config_from_args(args, config_class=None): if args.swiglu: kw_args['activation_func'] = F.silu kw_args['gated_linear_unit'] = True - kw_args['bias_activation_fusion'] = args.bias_swiglu_fusion + kw_args['bias_activation_fusion'] = not args.no_bias_swiglu_fusion else: - kw_args['bias_activation_fusion'] = args.bias_gelu_fusion + kw_args['bias_activation_fusion'] = not args.no_bias_gelu_fusion if args.quick_geglu: assert not args.swiglu kw_args['gated_linear_unit'] = True @@ -135,15 +135,11 @@ def core_transformer_config_from_args(args, config_class=None): kw_args['num_query_groups'] = args.num_query_groups else: kw_args['num_query_groups'] = None - if args.rope_type is None: - # Pop 'rope_type' to let the config class use the default value. - kw_args.pop('rope_type', None) - else: - assert (args.multi_latent_attention or args.rope_type - == 'rope'), (f'Common attention only support rope_type="rope", but got {args.rope_type}.') kw_args['cp_comm_type'] = 'p2p' kw_args['inference_sampling_seed'] = args.seed - # Return config. - return config_class(**kw_args) + config = config_class(**kw_args) + config.args = args + + return config From 7084f77dc762c920be09089c1053fe6b2e0fb16e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 3 Feb 2026 11:04:05 +0800 Subject: [PATCH 06/43] update --- docs/source/Megatron-SWIFT/Command-line-parameters.md | 6 +++--- docs/source_en/Megatron-SWIFT/Command-line-parameters.md | 6 +++--- examples/ascend/train/qwen3/qwen3_lora_megatron.sh | 2 +- examples/ascend/train/qwen3_next/qwen3_next_megatron.sh | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 1aa370fe4d..c79075f645 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -23,7 +23,7 @@ - 🔥max_epochs: 指定训练的epochs数。当使用非流式数据集时,该参数会为你自动计算train_iters而不需要手动传入`train_iters`。当使用流式数据集时,该参数会在训练到`max_epochs`时强制退出训练,并对权重进行验证和保存。默认为None。 - 🔥log_interval: log的时间间隔(单位:iters),默认为5。 - tensorboard_dir: tensorboard日志写入的目录。默认None,即存储在`f'{save}/runs'`目录下。 -- no_masked_softmax_fusion: 默认为False。用于禁用query_key_value的scaling, masking, and softmax融合。 +- masked_softmax_fusion: 默认为True。用于开启query_key_value的scaling, masking, and softmax融合。 - no_bias_dropout_fusion: 默认为False。用于禁用bias和dropout的融合。 - no_bias_swiglu_fusion: 默认为False。指定`--no_bias_dropout_fusion true`,用于禁止bias和swiglu融合。 - no_rope_fusion: 默认为False。指定`--no_rope_fusion true`用于禁止rope融合。 @@ -96,7 +96,7 @@ - 注意:**断点续训**你需要设置`--load`(lora训练需要额外设置`--adapter_load`),若设置`--finetune true`,将不加载优化器状态和随机种子状态并将迭代数设置为0,不会进行数据集跳过;若设置`--finetune false`,将读取迭代数并跳过之前训练的数据集数量,优化器状态和随机种子状态的读取通过`--no_load_optim`和`--no_load_rng`控制。 - 流式数据集`--streaming`,暂不支持跳过数据集。 - ckpt_format: checkpoint的格式。可选为'torch', 'torch_dist', 'zarr'。默认为'torch_dist'。(暂时权重转换只支持'torch_dist'格式) -- no_initialization: 不对权重进行初始化,默认为True。 +- perform_initialization: 对权重进行初始化,默认为False。 - auto_detect_ckpt_format: 自动检测ckpt format为legacy还是distributed格式。默认为True。 - exit_on_missing_checkpoint: 如果设置了`–-load`,但**找不到检查点,则直接退出**,而不是初始化。默认为True。 - 🔥async_save: 使用异步检查点保存。目前仅适用于`torch_dist`分布式检查点格式。默认为False。 @@ -183,7 +183,7 @@ - activation_func_clamp_value: 限制激活函数中 linear_fc1 的输出值范围。仅在 `activation_func` 为 `quick_gelu` 时使用。默认为None。 - glu_linear_offset: GLU 激活函数中的偏移项:`activation_func(x[0]) * (x[1] + offset)`。仅在 gated_linear_unit 为 True 时使用。默认为0.。 - untie_embeddings_and_output_weights: 解开embedding和输出权重的绑定,默认为True。 -- disable_bias_linear: 禁用linear层的bias。默认为True。 +- add_bias_linear: 开启linear层的bias。默认为True。 - add_qkv_bias: 仅在QKV的linear中增加bias,默认为True。 - attention_dropout: 默认为0.。 - hidden_dropout: 默认为0.。 diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index ff96ae02b1..2ff8fcb488 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -24,7 +24,7 @@ - 🔥max_epochs: Specifies the number of training epochs. When using a non-streaming dataset, this parameter automatically calculates `train_iters`, eliminating the need to manually provide `train_iters`. When using a streaming dataset, training will be forcibly terminated upon reaching `max_epochs`, and the model weights will be validated and saved. Default is None. - 🔥log_interval: Log interval (unit: iters), default is 5. - tensorboard_dir: Directory where TensorBoard logs are written. Default is None, meaning logs will be stored in the `f'{save}/runs'` directory. -- no_masked_softmax_fusion: Default is False. Disables scaling, masking, and softmax fusion for query_key_value. +- masked_softmax_fusion: Defaults to True. Used to enable the fusion of scaling, masking, and softmax for query_key_value. - no_bias_dropout_fusion: Default is False. Disables bias and dropout fusion. - no_bias_swiglu_fusion: Default is False. Specify `--no_bias_dropout_fusion true` to disable bias and swiglu fusion. - no_rope_fusion: Default is False. Specify `--no_rope_fusion true` to disable rope fusion. @@ -100,7 +100,7 @@ - Note: For resuming training from a checkpoint, you should set `--load` (and additionally `--adapter_load` for LoRA training). If `--finetune true` is set, the optimizer and RNG states will not be loaded, the iteration count will be reset to 0, and no dataset skipping will occur. If `--finetune false` is set, the iteration count will be restored, and the corresponding number of previously trained samples will be skipped in the dataset. Loading of the optimizer and RNG states is controlled by `--no_load_optim` and `--no_load_rng`, respectively. - Streaming datasets (`--streaming`) are currently not supported for skipping datasets. - ckpt_format: Format of the checkpoint. Options are 'torch', 'torch_dist', 'zarr'. Default is 'torch_dist'. (Currently, weight conversion only supports the 'torch_dist' format.) -- no_initialization: Do not initialize weights, default is True. +- perform_initialization: Perform weight initialization. Default is False. - auto_detect_ckpt_format: Automatically detect whether the checkpoint format is legacy or distributed. Default is True. - exit_on_missing_checkpoint: If `--load` is set but **no checkpoint is found, exit directly** instead of initializing. Default is True. - 🔥async_save: Use asynchronous checkpoint saving. Currently only applicable to the `torch_dist` distributed checkpoint format. Defaults to False. @@ -194,7 +194,7 @@ For guidance on selecting parallelization strategies, please refer to the [Train - activation_func_clamp_value: Clamp the output value range of linear_fc1 in the activation function. Only used when `activation_func` is `quick_gelu`. Default is None. - glu_linear_offset: Offset term in the GLU activation function: `activation_func(x[0]) * (x[1] + offset)`. Only used when gated_linear_unit is True. Default is 0. - untie_embeddings_and_output_weights: Unties embedding and output weights. Default is True. -- disable_bias_linear: Disables bias in linear layers. Default is True. +- add_bias_linear: Enable bias in linear layers. Default is True - add_qkv_bias: Adds bias only to QKV linear layers. Default is True. - attention_dropout: Default is 0. - hidden_dropout: Default is 0. diff --git a/examples/ascend/train/qwen3/qwen3_lora_megatron.sh b/examples/ascend/train/qwen3/qwen3_lora_megatron.sh index 5e8f9e20f4..bf8a8865bc 100755 --- a/examples/ascend/train/qwen3/qwen3_lora_megatron.sh +++ b/examples/ascend/train/qwen3/qwen3_lora_megatron.sh @@ -33,6 +33,6 @@ megatron sft \ --no_save_rng true \ --dataset_num_proc 4 \ --no_gradient_accumulation_fusion true \ - --no_masked_softmax_fusion true \ + --masked_softmax_fusion false \ --model_author swift \ --model_name swift-robot diff --git a/examples/ascend/train/qwen3_next/qwen3_next_megatron.sh b/examples/ascend/train/qwen3_next/qwen3_next_megatron.sh index 88648319d2..8f5a0c076c 100644 --- a/examples/ascend/train/qwen3_next/qwen3_next_megatron.sh +++ b/examples/ascend/train/qwen3_next/qwen3_next_megatron.sh @@ -35,6 +35,6 @@ megatron sft \ --no_save_rng true \ --dataset_num_proc 4 \ --no_gradient_accumulation_fusion true \ - --no_masked_softmax_fusion true \ + --masked_softmax_fusion false \ --model_author swift \ --model_name swift-robot From f383eb2e8b98bf55cc292d95a627f3e62f99df57 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 3 Feb 2026 11:16:56 +0800 Subject: [PATCH 07/43] update --- .../Instruction/Frequently-asked-questions.md | 2 +- docs/source/Megatron-SWIFT/Ascend.md | 8 ++++---- .../Megatron-SWIFT/Command-line-parameters.md | 6 +++--- docs/source/Megatron-SWIFT/Quick-start.md | 2 +- .../Instruction/Frequently-asked-questions.md | 2 +- docs/source_en/Megatron-SWIFT/Ascend.md | 8 ++++---- .../Megatron-SWIFT/Command-line-parameters.md | 6 +++--- docs/source_en/Megatron-SWIFT/Quick-start.md | 2 +- examples/ascend/multi-node/megatron/node1.sh | 2 +- examples/ascend/multi-node/megatron/node2.sh | 2 +- .../ascend/train/qwen3/qwen3_lora_megatron.sh | 2 +- .../train/qwen3_next/qwen3_next_megatron.sh | 2 +- swift/megatron/arguments/megatron_args.py | 16 ++++++++-------- swift/megatron/utils/config.py | 2 +- swift/megatron/utils/megatron_lm_utils.py | 4 ++-- 15 files changed, 33 insertions(+), 33 deletions(-) diff --git a/docs/source/Instruction/Frequently-asked-questions.md b/docs/source/Instruction/Frequently-asked-questions.md index f9113b7737..a13d500fe7 100644 --- a/docs/source/Instruction/Frequently-asked-questions.md +++ b/docs/source/Instruction/Frequently-asked-questions.md @@ -592,7 +592,7 @@ megatron sft \ ```text RuntimeError: ColumnParallelLinear was called with gradient_accumulation_fusion set to True but the custom CUDA extension fused_weight_gradient_mlp_cuda module is not found. To use gradient_accumulation_fusion you must install APEX with --cpp_ext and --cuda_ext. For example: pip install --global-option="--cpp_ext" --global-option="--cuda_ext ." Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion. ``` -设置一下`--no_gradient_accumulation_fusion true`。 +设置一下`--gradient_accumulation_fusion false`。 ### Q163: moe的lora训练,target_modules参数设置了all-linear,是包括了路由器模块吗? 看gate是否是nn.Linear实现,如果是nn.Parameter就不训练,详见命令行参数[target_parameters](https://swift.readthedocs.io/zh-cn/latest/Instruction/Command-line-parameters.html#tuner)。 diff --git a/docs/source/Megatron-SWIFT/Ascend.md b/docs/source/Megatron-SWIFT/Ascend.md index 774af28f52..b1136ce32c 100644 --- a/docs/source/Megatron-SWIFT/Ascend.md +++ b/docs/source/Megatron-SWIFT/Ascend.md @@ -182,7 +182,7 @@ def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_par ### 使能 -另外,由于msprobe不支持融合计算,需要在启动脚本添加`--no_bias_dropout_fusion True`、`--no_bias_swiglu_fusion True`、`--cross_entropy_loss_fusion False` +另外,由于msprobe不支持融合计算,需要在启动脚本添加`--bias_dropout_fusion false`、`--bias_swiglu_fusion false`、`--cross_entropy_loss_fusion false` #### 示例 ```shell @@ -196,7 +196,7 @@ megatron sft \ 'swift/self-cognition#500' \ --tensor_model_parallel_size 2 \ ... - --no_bias_dropout_fusion True \ - --no_bias_swiglu_fusion True \ - --cross_entropy_loss_fusion False + --bias_dropout_fusion false \ + --bias_swiglu_fusion false \ + --cross_entropy_loss_fusion false ``` diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index c79075f645..9629958c41 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -24,11 +24,11 @@ - 🔥log_interval: log的时间间隔(单位:iters),默认为5。 - tensorboard_dir: tensorboard日志写入的目录。默认None,即存储在`f'{save}/runs'`目录下。 - masked_softmax_fusion: 默认为True。用于开启query_key_value的scaling, masking, and softmax融合。 -- no_bias_dropout_fusion: 默认为False。用于禁用bias和dropout的融合。 -- no_bias_swiglu_fusion: 默认为False。指定`--no_bias_dropout_fusion true`,用于禁止bias和swiglu融合。 +- bias_dropout_fusion: 默认为True。用于开启bias和dropout的融合。 +- bias_swiglu_fusion: 默认为True。用于开启bias和swiglu融合。 - no_rope_fusion: 默认为False。指定`--no_rope_fusion true`用于禁止rope融合。 - **当使用mrope等不支持rope_fusion的位置编码时,该参数会自动设置为True**。 -- no_gradient_accumulation_fusion: 默认为False。指定`--no_gradient_accumulation_fusion true`用于禁用梯度累加融合。 +- gradient_accumulation_fusion: 默认为True。用于开启梯度累加融合。 - 🔥cross_entropy_loss_fusion: 启动交叉熵损失计算融合。默认为False。 - cross_entropy_fusion_impl: 交叉熵损失融合的实现。可选为'native'和'te'。默认为'native'。 - calculate_per_token_loss: 根据全局批次中的非填充token数量来对交叉熵损失进行缩放。默认为True。 diff --git a/docs/source/Megatron-SWIFT/Quick-start.md b/docs/source/Megatron-SWIFT/Quick-start.md index 9ac9200d42..7566f4ea6d 100644 --- a/docs/source/Megatron-SWIFT/Quick-start.md +++ b/docs/source/Megatron-SWIFT/Quick-start.md @@ -29,7 +29,7 @@ pip install pybind11 pip install --no-build-isolation transformer_engine[pytorch] # apex -# 提示:Megatron-SWIFT可以在不含apex的环境下运行,额外设置`--no_gradient_accumulation_fusion true`即可。 +# 提示:Megatron-SWIFT可以在不含apex的环境下运行,额外设置`--gradient_accumulation_fusion false`即可。 git clone https://github.com/NVIDIA/apex cd apex pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ diff --git a/docs/source_en/Instruction/Frequently-asked-questions.md b/docs/source_en/Instruction/Frequently-asked-questions.md index 8a578a522f..64be698fdd 100644 --- a/docs/source_en/Instruction/Frequently-asked-questions.md +++ b/docs/source_en/Instruction/Frequently-asked-questions.md @@ -592,7 +592,7 @@ Saving checkpoints per epoch is not yet supported. ```text RuntimeError: ColumnParallelLinear was called with gradient_accumulation_fusion set to True but the custom CUDA extension fused_weight_gradient_mlp_cuda module is not found. To use gradient_accumulation_fusion you must install APEX with --cpp_ext and --cuda_ext. For example: pip install --global-option="--cpp_ext" --global-option="--cuda_ext ." Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion. ``` -Set `--no_gradient_accumulation_fusion true`. +Set `--gradient_accumulation_fusion false`. ### Q163: For MoE LoRA training, if target_modules is set to all-linear, does this include the router modules? It depends on whether the gate is implemented as nn.Linear. If it's an nn.Parameter, it won't be trained. For details, see the command-line parameter [target_parameters](https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html#tuner-arguments). diff --git a/docs/source_en/Megatron-SWIFT/Ascend.md b/docs/source_en/Megatron-SWIFT/Ascend.md index 419c1ee5df..b2454cf229 100644 --- a/docs/source_en/Megatron-SWIFT/Ascend.md +++ b/docs/source_en/Megatron-SWIFT/Ascend.md @@ -186,7 +186,7 @@ def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_par ### Enable -Additionally, since msprobe does not support fusion computation, you need to add `--no_bias_dropout_fusion True`, `--no_bias_swiglu_fusion True`, `--cross_entropy_loss_fusion False` to the launch script. +Additionally, since msprobe does not support fusion computation, you need to add `--bias_dropout_fusion false`, `--bias_swiglu_fusion false`, `--cross_entropy_loss_fusion false` to the launch script. #### Example ```shell @@ -200,7 +200,7 @@ megatron sft \ 'swift/self-cognition#500' \ --tensor_model_parallel_size 2 \ ... - --no_bias_dropout_fusion True \ - --no_bias_swiglu_fusion True \ - --cross_entropy_loss_fusion False + --bias_dropout_fusion false \ + --bias_swiglu_fusion false \ + --cross_entropy_loss_fusion false ``` diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 2ff8fcb488..7a0ea12033 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -25,11 +25,11 @@ - 🔥log_interval: Log interval (unit: iters), default is 5. - tensorboard_dir: Directory where TensorBoard logs are written. Default is None, meaning logs will be stored in the `f'{save}/runs'` directory. - masked_softmax_fusion: Defaults to True. Used to enable the fusion of scaling, masking, and softmax for query_key_value. -- no_bias_dropout_fusion: Default is False. Disables bias and dropout fusion. -- no_bias_swiglu_fusion: Default is False. Specify `--no_bias_dropout_fusion true` to disable bias and swiglu fusion. +- bias_dropout_fusion: Defaults to True. Used to enable the fusion of bias and dropout. +- bias_swiglu_fusion: Defaults to True. Used to enable the fusion of bias and swiglu. - no_rope_fusion: Default is False. Specify `--no_rope_fusion true` to disable rope fusion. - **When using position embedding such as mrope that do not support RoPE fusion, this parameter will be automatically set to True**. -- no_gradient_accumulation_fusion: Default is False. Specify `--no_gradient_accumulation_fusion true` to disable gradient accumulation fusion. +- gradient_accumulation_fusion: Defaults to True. Used to enable gradient accumulation fusion. - 🔥cross_entropy_loss_fusion: Enables cross-entropy loss calculation fusion. Default is False. - cross_entropy_fusion_impl: Implementation of cross-entropy loss fusion. Options include 'native' and 'te'. Defaults to 'native'. - calculate_per_token_loss: Scales the cross-entropy loss according to the number of non-padded tokens in the global batch. Default is True. diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index b23c5a4c83..197f3e469f 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -28,7 +28,7 @@ pip install pybind11 pip install --no-build-isolation transformer_engine[pytorch] # apex -# Note: Megatron-SWIFT can run in environments without apex by setting `--no_gradient_accumulation_fusion true`. +# Note: Megatron-SWIFT can run in environments without apex by setting `--gradient_accumulation_fusion false`. git clone https://github.com/NVIDIA/apex cd apex pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ diff --git a/examples/ascend/multi-node/megatron/node1.sh b/examples/ascend/multi-node/megatron/node1.sh index f0722b145c..3006b14271 100644 --- a/examples/ascend/multi-node/megatron/node1.sh +++ b/examples/ascend/multi-node/megatron/node1.sh @@ -24,7 +24,7 @@ megatron sft \ --recompute_granularity selective \ --recompute_modules core_attn \ --cross_entropy_loss_fusion true \ - --no_gradient_accumulation_fusion true \ + --gradient_accumulation_fusion false \ --lr 1e-4 \ --lr_warmup_fraction 0.05 \ --min_lr 1e-5 \ diff --git a/examples/ascend/multi-node/megatron/node2.sh b/examples/ascend/multi-node/megatron/node2.sh index e5cd8f9307..218581b102 100644 --- a/examples/ascend/multi-node/megatron/node2.sh +++ b/examples/ascend/multi-node/megatron/node2.sh @@ -24,7 +24,7 @@ megatron sft \ --recompute_granularity selective \ --recompute_modules core_attn \ --cross_entropy_loss_fusion true \ - --no_gradient_accumulation_fusion true \ + --gradient_accumulation_fusion false \ --lr 1e-4 \ --lr_warmup_fraction 0.05 \ --min_lr 1e-5 \ diff --git a/examples/ascend/train/qwen3/qwen3_lora_megatron.sh b/examples/ascend/train/qwen3/qwen3_lora_megatron.sh index bf8a8865bc..19f7b99d1f 100755 --- a/examples/ascend/train/qwen3/qwen3_lora_megatron.sh +++ b/examples/ascend/train/qwen3/qwen3_lora_megatron.sh @@ -32,7 +32,7 @@ megatron sft \ --no_save_optim true \ --no_save_rng true \ --dataset_num_proc 4 \ - --no_gradient_accumulation_fusion true \ + --gradient_accumulation_fusion false \ --masked_softmax_fusion false \ --model_author swift \ --model_name swift-robot diff --git a/examples/ascend/train/qwen3_next/qwen3_next_megatron.sh b/examples/ascend/train/qwen3_next/qwen3_next_megatron.sh index 8f5a0c076c..13e2fc766d 100644 --- a/examples/ascend/train/qwen3_next/qwen3_next_megatron.sh +++ b/examples/ascend/train/qwen3_next/qwen3_next_megatron.sh @@ -34,7 +34,7 @@ megatron sft \ --no_save_optim true \ --no_save_rng true \ --dataset_num_proc 4 \ - --no_gradient_accumulation_fusion true \ + --gradient_accumulation_fusion false \ --masked_softmax_fusion false \ --model_author swift \ --model_name swift-robot diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 29c4f243b1..f71c378203 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -420,11 +420,11 @@ class MegatronArguments(ExtraMegatronArguments): log_interval: int = 5 tensorboard_dir: Optional[str] = None masked_softmax_fusion: bool = True - no_bias_dropout_fusion: Optional[bool] = None - no_bias_swiglu_fusion: bool = False - no_bias_gelu_fusion: bool = False + bias_dropout_fusion: Optional[bool] = None + bias_swiglu_fusion: bool = True + bias_gelu_fusion: bool = True no_rope_fusion: Optional[bool] = None - no_gradient_accumulation_fusion: bool = False + gradient_accumulation_fusion: bool = True cross_entropy_loss_fusion: bool = False cross_entropy_fusion_impl: Literal['native', 'te'] = 'native' calculate_per_token_loss: Optional[bool] = None @@ -666,8 +666,8 @@ def _set_default(self): self.task_type = 'causal_lm' if self.calculate_per_token_loss is None: self.calculate_per_token_loss = self.task_type == 'causal_lm' - if self.no_bias_dropout_fusion is None: - self.no_bias_dropout_fusion = False + if self.bias_dropout_fusion is None: + self.bias_dropout_fusion = True # moe MegatronArguments._set_moe_default(self) # log @@ -757,12 +757,12 @@ def __post_init__(self): if self.save_strategy == 'epoch': self.save_interval = 1 self.eval_interval = 1 - if not self.no_gradient_accumulation_fusion: + if self.gradient_accumulation_fusion: try: import apex except ImportError: logger.warning('apex is not installed, so gradient accumulation fusion is disabled.') - self.no_gradient_accumulation_fusion = True + self.gradient_accumulation_fusion = False if isinstance(self.ref_adapters, str): self.ref_adapters = [self.ref_adapters] if self.eval_interval is None: diff --git a/swift/megatron/utils/config.py b/swift/megatron/utils/config.py index ac282d95ec..4cb902e478 100644 --- a/swift/megatron/utils/config.py +++ b/swift/megatron/utils/config.py @@ -139,7 +139,7 @@ def convert_hf_config(config) -> Dict[str, Any]: res['rotary_interleaved'] = True elif llm_model_type == 'gpt_oss': res['add_bias_linear'] = True - res['no_bias_dropout_fusion'] = True + res['bias_dropout_fusion'] = False res['softmax_type'] = 'learnable' res['swiglu'] = False res['quick_geglu'] = True diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index 71c5ab8c81..685c03a4e3 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -124,9 +124,9 @@ def core_transformer_config_from_args(args, config_class=None): if args.swiglu: kw_args['activation_func'] = F.silu kw_args['gated_linear_unit'] = True - kw_args['bias_activation_fusion'] = not args.no_bias_swiglu_fusion + kw_args['bias_activation_fusion'] = args.bias_swiglu_fusion else: - kw_args['bias_activation_fusion'] = not args.no_bias_gelu_fusion + kw_args['bias_activation_fusion'] = args.bias_gelu_fusion if args.quick_geglu: assert not args.swiglu kw_args['gated_linear_unit'] = True From 92ab184547349f65527101debd2d9e579fde3b2e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 3 Feb 2026 11:25:51 +0800 Subject: [PATCH 08/43] update --- .../Megatron-SWIFT/Command-line-parameters.md | 8 ++++---- .../Megatron-SWIFT/Command-line-parameters.md | 8 ++++---- swift/megatron/arguments/megatron_args.py | 19 ++++++++++--------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 9629958c41..669cb960e4 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -26,8 +26,8 @@ - masked_softmax_fusion: 默认为True。用于开启query_key_value的scaling, masking, and softmax融合。 - bias_dropout_fusion: 默认为True。用于开启bias和dropout的融合。 - bias_swiglu_fusion: 默认为True。用于开启bias和swiglu融合。 -- no_rope_fusion: 默认为False。指定`--no_rope_fusion true`用于禁止rope融合。 - - **当使用mrope等不支持rope_fusion的位置编码时,该参数会自动设置为True**。 +- apply_rope_fusion: 默认为True。用于开启rope融合。 + - **当使用mrope等不支持rope_fusion的位置编码时,该参数会自动设置为False**。 - gradient_accumulation_fusion: 默认为True。用于开启梯度累加融合。 - 🔥cross_entropy_loss_fusion: 启动交叉熵损失计算融合。默认为False。 - cross_entropy_fusion_impl: 交叉熵损失融合的实现。可选为'native'和'te'。默认为'native'。 @@ -53,7 +53,7 @@ - seed: python、numpy、pytorch和cuda的随机种子,默认为42。 - 🔥num_workers: dataloader的workers数量,默认为4。 - 注意:若设置`--streaming true`,则设置为1。 -- no_data_sharding: 当`--train_dataloader_shuffle true`时对 train_dataloader 生效,默认为False。该参数控制数据集随机的范围。若设置为False,则先对数据集进行分片,然后对每个分片进行随机处理(略节约内存);若设置为True,则先对数据集进行随机,再进行分片(更好的随机效果)。使用该参数需"ms-swift>=3.12"。 +- data_sharding: 当`--train_dataloader_shuffle true`时对 train_dataloader 生效,默认为False。该参数控制数据集随机的范围。若设置为True,则先对数据集进行分片,然后对每个分片进行随机处理(略节约内存);若设置为False,则先对数据集进行随机,再进行分片(更好的随机效果)。 - seq_length: 默认为None,即设置为`max_length`。对数据集长度进行限制建议使用“基本参数”中的`--max_length`控制,无需设置此参数。 - use_cpu_initialization: 在cpu上初始化权重,默认为False。在进行HF和MCore权重转换时会被使用。通常不需要修改该值。 - 🔥megatron_extra_kwargs: 额外需要透传入megatron的其他参数,使用json传递。默认为None。 @@ -135,7 +135,7 @@ - tensorboard_log_interval: 记录到tensorboard的间隔(steps),默认为1。 - tensorboard_queue_size: 用于暂存事件和摘要的 TensorBoard 队列大小;当队列中待处理的事件和摘要数量达到该大小时,下一次调用 "add" 相关方法会触发将数据刷新写入磁盘。默认为50。 - log_timers_to_tensorboard: 记录timers到tensorboard。默认为True。 -- no_log_learning_rate_to_tensorboard: 不记录学习率到tensorboard。默认为False。 +- log_learning_rate_to_tensorboard: 记录学习率到tensorboard。默认为True。 - log_validation_ppl_to_tensorboard: 将验证困惑度写入tensorboard。默认为True。 - log_memory_to_tensorboard: 将内存日志写入tensorboard。默认为True。 - logging_level: 日志级别。默认为None。 diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 7a0ea12033..e0899ad864 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -27,8 +27,8 @@ - masked_softmax_fusion: Defaults to True. Used to enable the fusion of scaling, masking, and softmax for query_key_value. - bias_dropout_fusion: Defaults to True. Used to enable the fusion of bias and dropout. - bias_swiglu_fusion: Defaults to True. Used to enable the fusion of bias and swiglu. -- no_rope_fusion: Default is False. Specify `--no_rope_fusion true` to disable rope fusion. - - **When using position embedding such as mrope that do not support RoPE fusion, this parameter will be automatically set to True**. +- apply_rope_fusion: Defaults to True. Used to enable RoPE (Rotary Position Embedding) fusion. + - **When using position embedding such as mrope that do not support RoPE fusion, this parameter will be automatically set to False**. - gradient_accumulation_fusion: Defaults to True. Used to enable gradient accumulation fusion. - 🔥cross_entropy_loss_fusion: Enables cross-entropy loss calculation fusion. Default is False. - cross_entropy_fusion_impl: Implementation of cross-entropy loss fusion. Options include 'native' and 'te'. Defaults to 'native'. @@ -54,7 +54,7 @@ - seed: Random seed for python, numpy, pytorch, and cuda, default is 42. - 🔥num_workers: Number of workers for the dataloader, default is 4. - Note: If `--streaming true` is set, it will be set to 1. -- no_data_sharding: Takes effect on train_dataloader when `--train_dataloader_shuffle true` is set. Defaults to False. This parameter controls the scope of dataset shuffling. If set to False, the dataset is first sharded, then each shard is shuffled independently (slightly more memory efficient); if set to True, the dataset is shuffled globally first, then sharded (better randomization). Requires "ms-swift>=3.12". +- data_sharding: Takes effect on train_dataloader when `--train_dataloader_shuffle true` is set. Defaults to False. This parameter controls the scope of dataset shuffling. If set to True, the dataset is first sharded, then each shard is shuffled independently (slightly more memory efficient); if set to False, the dataset is shuffled globally first, then sharded (better randomization). - seq_length: Defaults to `None`, which means it will be set to `max_length`. To limit the sequence length of the dataset, it is recommended to use the `--max_length` argument under "Basic Parameters" instead; this parameter does not need to be set explicitly. - use_cpu_initialization: Initialize weights on the CPU. Defaults to `False`. This option is used during weight conversion between Hugging Face (HF) and MCore formats. The value typically does not need to be modified. - 🔥megatron_extra_kwargs: Additional arguments to be passed through directly to Megatron, provided as a JSON string. Defaults to `None`. @@ -141,7 +141,7 @@ For guidance on selecting parallelization strategies, please refer to the [Train - tensorboard_log_interval: Interval (steps) for logging to TensorBoard, default is 1. - tensorboard_queue_size: Size of the TensorBoard queue for buffering pending events and summaries. When the number of pending items reaches this value, the next call to an "add" method will trigger a flush to disk. The default is 50. - log_timers_to_tensorboard: Logs timers to TensorBoard. Default is True. -- no_log_learning_rate_to_tensorboard: Do not log learning rate to TensorBoard. Default is False. +- log_learning_rate_to_tensorboard: Do not log learning rate to TensorBoard. Default is True. - log_validation_ppl_to_tensorboard: Writes validation perplexity to TensorBoard. Default is True. - log_memory_to_tensorboard: Writes memory logs to TensorBoard. Default is True. - logging_level: Logging level. Default is None. diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index f71c378203..d0f693273a 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -423,7 +423,7 @@ class MegatronArguments(ExtraMegatronArguments): bias_dropout_fusion: Optional[bool] = None bias_swiglu_fusion: bool = True bias_gelu_fusion: bool = True - no_rope_fusion: Optional[bool] = None + apply_rope_fusion: Optional[bool] = None gradient_accumulation_fusion: bool = True cross_entropy_loss_fusion: bool = False cross_entropy_fusion_impl: Literal['native', 'te'] = 'native' @@ -596,7 +596,7 @@ class MegatronArguments(ExtraMegatronArguments): tensorboard_log_interval: int = 1 tensorboard_queue_size: int = 50 log_timers_to_tensorboard: bool = True - no_log_learning_rate_to_tensorboard: bool = False + log_learning_rate_to_tensorboard: bool = True log_validation_ppl_to_tensorboard: bool = True log_memory_to_tensorboard: bool = True logging_level: Optional[str] = None @@ -612,7 +612,7 @@ class MegatronArguments(ExtraMegatronArguments): seed: int = 42 seq_length: Optional[int] = None num_workers: int = 4 - no_data_sharding: bool = False + data_sharding: bool = False def _set_default(self): if self.mlp_padding_free and (self.sequence_parallel or self.context_parallel_size > 1): @@ -746,6 +746,7 @@ def __post_init__(self): if hasattr(self, 'ddp_timeout'): self.distributed_timeout_minutes = self.ddp_timeout // 60 self.group_query_attention = self.num_query_groups > 1 + self.fp8 = self.fp8_format # compat megatron-lm if self.rope_scaling is not None: self.rope_scaling = json_parse_to_dict(self.rope_scaling) if 'type' in self.rope_scaling and 'rope_type' not in self.rope_scaling: @@ -782,17 +783,17 @@ def __post_init__(self): self._init_moe() self._init_mixed_precision() - self._init_no_rope_fusion() + self._init_apply_rope_fusion() - def _init_no_rope_fusion(self): - if self.no_rope_fusion is not None: + def _init_apply_rope_fusion(self): + if self.apply_rope_fusion is not None: return if self.multi_latent_attention or self.rotary_interleaved: # Upgrading transformer_engine requires checking here. - self.no_rope_fusion = True + self.apply_rope_fusion = False else: - self.no_rope_fusion = False - logger.info(f'Setting args.no_rope_fusion: {self.no_rope_fusion}.') + self.apply_rope_fusion = True + logger.info(f'Setting args.apply_rope_fusion: {self.apply_rope_fusion}.') def _init_vpp_size(self): # TODO From a5fea817b412c9e793d7154cf0d4d4e3f2ff0549 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 3 Feb 2026 11:42:28 +0800 Subject: [PATCH 09/43] update --- swift/megatron/convert.py | 4 ++-- swift/megatron/init.py | 9 +++++---- swift/megatron/pipelines/export/export.py | 4 ++-- swift/megatron/trainers/gkd_trainer.py | 2 +- swift/megatron/trainers/grpo_trainer.py | 2 +- swift/megatron/utils/convert_utils.py | 11 ++++++----- 6 files changed, 17 insertions(+), 15 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 6151249245..278f09308c 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -74,7 +74,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.') # Place it at the end to avoid test_convert_precision affecting precision. if args.test_convert_precision: - test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) + test_convert_precision(megatron_args, hf_model, mg_model, template, test_convert_dtype=args.test_convert_dtype) def convert_mcore2hf(args: ExportArguments) -> None: @@ -131,7 +131,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.') if args.test_convert_precision: hf_model, template = prepare_model_template(args, model=args.output_dir) - test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) + test_convert_precision(megatron_args, hf_model, mg_model, template, test_convert_dtype=args.test_convert_dtype) elif args.to_mcore: if args.thread_count is None: checkpoint_size = sum(get_n_params_grads(mg_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 38bc96a606..2bdf6cd0d3 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -518,6 +518,7 @@ def sharded_state_dict( def _patch_TransformerLayer(): import megatron.core from megatron.core.transformer import TransformerLayer + from megatron.core import mpu _origin_forward = TransformerLayer.forward mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -531,7 +532,7 @@ def forward(self, *_args, **kwargs): if not mcore_013: return _origin_forward(self, *_args, **kwargs) hidden_states, context = self._forward_attention(*_args, **kwargs) - args = get_args() + args = self.config.args mlp_padding_free = args.mlp_padding_free and 'attention_mask' in kwargs mask = None if mlp_padding_free and hidden_states.shape[1] > 1: @@ -660,6 +661,7 @@ def _write_item(self, *args, **kwargs): def _patch_mrope(): from megatron.core.models.common.embeddings.rotary_pos_embedding import MultimodalRotaryEmbedding import megatron.core + from megatron.core import mpu from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd from megatron.core.models.common.embeddings import rope_utils @@ -696,7 +698,7 @@ def forward(self, position_ids, mrope_section: List[int], packed_seq: bool = Fal seq_expanded = seq[:, :, None, :].float() # shape (3, bs, seq_length, dim) freqs = (inv_freq_expanded @ seq_expanded).transpose(2, 3) - args = get_args() + args = self.config.args if args.mrope_interleaved: freqs = apply_interleaved_mrope(freqs, mrope_section) emb = torch.cat((freqs, freqs), dim=-1) @@ -744,8 +746,7 @@ def _apply_rotary_pos_emb_thd( if cp_group is not None: cp_size = cp_group.size() else: - args = get_args() - cp_size = args.context_parallel_size + cp_size = mpu.get_context_parallel_world_size() cu_seqlens_for_batched = cu_seqlens // cp_size use_batched_rope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item() if not use_batched_rope: diff --git a/swift/megatron/pipelines/export/export.py b/swift/megatron/pipelines/export/export.py index b237250ab6..9733b58301 100644 --- a/swift/megatron/pipelines/export/export.py +++ b/swift/megatron/pipelines/export/export.py @@ -82,7 +82,7 @@ def convert_mcore2hf(self) -> None: device_map = args.device_map or 'auto' hf_model, template = prepare_model_template( args, device_map=device_map, **kwargs) if is_last_rank() else (None, template) - test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) + test_convert_precision(args, hf_model, mg_model, template, test_convert_dtype=args.test_convert_dtype) dist.barrier() def convert_hf2mcore(self) -> None: @@ -135,7 +135,7 @@ def convert_hf2mcore(self) -> None: device_map = args.device_map or 'auto' hf_model, template = prepare_model_template( args, device_map=device_map) if is_last_rank() else (None, template) - test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) + test_convert_precision(args, hf_model, mg_model, template, test_convert_dtype=args.test_convert_dtype) dist.barrier() else: logger.warning('Skip test_convert_precision because `--adapter_load` is specified.') diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index 3d1c02efb8..94b7a2c663 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -441,7 +441,7 @@ def _compute_teacher_logits(self, encoded_batches: List[Dict], vp_stage: Optiona teacher_data.pop('labels', None) # Teacher forward with args override for correct hidden_size with self.load_teacher_model_context(), self._teacher_args_context(), torch.no_grad(): - teacher_logits = forward_step_helper(teacher_model, teacher_data) + teacher_logits = forward_step_helper(self.args, teacher_model, teacher_data) if teacher_logits is not None: teacher_logits = teacher_logits.detach() encoded_batch['teacher_logits'] = teacher_logits diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 2bd2fb219b..f9e7fbfd19 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1504,7 +1504,7 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False): context = torch.no_grad() if no_grad else nullcontext() with context: - output_tensor = forward_step_helper(model, data) + output_tensor = forward_step_helper(self.args, model, data) # packed_seq_params only exists in padding_free mode packed_seq_params = data.get('packed_seq_params') diff --git a/swift/megatron/utils/convert_utils.py b/swift/megatron/utils/convert_utils.py index 3c4d701001..c4a51a3842 100644 --- a/swift/megatron/utils/convert_utils.py +++ b/swift/megatron/utils/convert_utils.py @@ -143,8 +143,9 @@ def get_examples(is_multimodal: bool) -> Dict[str, Any]: return data -def test_convert_precision(args, hf_model, mg_model, template): - torch_dtype = args.test_convert_dtype +def test_convert_precision(args, hf_model, mg_model, template, test_convert_dtype=None): + if test_convert_dtype is None: + test_convert_dtype = getattr(args, 'test_convert_dtype', torch.float32) template.set_mode('train') _test_params_sum(mg_model) @@ -166,7 +167,7 @@ def test_convert_precision(args, hf_model, mg_model, template): ignore_modules = (model_arch.vision_tower + model_arch.aligner) if is_multimodal else [] hf_modules = _find_modules(hf_model, ignore_modules=ignore_modules) with torch.inference_mode(), _model_cpu_forward_context( - hf_modules, torch_dtype, share_embedding=share_embedding): + hf_modules, test_convert_dtype, share_embedding=share_embedding): hf_inputs.pop('text_position_ids', None) hf_logits = hf_model(**hf_inputs).logits hf_logits = hf_logits.to('cuda') @@ -195,8 +196,8 @@ def test_convert_precision(args, hf_model, mg_model, template): if n.endswith('router'): m.to(mg_dtype) with torch.inference_mode(), _model_cpu_forward_context( - mg_modules, torch_dtype, 'cuda', share_embedding=share_embedding, target_device=mg_device): - mg_logits = forward_step_helper(mg_model, mg_inputs, dtype=torch_dtype) + mg_modules, test_convert_dtype, 'cuda', share_embedding=share_embedding, target_device=mg_device): + mg_logits = forward_step_helper(args, mg_model, mg_inputs, dtype=test_convert_dtype) if args.tensor_model_parallel_size > 1 and args.task_type != 'seq_cls': from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region if mg_logits is not None: From 12c39bc322567734a29d01fc68f7fed807b10ff4 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 3 Feb 2026 14:01:30 +0800 Subject: [PATCH 10/43] update --- swift/megatron/convert.py | 7 ++++--- swift/megatron/pipelines/export/export.py | 4 ++-- swift/megatron/utils/__init__.py | 2 +- swift/megatron/utils/megatron_lm_utils.py | 14 +++++++++++--- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 278f09308c..c7f04d18e5 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -14,7 +14,7 @@ from .arguments import MegatronArguments from .model import get_megatron_model_meta from .utils import (convert_hf_config, initialize_megatron, patch_load_base_checkpoint, patch_torch_dist_shard, - test_convert_precision) + test_convert_precision, save_mcore_checkpoint) logger = get_logger() @@ -70,7 +70,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: if not _test_convert_precision: args.save_args() logger.info('Saving the model...') - mg_save_checkpoint(1, [mg_model], None, None, 0) + save_mcore_checkpoint(args, [mg_model]) logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.') # Place it at the end to avoid test_convert_precision affecting precision. if args.test_convert_precision: @@ -131,7 +131,8 @@ def convert_mcore2hf(args: ExportArguments) -> None: logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.') if args.test_convert_precision: hf_model, template = prepare_model_template(args, model=args.output_dir) - test_convert_precision(megatron_args, hf_model, mg_model, template, test_convert_dtype=args.test_convert_dtype) + test_convert_precision( + megatron_args, hf_model, mg_model, template, test_convert_dtype=args.test_convert_dtype) elif args.to_mcore: if args.thread_count is None: checkpoint_size = sum(get_n_params_grads(mg_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 diff --git a/swift/megatron/pipelines/export/export.py b/swift/megatron/pipelines/export/export.py index 9733b58301..2ec25399f9 100644 --- a/swift/megatron/pipelines/export/export.py +++ b/swift/megatron/pipelines/export/export.py @@ -5,14 +5,14 @@ import torch.distributed as dist from megatron.core import mpu -# from megatron.training import initialize_megatron # from megatron.training.checkpointing import load_checkpoint # from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint from transformers.utils import strtobool from swift.megatron.arguments import MegatronExportArguments from swift.megatron.convert import test_convert_precision -from swift.megatron.utils import adapter_state_dict_context, patch_load_base_checkpoint, prepare_mcore_model +from swift.megatron.utils import (adapter_state_dict_context, initialize_megatron, patch_load_base_checkpoint, + prepare_mcore_model) from swift.pipelines import SwiftPipeline, prepare_model_template from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_last_rank diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index bd4cf9b252..3d5eb3da81 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -2,7 +2,7 @@ from .config import convert_hf_config from .convert_utils import test_convert_precision -from .megatron_lm_utils import core_transformer_config_from_args, initialize_megatron +from .megatron_lm_utils import core_transformer_config_from_args, initialize_megatron, save_mcore_checkpoint, load_mcore_checkpoint from .patcher import patch_load_base_checkpoint, patch_merge_fn, patch_torch_dist_shard from .utils import (MegatronTrainerState, adapter_state_dict_context, copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index 685c03a4e3..eb53241b06 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -7,10 +7,11 @@ import torch import torch.nn.functional as F from megatron.core import mpu, tensor_parallel +from megatron.core.utils import unwrap_model from megatron.core.fusions.fused_bias_geglu import quick_gelu from megatron.core.transformer import MLATransformerConfig, TransformerConfig -from swift.utils import get_logger, is_master, seed_everything +from swift.utils import get_logger, init_process_group, is_master, seed_everything, set_device logger = get_logger() @@ -36,7 +37,8 @@ def create_group(ranks=None, timeout=None, *_args, **kwargs): def _initialize_mpu(args): """Initialize torch.distributed and core model parallel.""" if not torch.distributed.is_initialized(): - raise ValueError('torch.distributed is not initialized') + set_device() + init_process_group(args.distributed_backend, args.ddp_timeout) args.rank = torch.distributed.get_rank() args.world_size = torch.distributed.get_world_size() @@ -61,7 +63,7 @@ def _initialize_mpu(args): def _set_random_seed( seed_: int, - data_parallel_random_init: bool = True, + data_parallel_random_init: bool = False, te_rng_tracker: bool = False, inference_rng_tracker: bool = False, use_cudagraphable_rng: bool = False, @@ -143,3 +145,9 @@ def core_transformer_config_from_args(args, config_class=None): config.args = args return config + +def save_mcore_checkpoint(args, model, iteration=1): + model = unwrap_model(model) + +def load_mcore_checkpoint(args): + pass From 899f172d7e808c4152a41fce37318eae0fb271dd Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 3 Feb 2026 17:18:41 +0800 Subject: [PATCH 11/43] update --- swift/megatron/convert.py | 15 +- swift/megatron/trainers/base.py | 6 +- swift/megatron/trainers/utils.py | 2 +- swift/megatron/utils/__init__.py | 3 +- swift/megatron/utils/megatron_lm_utils.py | 193 +++++++++++++++++++++- 5 files changed, 198 insertions(+), 21 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index c7f04d18e5..d2c2fff936 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -13,8 +13,8 @@ from swift.utils import get_logger, get_n_params_grads, is_master from .arguments import MegatronArguments from .model import get_megatron_model_meta -from .utils import (convert_hf_config, initialize_megatron, patch_load_base_checkpoint, patch_torch_dist_shard, - test_convert_precision, save_mcore_checkpoint) +from .utils import (convert_hf_config, initialize_megatron, load_mcore_checkpoint, patch_load_base_checkpoint, + patch_torch_dist_shard, save_mcore_checkpoint, test_convert_precision) logger = get_logger() @@ -70,8 +70,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: if not _test_convert_precision: args.save_args() logger.info('Saving the model...') - save_mcore_checkpoint(args, [mg_model]) - logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.') + save_mcore_checkpoint(megatron_args, [mg_model]) # Place it at the end to avoid test_convert_precision affecting precision. if args.test_convert_precision: test_convert_precision(megatron_args, hf_model, mg_model, template, test_convert_dtype=args.test_convert_dtype) @@ -109,12 +108,12 @@ def convert_mcore2hf(args: ExportArguments) -> None: mg_model = megatron_model_meta.model_provider(megatron_args) if megatron_args.load is None: raise ValueError('Please specify `--mcore_model`.') - with patch_load_base_checkpoint(): - load_checkpoint([mg_model], None, None, strict=True) + # with patch_load_base_checkpoint(): + load_mcore_checkpoint(megatron_args, [mg_model], load_arg='load') if megatron_args.adapter_load is not None: peft_model = prepare_mcore_model(mg_model) - with adapter_state_dict_context(): - load_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) + # with adapter_state_dict_context(): + load_mcore_checkpoint(megatron_args, [mg_model], load_arg='adapter_load') logger.info('Merge LoRA...') mg_model = peft_model.merge_and_unload() logger.info('Megatron model created successfully.') diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 05e7d57e10..fea3421f29 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -447,10 +447,10 @@ def _load_iteration(self): if ckpt_dir is None: return 0, 0 logger.info(f'checkpoint_dir: {ckpt_dir}') - iteration_path = os.path.join(ckpt_dir, 'latest_checkpointed_iteration.txt') - if not os.path.exists(iteration_path): + tracker_path = os.path.join(ckpt_dir, 'latest_checkpointed_iteration.txt') + if not os.path.exists(tracker_path): return 0, 0 - with open(iteration_path, 'r') as f: + with open(tracker_path, 'r') as f: iteration = int(f.read()) common_path = os.path.join(ckpt_dir, f'iter_{iteration:07d}', 'common.pt') diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index f5a327a45c..7b3cf68645 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -128,7 +128,7 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): cp_size = mpu.get_context_parallel_world_size() if cp_size > 1: args = get_args() - keys = ['labels', 'attention_mask', 'position_ids', 'loss_scale'] + keys = ['labels', 'position_ids', 'loss_scale'] if not args.is_multimodal: # Multimodal models will handle CP in input_embeds. keys.append('input_ids') diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 3d5eb3da81..e8e0511cbc 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -2,7 +2,8 @@ from .config import convert_hf_config from .convert_utils import test_convert_precision -from .megatron_lm_utils import core_transformer_config_from_args, initialize_megatron, save_mcore_checkpoint, load_mcore_checkpoint +from .megatron_lm_utils import (core_transformer_config_from_args, initialize_megatron, load_mcore_checkpoint, + save_mcore_checkpoint) from .patcher import patch_load_base_checkpoint, patch_merge_fn, patch_torch_dist_shard from .utils import (MegatronTrainerState, adapter_state_dict_context, copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index eb53241b06..8d565a04b0 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -1,17 +1,28 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # Parts of the functions in this file are code borrowed from NVIDIA/Megatron-LM import dataclasses +import os +import random +from argparse import Namespace from contextlib import contextmanager from datetime import timedelta +import numpy as np import torch import torch.nn.functional as F -from megatron.core import mpu, tensor_parallel -from megatron.core.utils import unwrap_model +from megatron.core import dist_checkpointing, mpu, tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedObject +from megatron.core.dist_checkpointing.serialization import (get_default_load_sharded_strategy, + get_default_save_sharded_strategy) +from megatron.core.dist_checkpointing.strategies.fully_parallel import (FullyParallelLoadStrategyWrapper, + FullyParallelSaveStrategyWrapper) from megatron.core.fusions.fused_bias_geglu import quick_gelu +from megatron.core.msc_utils import open_file +from megatron.core.num_microbatches_calculator import update_num_microbatches from megatron.core.transformer import MLATransformerConfig, TransformerConfig +from megatron.core.utils import unwrap_model -from swift.utils import get_logger, init_process_group, is_master, seed_everything, set_device +from swift.utils import check_json_format, get_logger, init_process_group, is_master, seed_everything, set_device logger = get_logger() @@ -56,9 +67,9 @@ def _initialize_mpu(args): distributed_timeout_minutes=args.distributed_timeout_minutes, ) if is_master(): - logger.info(f'tp: {args.tensor_model_parallel_size}, pp: {args.pipeline_model_parallel_size}, ' - f'vpp: {args.virtual_pipeline_model_parallel_size}, cp: {args.context_parallel_size}, ' - f'ep: {args.expert_model_parallel_size}, etp: {args.expert_tensor_parallel_size}') + logger.info(f'TP: {args.tensor_model_parallel_size}, PP: {args.pipeline_model_parallel_size}, ' + f'VPP: {args.virtual_pipeline_model_parallel_size}, CP: {args.context_parallel_size}, ' + f'EP: {args.expert_model_parallel_size}, ETP: {args.expert_tensor_parallel_size}') def _set_random_seed( @@ -146,8 +157,174 @@ def core_transformer_config_from_args(args, config_class=None): return config + +def _get_rng_state(): + """Collect rng state across data parallel ranks.""" + rng_state = { + 'random_rng_state': random.getstate(), + 'np_rng_state': np.random.get_state(), + 'torch_rng_state': torch.get_rng_state(), + 'cuda_rng_state': torch.cuda.get_rng_state(), + 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states() + } + + # data_parallel_random_init False + rng_state_list = [rng_state] + + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + rng_state_list = ShardedObject( + 'rng_state', + rng_state_list, (pp_size, tp_size), (pp_rank, tp_rank), + replica_id=mpu.get_data_parallel_rank(with_context_parallel=True)) + return rng_state_list + + +def _generate_state_dict(args, model, iteration=None, model_sd_kwargs=None): + model_sd_kwargs = model_sd_kwargs or {} + state_dict = {} + state_dict['args'] = Namespace(**check_json_format(args.__dict__)) + if iteration is not None: + state_dict['iteration'] = iteration + for i, m in enumerate(model): + key = 'model' + if len(model) > 1: + key = f'model{i}' + model_sd = model[i].sharded_state_dict(**model_sd_kwargs) + state_dict[key] = model_sd + return state_dict + + def save_mcore_checkpoint(args, model, iteration=1): model = unwrap_model(model) + rng_state = _get_rng_state() + checkpoint_dir = os.path.join(args.save, f'iter_{iteration:07d}') + sharded_sd_metadata = { + 'distrib_optim_sharding_type': 'dp_reshardable', + 'singleton_local_shards': False, + 'chained_optim_avoid_prefix': True + } + os.makedirs(checkpoint_dir, exist_ok=True) + save_strategy = get_default_save_sharded_strategy() + save_strategy = FullyParallelSaveStrategyWrapper( + save_strategy, + mpu.get_data_parallel_group(with_context_parallel=True), + ) + + state_dict = _generate_state_dict(args, model, iteration, model_sd_kwargs={'metadata': sharded_sd_metadata}) + async_save_request = dist_checkpointing.save( + state_dict, + checkpoint_dir, + save_strategy, + async_sharded_save=args.async_save, + validate_access_integrity=True, + content_metadata=sharded_sd_metadata) + + if not args.async_save: + assert async_save_request is None + # Wait so everyone is done (necessary) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + if is_master(): + tracker_path = os.path.join(args.save, 'latest_checkpointed_iteration.txt') + + def iter_finalize_fn(): + prev_iteration = 0 + save_retain_interval = getattr(args, 'save_retain_interval', None) # For backwards compatibility of tests. + if save_retain_interval is not None: + if os.path.exists(tracker_path): # TODO: Make this work with MSC remote paths? + with open_file(tracker_path, 'r') as f: + prev_iteration = int(f.read().strip()) + + with open_file(tracker_path, 'w') as f: + f.write(str(iteration)) + # TODO: delete_checkpoint + + if args.async_save: + assert async_save_request is not None + async_save_request.add_finalize_fn(iter_finalize_fn) + else: + iter_finalize_fn() + logger.info(f'Successfully saved Megatron model weights in `{args.save}`.') + + +def _load_iteration(tracker_path: str): + if not os.path.exists(tracker_path): + return 0 + with open(tracker_path, 'r') as f: + iteration = int(f.read()) + # Get the max iteration retrieved across the ranks. + if torch.distributed.is_initialized(): + iters_cuda = torch.tensor([iteration], dtype=torch.long, device='cuda') + torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX) + iteration = iters_cuda[0].item() + return iteration + + +def load_mcore_checkpoint(args, model, load_arg: str = 'load'): + if load_arg in {'adapter_load', 'ref_adapter_load'}: + is_peft_format = True + elif load_arg in {'load', 'ref_load'}: + is_peft_format = False + model = unwrap_model(model) + tracker_path = os.path.join(args.load, 'latest_checkpointed_iteration.txt') + iteration = _load_iteration(tracker_path) + checkpoint_dir = os.path.join(args.load, f'iter_{iteration:07d}') + state_dict = dist_checkpointing.load_common_state_dict(checkpoint_dir) + + ckpt_tp_pp = ( + state_dict['args'].tensor_model_parallel_size, + state_dict['args'].pipeline_model_parallel_size, + ) + run_tp_pp = ( + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, + ) + + # Determine if RNG state will be loaded + if (ckpt_tp_pp == run_tp_pp and not args.finetune and not args.no_load_rng + and not getattr(state_dict['args'], 'no_save_rng', False)): + gen_sd_rng_state = _get_rng_state() # we can load the rng state + else: + gen_sd_rng_state = None + if ckpt_tp_pp != run_tp_pp: + logger.info(f'(TP, PP) mismatch after resume ({run_tp_pp} vs {ckpt_tp_pp} from checkpoint): ' + 'RNG state will be ignored') + sharded_sd_metadata = dist_checkpointing.load_content_metadata(preloaded_state_dict=state_dict) + + sharded_state_dict = _generate_state_dict(args, model, model_sd_kwargs={'metadata': sharded_sd_metadata}) + load_kwargs = {'sharded_state_dict': sharded_state_dict} + load_strategy = get_default_load_sharded_strategy(checkpoint_dir) + load_strategy = FullyParallelLoadStrategyWrapper(load_strategy, + mpu.get_data_parallel_group(with_context_parallel=True)) + state_dict = dist_checkpointing.load(sharded_state_dict, checkpoint_dir, load_strategy) + + if state_dict is None: + return 0, 0 + + if args.finetune: + iteration = 0 + num_floating_point_operations_so_far = state_dict.get('num_floating_point_operations_so_far', 0) + if 'args' in state_dict and not args.finetune: + checkpoint_args = state_dict['args'] + args.consumed_train_samples = getattr(checkpoint_args, 'consumed_train_samples', 0) + args.skipped_train_samples = getattr(checkpoint_args, 'skipped_train_samples', 0) + update_num_microbatches(consumed_samples=args.consumed_train_samples, verbose=True) + args.consumed_valid_samples = getattr(checkpoint_args, 'consumed_valid_samples', 0) + + if len(model) == 1: + model[0].load_state_dict(state_dict['model']) + else: + for i, m in enumerate(model): + if f'model{i}' not in state_dict: + continue + m.load_state_dict(state_dict[f'model{i}']) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() -def load_mcore_checkpoint(args): - pass + logger.info(f'Successfully loaded Megatron model weights from: {args.load}') + return iteration, num_floating_point_operations_so_far From 6d918d855b3ba49eedc2b30b3b5c80887249215d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 3 Feb 2026 17:49:25 +0800 Subject: [PATCH 12/43] update --- swift/megatron/convert.py | 5 ++--- swift/megatron/pipelines/export/export.py | 9 +++------ swift/megatron/utils/__init__.py | 2 +- swift/megatron/utils/megatron_lm_utils.py | 9 +++++++-- swift/megatron/utils/patcher.py | 20 -------------------- 5 files changed, 13 insertions(+), 32 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index d2c2fff936..73e1c1cd8b 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -13,8 +13,8 @@ from swift.utils import get_logger, get_n_params_grads, is_master from .arguments import MegatronArguments from .model import get_megatron_model_meta -from .utils import (convert_hf_config, initialize_megatron, load_mcore_checkpoint, patch_load_base_checkpoint, - patch_torch_dist_shard, save_mcore_checkpoint, test_convert_precision) +from .utils import (convert_hf_config, initialize_megatron, load_mcore_checkpoint, patch_torch_dist_shard, + save_mcore_checkpoint, test_convert_precision) logger = get_logger() @@ -108,7 +108,6 @@ def convert_mcore2hf(args: ExportArguments) -> None: mg_model = megatron_model_meta.model_provider(megatron_args) if megatron_args.load is None: raise ValueError('Please specify `--mcore_model`.') - # with patch_load_base_checkpoint(): load_mcore_checkpoint(megatron_args, [mg_model], load_arg='load') if megatron_args.adapter_load is not None: peft_model = prepare_mcore_model(mg_model) diff --git a/swift/megatron/pipelines/export/export.py b/swift/megatron/pipelines/export/export.py index 2ec25399f9..cb3308bbe9 100644 --- a/swift/megatron/pipelines/export/export.py +++ b/swift/megatron/pipelines/export/export.py @@ -11,8 +11,7 @@ from swift.megatron.arguments import MegatronExportArguments from swift.megatron.convert import test_convert_precision -from swift.megatron.utils import (adapter_state_dict_context, initialize_megatron, patch_load_base_checkpoint, - prepare_mcore_model) +from swift.megatron.utils import adapter_state_dict_context, initialize_megatron, prepare_mcore_model from swift.pipelines import SwiftPipeline, prepare_model_template from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_last_rank @@ -46,8 +45,7 @@ def convert_mcore2hf(self) -> None: mg_model = megatron_model_meta.model_provider(args, pre_process=pre_process, post_process=post_process) bridge = megatron_model_meta.bridge_cls(args) if args.load is not None: - with patch_load_base_checkpoint(): - load_checkpoint([mg_model], None, None, strict=True) + load_checkpoint([mg_model], None, None, strict=True) elif args.model is not None: bridge.load_weights(mg_model, args.model_info.model_dir) else: @@ -102,8 +100,7 @@ def convert_hf2mcore(self) -> None: if args.model is not None: bridge.load_weights(mg_model, args.model_info.model_dir) elif args.load is not None: - with patch_load_base_checkpoint(): - load_checkpoint([mg_model], None, None, strict=True) + load_checkpoint([mg_model], None, None, strict=True) else: raise ValueError('Please specify `--load` or `--model`.') dist.barrier() diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index e8e0511cbc..759362581c 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -4,6 +4,6 @@ from .convert_utils import test_convert_precision from .megatron_lm_utils import (core_transformer_config_from_args, initialize_megatron, load_mcore_checkpoint, save_mcore_checkpoint) -from .patcher import patch_load_base_checkpoint, patch_merge_fn, patch_torch_dist_shard +from .patcher import patch_merge_fn, patch_torch_dist_shard from .utils import (MegatronTrainerState, adapter_state_dict_context, copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index 8d565a04b0..bc018cfd23 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -23,6 +23,7 @@ from megatron.core.utils import unwrap_model from swift.utils import check_json_format, get_logger, init_process_group, is_master, seed_everything, set_device +from .patcher import patch_merge_fn logger = get_logger() @@ -184,8 +185,7 @@ def _get_rng_state(): def _generate_state_dict(args, model, iteration=None, model_sd_kwargs=None): model_sd_kwargs = model_sd_kwargs or {} - state_dict = {} - state_dict['args'] = Namespace(**check_json_format(args.__dict__)) + state_dict = {'args': Namespace(**check_json_format(args.__dict__))} if iteration is not None: state_dict['iteration'] = iteration for i, m in enumerate(model): @@ -302,6 +302,11 @@ def load_mcore_checkpoint(args, model, load_arg: str = 'load'): mpu.get_data_parallel_group(with_context_parallel=True)) state_dict = dist_checkpointing.load(sharded_state_dict, checkpoint_dir, load_strategy) + if state_dict.get('sharded_state_dict') is not None: + model_keys = [k for k in sharded_state_dict.keys() if k.startswith('model')] # compat vpp + for k in model_keys: + patch_merge_fn(sharded_state_dict[k]) + if state_dict is None: return 0, 0 diff --git a/swift/megatron/utils/patcher.py b/swift/megatron/utils/patcher.py index b6ccc5b4ae..fe6193fe4c 100644 --- a/swift/megatron/utils/patcher.py +++ b/swift/megatron/utils/patcher.py @@ -43,23 +43,3 @@ def sh_ten_merge_fn(sub_state_dict): for v in state_dict_model.values(): if isinstance(v, ShardedTensorFactory) and 'apply_swiglu_sharded_factory' in v.merge_fn.__qualname__: v.merge_fn = sh_ten_merge_fn - - -@contextmanager -def patch_load_base_checkpoint(): - origin__load_base_checkpoint = checkpointing._load_base_checkpoint - - def _load_base_checkpoint(*_args, **kwargs): - sharded_state_dict = kwargs.get('sharded_state_dict') - if sharded_state_dict is None: - return origin__load_base_checkpoint(*_args, **kwargs) - model_keys = [k for k in sharded_state_dict.keys() if k.startswith('model')] # compat vpp - for k in model_keys: - patch_merge_fn(sharded_state_dict[k]) - return origin__load_base_checkpoint(*_args, **kwargs) - - checkpointing._load_base_checkpoint = _load_base_checkpoint - try: - yield - finally: - checkpointing._load_base_checkpoint = origin__load_base_checkpoint From 62f82b801e3b6d3d6a10ec7b85686b9b090ba394 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 3 Feb 2026 19:21:14 +0800 Subject: [PATCH 13/43] update --- docs/source/Megatron-SWIFT/Mcore-Bridge.md | 9 +++--- docs/source_en/Megatron-SWIFT/Mcore-Bridge.md | 9 +++--- swift/megatron/__init__.py | 4 +-- swift/megatron/convert.py | 2 +- swift/megatron/init.py | 17 ---------- swift/megatron/pipelines/export/export.py | 15 ++++----- swift/megatron/trainers/base.py | 25 +++++++-------- swift/megatron/trainers/dpo_trainer.py | 12 ++----- swift/megatron/trainers/embedding_trainer.py | 11 ++----- swift/megatron/trainers/gkd_trainer.py | 32 +++++++------------ swift/megatron/trainers/grpo_trainer.py | 26 +++++++-------- swift/megatron/trainers/kto_trainer.py | 12 ++----- swift/megatron/trainers/reranker_trainer.py | 11 ++----- swift/megatron/trainers/reward_trainer.py | 11 ++----- swift/megatron/trainers/rlhf_mixin.py | 8 ++--- swift/megatron/trainers/trainer.py | 13 ++------ swift/megatron/trainers/utils.py | 10 +----- 17 files changed, 76 insertions(+), 151 deletions(-) diff --git a/docs/source/Megatron-SWIFT/Mcore-Bridge.md b/docs/source/Megatron-SWIFT/Mcore-Bridge.md index cd417a9793..039e5a9b9e 100644 --- a/docs/source/Megatron-SWIFT/Mcore-Bridge.md +++ b/docs/source/Megatron-SWIFT/Mcore-Bridge.md @@ -285,9 +285,10 @@ megatron export \ ```python import torch -from swift.megatron import MegatronArguments, convert_hf_config, get_megatron_model_meta +from swift.megatron import ( + MegatronArguments, convert_hf_config, get_megatron_model_meta, initialize_megatron +) from swift.model import get_processor -from megatron.training.initialize import initialize_megatron model_id = 'Qwen/Qwen3-4B-Instruct-2507' processor = get_processor(model_id, download_model=True) @@ -327,10 +328,10 @@ LoRA权重的加载、导出和存储同理,运行`CUDA_VISIBLE_DEVICES=0,1,2, import torch from swift.megatron import ( - MegatronArguments, convert_hf_config, get_megatron_model_meta, prepare_mcore_model + MegatronArguments, convert_hf_config, get_megatron_model_meta, + prepare_mcore_model, initialize_megatron ) from swift.model import get_processor -from megatron.training.initialize import initialize_megatron model_id = 'Qwen/Qwen3-30B-A3B-Instruct-2507' processor = get_processor(model_id, download_model=True) diff --git a/docs/source_en/Megatron-SWIFT/Mcore-Bridge.md b/docs/source_en/Megatron-SWIFT/Mcore-Bridge.md index 6cda151ddd..86e84c9bf4 100644 --- a/docs/source_en/Megatron-SWIFT/Mcore-Bridge.md +++ b/docs/source_en/Megatron-SWIFT/Mcore-Bridge.md @@ -297,9 +297,10 @@ You need to create the following file (test.py), then run `CUDA_VISIBLE_DEVICES= ```python import torch -from swift.megatron import MegatronArguments, convert_hf_config, get_megatron_model_meta +from swift.megatron import ( + MegatronArguments, convert_hf_config, get_megatron_model_meta, initialize_megatron +) from swift.model import get_processor -from megatron.training.initialize import initialize_megatron model_id = 'Qwen/Qwen3-4B-Instruct-2507' _, processor = get_processor(model_id, download_model=True) @@ -341,10 +342,10 @@ Loading, exporting, and saving LoRA weights follows the same pattern. Run `CUDA_ import torch from swift.megatron import ( - MegatronArguments, convert_hf_config, get_megatron_model_meta, prepare_mcore_model + MegatronArguments, convert_hf_config, get_megatron_model_meta, + prepare_mcore_model, initialize_megatron ) from swift.model import get_processor -from megatron.training.initialize import initialize_megatron model_id = 'Qwen/Qwen3-30B-A3B-Instruct-2507' _, processor = get_processor(model_id, download_model=True) diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index 2a456c0d99..667afd5faf 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: from .pipelines import megatron_export_main, megatron_sft_main, megatron_pretrain_main, megatron_rlhf_main from .convert import convert_hf2mcore, convert_mcore2hf - from .utils import prepare_mcore_model, adapter_state_dict_context, convert_hf_config + from .utils import prepare_mcore_model, adapter_state_dict_context, convert_hf_config, initialize_megatron from .arguments import (MegatronSftArguments, MegatronPretrainArguments, MegatronRLHFArguments, MegatronExportArguments, MegatronArguments) from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model @@ -29,7 +29,7 @@ _import_structure = { 'pipelines': ['megatron_sft_main', 'megatron_pretrain_main', 'megatron_rlhf_main', 'megatron_export_main'], 'convert': ['convert_hf2mcore', 'convert_mcore2hf'], - 'utils': ['prepare_mcore_model', 'adapter_state_dict_context', 'convert_hf_config'], + 'utils': ['prepare_mcore_model', 'adapter_state_dict_context', 'convert_hf_config', 'initialize_megatron'], 'arguments': [ 'MegatronSftArguments', 'MegatronPretrainArguments', 'MegatronRLHFArguments', 'MegatronExportArguments', 'MegatronArguments' diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 73e1c1cd8b..70b39d3767 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -139,5 +139,5 @@ def convert_mcore2hf(args: ExportArguments) -> None: args.save_args() logger.info('Saving the model...') - mg_save_checkpoint(1, [mg_model], None, None, 0) + save_mcore_checkpoint(megatron_args, [mg_model]) logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.') diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 2bdf6cd0d3..5004c3fef2 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -630,7 +630,6 @@ def __repr__(self): def _patch_build_train_valid_test_datasets(): - from megatron.training import training def build_train_valid_test_datasets(build_train_valid_test_datasets_provider, *args, **kwargs): train_valid_test_num_samples = training.get_train_valid_test_num_samples() @@ -800,7 +799,6 @@ def _new_load_inline(*args, **kwargs): def _patch_megatron_swanlab(): - from megatron.training import global_vars, wandb_utils, get_args def _set_wandb_writer(*_args, **kwargs): args = get_args() @@ -836,20 +834,6 @@ def on_save_checkpoint_success(*_args, **kwargs): wandb_utils.on_save_checkpoint_success = on_save_checkpoint_success -def _patch_modelopt(): - from megatron.training import checkpointing - if not hasattr(checkpointing, 'save_sharded_modelopt_state'): - return - save_sharded_modelopt_state = checkpointing.save_sharded_modelopt_state - - def new_save_sharded_modelopt_state(model, *args, **kwargs): - if not model: - return - save_sharded_modelopt_state(model, *args, **kwargs) - - checkpointing.save_sharded_modelopt_state = new_save_sharded_modelopt_state - - def init_megatron_env(): os.environ.pop('VLLM_USE_MODELSCOPE', None) logging_level = logging.root.level @@ -867,7 +851,6 @@ def init_megatron_env(): _patch__write_item() _patch_mtp() # _patch_megatron_swanlab() - # _patch_modelopt() logging.root.setLevel(logging_level) # revert logger level from swift.megatron import tuners # patch lora try: diff --git a/swift/megatron/pipelines/export/export.py b/swift/megatron/pipelines/export/export.py index cb3308bbe9..4f703380d7 100644 --- a/swift/megatron/pipelines/export/export.py +++ b/swift/megatron/pipelines/export/export.py @@ -5,13 +5,12 @@ import torch.distributed as dist from megatron.core import mpu -# from megatron.training.checkpointing import load_checkpoint -# from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint from transformers.utils import strtobool from swift.megatron.arguments import MegatronExportArguments from swift.megatron.convert import test_convert_precision -from swift.megatron.utils import adapter_state_dict_context, initialize_megatron, prepare_mcore_model +from swift.megatron.utils import (adapter_state_dict_context, initialize_megatron, load_mcore_checkpoint, + prepare_mcore_model, save_mcore_checkpoint) from swift.pipelines import SwiftPipeline, prepare_model_template from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_last_rank @@ -45,7 +44,7 @@ def convert_mcore2hf(self) -> None: mg_model = megatron_model_meta.model_provider(args, pre_process=pre_process, post_process=post_process) bridge = megatron_model_meta.bridge_cls(args) if args.load is not None: - load_checkpoint([mg_model], None, None, strict=True) + load_mcore_checkpoint([mg_model], None, None, strict=True) elif args.model is not None: bridge.load_weights(mg_model, args.model_info.model_dir) else: @@ -53,7 +52,7 @@ def convert_mcore2hf(self) -> None: if args.adapter_load is not None: peft_model = prepare_mcore_model(mg_model) with adapter_state_dict_context(): - load_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) + load_mcore_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) if args.merge_lora: logger.info('Merge LoRA...') mg_model = peft_model.merge_and_unload() @@ -100,7 +99,7 @@ def convert_hf2mcore(self) -> None: if args.model is not None: bridge.load_weights(mg_model, args.model_info.model_dir) elif args.load is not None: - load_checkpoint([mg_model], None, None, strict=True) + load_mcore_checkpoint([mg_model], None, None, strict=True) else: raise ValueError('Please specify `--load` or `--model`.') dist.barrier() @@ -111,7 +110,7 @@ def convert_hf2mcore(self) -> None: bridge.load_weights(mg_model, args.adapters[0], is_peft_format=True) elif args.adapter_load is not None: with adapter_state_dict_context(): - load_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) + load_mcore_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) if args.merge_lora: logger.info('Merge LoRA...') mg_model = peft_model.merge_and_unload() @@ -122,7 +121,7 @@ def convert_hf2mcore(self) -> None: logger.info('Saving the model...') save_peft_format = args.tuner_type == 'lora' and not args.merge_lora with adapter_state_dict_context(is_peft_format=save_peft_format): - mg_save_checkpoint(1, [mg_model], None, None, 0) + save_mcore_checkpoint(args, [mg_model]) logger.info_if(f'Successfully saved Megatron model weights in `{args.save}`.', cond=is_last_rank()) # hf_model does not support loading args.adapter_load, so test_convert_precision cannot be performed support_convert_precision = args.adapter_load is None diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index fea3421f29..35b789fdfa 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -24,22 +24,22 @@ from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.moe_utils import track_moe_metrics from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper -from megatron.core.utils import StragglerDetector -# from megatron.training import (checkpointing, ft_integration, get_args, get_model, get_tensorboard_writer, get_timers, +from megatron.core.utils import StragglerDetector, unwrap_model +# from megatron.training import (checkpointing, ft_integration, get_args, get_model, get_tensorboard_writer, # get_wandb_writer, initialize, is_last_rank, one_logger_utils, pretrain, print_rank_0, # print_rank_last, training) -# from megatron.training.checkpointing import check_checkpoint_args, load_checkpoint, set_checkpoint_version +# from megatron.training.checkpointing import check_checkpoint_args, set_checkpoint_version # from megatron.training.dist_signal_handler import DistributedSignalHandler # from megatron.training.theoretical_memory_usage import report_theoretical_memory # from megatron.training.training import num_floating_point_operations -# from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory, unwrap_model +# from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory from modelscope import check_local_model_is_latest from packaging import version from tqdm.auto import tqdm from swift.megatron.tuners import LoraParallelLinear -from swift.megatron.utils import (adapter_state_dict_context, copy_original_module_weight, patch_merge_fn, - prepare_mcore_model) +from swift.megatron.utils import (adapter_state_dict_context, copy_original_module_weight, load_mcore_checkpoint, + patch_merge_fn, prepare_mcore_model) from swift.metrics import MeanMetric from swift.template import Template from swift.trainers import SwiftMixin, dynamic_gradient_checkpointing @@ -47,10 +47,10 @@ from .utils import (MegatronPretrainingRandomSampler, get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, get_swift_datasets_provider) -try: - from megatron.training.datasets.data_samplers import MegatronPretrainingSampler -except ImportError: - from megatron.legacy.data.data_samplers import MegatronPretrainingSampler +# try: +# from megatron.training.datasets.data_samplers import MegatronPretrainingSampler +# except ImportError: +# from megatron.legacy.data.data_samplers import MegatronPretrainingSampler try: from megatron.core.optimizer import param_group_identifier_keys @@ -65,7 +65,6 @@ class BaseMegatronTrainer(ABC): def __init__(self, args, template: Template): self.args = args self.template = template - self.stimer = StragglerDetector() self.unwrapped_models = [] self.wrapped_models = [] self.peft_models = [] @@ -514,10 +513,10 @@ def new_model_provider_func(*_args, **kwargs): copy_original_module_weight(m) if args.ref_adapter_load is not None: with self._patch_load_state_dict(self._load_adapter_base_checkpoint): - load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='ref_adapter_load', strict=False) + load_mcore_checkpoint(model, optimizer, opt_param_scheduler, load_arg='ref_adapter_load', strict=False) if args.adapter_load is not None: with adapter_state_dict_context(): - args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( + args.iteration, args.num_floating_point_operations_so_far = load_mcore_checkpoint( model, optimizer, opt_param_scheduler, load_arg='adapter_load', strict=False) if args.is_multimodal: for m in self.unwrapped_models: diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index d0821d3008..9ec1f85eed 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -4,7 +4,6 @@ import torch from megatron.core import mpu -from megatron.training import get_args, get_timers from torch.distributed.nn import all_reduce from swift.rlhf_trainers import DPOTrainer @@ -38,7 +37,7 @@ def __init__(self, args, template): def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed_seq_params): ref_output_tensor = output_tensor[:output_tensor.shape[0] // 2].detach() output_tensor = output_tensor[output_tensor.shape[0] // 2:] - args = get_args() + args = self.args num_samples = labels.shape[0] // 2 if packed_seq_params is None else packed_seq_params.num_samples logps = self.get_logps(output_tensor, labels, packed_seq_params, num_samples * 2) @@ -80,15 +79,11 @@ def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed return loss, metric def forward_step(self, data_iterator, model): - timers = get_timers() # Get the batch. unwrapped_model = model.module.module input_tensor = unwrapped_model.get_input_tensor() vp_stage = unwrapped_model.vp_stage - timers('batch-generator', log_level=2).start() - with self.stimer(bdata=True): - data = self.get_batch(data_iterator, vp_stage) - timers('batch-generator').stop() + data = self.get_batch(data_iterator, vp_stage) data.pop('loss_scale', None) # ref_model with torch.no_grad(), self.null_ref_context() as ref_models: @@ -99,7 +94,6 @@ def forward_step(self, data_iterator, model): if input_tensor is not None: unwrapped_model.set_input_tensor(input_tensor[input_tensor.shape[0] // 2:]) - with self.stimer: - output_tensor = model(**data) + output_tensor = model(**data) return torch.concat([ref_output_tensor, output_tensor], dim=0), partial( self.loss_func, labels=data.get('labels'), packed_seq_params=data.get('packed_seq_params')) diff --git a/swift/megatron/trainers/embedding_trainer.py b/swift/megatron/trainers/embedding_trainer.py index f7de1af4ec..a081d2ef3b 100644 --- a/swift/megatron/trainers/embedding_trainer.py +++ b/swift/megatron/trainers/embedding_trainer.py @@ -2,7 +2,6 @@ from functools import partial import torch.nn -from megatron.training import get_args, get_timers from swift.loss import loss_map from swift.metrics import eval_metrics_map @@ -35,17 +34,11 @@ def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed return loss, metric def forward_step(self, data_iterator, model): - timers = get_timers() - # Get the batch. vp_stage = model.module.module.vp_stage - timers('batch-generator', log_level=2).start() - with self.stimer(bdata=True): - data = self.get_batch(data_iterator, vp_stage) - timers('batch-generator').stop() + data = self.get_batch(data_iterator, vp_stage) labels = data.pop('labels', None) - with self.stimer: - output_tensor = model(**data) + output_tensor = model(**data) packed_seq_params = data.get('packed_seq_params') loss_func = partial(self.loss_func, labels=labels, packed_seq_params=packed_seq_params) return output_tensor, loss_func diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index 94b7a2c663..e4ce8ba62e 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -9,8 +9,7 @@ import torch.nn.functional as F from megatron.core import mpu from megatron.core.rerun_state_machine import RerunDataIterator -from megatron.training import get_args, get_model, get_timers -from megatron.training.utils import unwrap_model +from megatron.core.utils import unwrap_model from transformers import AutoConfig from swift.megatron.arguments import MegatronArguments @@ -96,7 +95,7 @@ def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **k return super().setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) def _load_teacher_model(self, teacher_model_path: str, model_type: str): - megatron_args = get_args() + megatron_args = self.args vp_size = getattr(megatron_args, 'virtual_pipeline_model_parallel_size') assert vp_size is None or vp_size == 1, 'GKD currently does not support VPP.' teacher_model_info, _ = get_model_info_meta( @@ -233,7 +232,7 @@ def _teacher_args_context(self): This is necessary for teacher model forward to use correct hidden_size, num_layers, etc. when the teacher has a different architecture than the student. """ - megatron_args = get_args() + megatron_args = self.args # Save original values and override with teacher config original_values = {} @@ -292,7 +291,7 @@ def _template_context(self, template: Template, max_length: Optional[int] = None def _encode_batch(self, batch: List[Dict]) -> Dict[str, torch.Tensor]: """Encode a batch of raw data into model inputs.""" template = self.template - args = get_args() + args = self.args max_length = template.max_length + self.max_completion_length with self._template_context(template, max_length=max_length): encoded_list = [template.encode(data, return_length=True) for data in batch] @@ -357,7 +356,7 @@ def _init_resample_data_iterator(self): """ from megatron.training.training import build_train_valid_test_data_iterators from megatron.training.initialize import _set_random_seed - args = get_args() + args = self.args resample_seed = getattr(args, 'seed', 42) + 1 try: @@ -533,7 +532,7 @@ def generalized_jsd_loss( beta: float = 0.5, chunk_size: int = 512, ) -> torch.Tensor: - args = get_args() + args = self.args mask = labels != -100 local_num_valid = mask.sum() num_valid = local_num_valid.float() @@ -641,7 +640,7 @@ def loss_func(self, # Add SFT loss if enabled (skip for student-generated responses) if self.sft_alpha > 0 and data_source != DataSource.STUDENT: - args = get_args() + args = self.args logits_sbv = student_logits.transpose(0, 1).contiguous() per_token_loss = self.unwrapped_models[0].compute_language_model_loss(labels, logits_sbv) @@ -669,28 +668,21 @@ def loss_func(self, return loss, metric def forward_step(self, data_iterator, model): - - timers = get_timers() - unwrapped_model = model.module.module input_tensor = unwrapped_model.get_input_tensor() vp_stage = unwrapped_model.vp_stage - timers('batch-generator', log_level=2).start() - with self.stimer(bdata=True): - data = next(data_iterator) - data_source = data.pop('data_source', DataSource.DATASET) - teacher_logits = data.pop('teacher_logits', None) - data = self._prepare_batch(data, vp_stage) - timers('batch-generator').stop() + data = next(data_iterator) + data_source = data.pop('data_source', DataSource.DATASET) + teacher_logits = data.pop('teacher_logits', None) + data = self._prepare_batch(data, vp_stage) data.pop('loss_scale', None) labels = data.pop('labels', None) if input_tensor is not None: unwrapped_model.set_input_tensor(input_tensor) - with self.stimer: - student_output = model(**data) + student_output = model(**data) return student_output, partial( self.loss_func, labels=labels, teacher_logits=teacher_logits, data_source=data_source) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index f9e7fbfd19..9e161d85a9 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -20,7 +20,6 @@ from dacite import from_dict from megatron.core import mpu from megatron.core.rerun_state_machine import RerunDataIterator -from megatron.training import get_args, get_wandb_writer, training from swift.dataset import RowPreprocessor from swift.infer_engine.protocol import RequestConfig, RolloutInferRequest, RolloutOutput @@ -39,6 +38,8 @@ from .utils import gather, gather_object, get_swift_datasets_provider, profiling_context, profiling_decorator from .vocab_parallel_utils import compute_logps_and_entropy_from_logits +# from megatron.training import get_wandb_writer, training + if is_wandb_available(): import wandb @@ -243,7 +244,7 @@ def _init_resample_data_iterator(self): from megatron.training.initialize import _set_random_seed from megatron.training import training training.cyclic_iter = self._origin_cyclic_iter - args = get_args() + args = self.args() train_valid_test_dataset_provider = self._train_valid_test_dataset_provider # Use different seed for resample iterator (offset by 1 to avoid overlap) @@ -308,7 +309,7 @@ def _batch_encode(self, infer_requests: List[Dict], template: Template, strict: return batched_inputs, error_list def _get_encoded_batch(self, encoded_list, rollout_batch, template): - args = get_args() + args = self.args encoded_batch = to_device(template.data_collator(encoded_list, padding_to=get_padding_to(args)), self.device) labels = encoded_batch['labels'] @@ -1080,7 +1081,7 @@ def patch_megatron_data_collator(self, data_collator): origin_build_pretraining_data_loader = training.build_pretraining_data_loader def build_pretraining_data_loader(*_args, **kwargs): - args = get_args() + args = self.args org_micro_batch_size = args.micro_batch_size # args.micro_batch_size = org_micro_batch_size // self.num_generations res = origin_build_pretraining_data_loader(*_args, **kwargs) @@ -1097,7 +1098,7 @@ def build_pretraining_data_loader(*_args, **kwargs): @profiling_decorator def forward_step(self, data_iterator, model): - args = get_args() + args = self.args data = next(data_iterator) advantages = data.pop('advantages') truncated_mask = data.pop('truncated_mask') @@ -1122,8 +1123,7 @@ def forward_step(self, data_iterator, model): if self.compute_entropy: # Forward without labels to get logits, then compute logps and entropy inputs_for_logits = {k: v for k, v in inputs.items() if k != 'labels'} - with self.stimer: - output_tensor = model(**inputs_for_logits) + output_tensor = model(**inputs_for_logits) # Compute per_token_logps and per_token_entropy from logits on PP last stage if is_pp_last_stage and output_tensor is not None: @@ -1160,8 +1160,7 @@ def forward_step(self, data_iterator, model): data['per_token_entropy'] = per_token_entropy else: # Standard forward with labels, returns per-token loss (more efficient) - with self.stimer: - output_tensor = model(**inputs) + output_tensor = model(**inputs) # Convert output_tensor (per-token loss) to per_token_logps on PP last stage if is_pp_last_stage and output_tensor is not None: @@ -1188,7 +1187,7 @@ def forward_step(self, data_iterator, model): @profiling_decorator def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): - args = get_args() + args = self.args # Get pre-padded data in batch format [batch_size, max_seq_len] advantages = data['advantages'] # [batch_size] completion_mask = data['completion_mask'] # [batch_size, max_seq_len] @@ -1460,7 +1459,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): } self.jsonl_writer.append(table) wandb_writer = get_wandb_writer() - args = get_args() + args = self.args if wandb_writer: if args.report_to == 'wandb': df = pd.DataFrame(table) @@ -1496,8 +1495,7 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False): data dict containing 'logps' """ # used to calculate model forward (logps) in GRPO - with self.stimer(bdata=True): - data = self.get_batch(data_iterator) + data = self.get_batch(data_iterator) data.pop('loss_scale', None) input_ids = data.get('input_ids') labels = data.get('labels') @@ -2070,7 +2068,7 @@ def _collect_config_info(self) -> Dict[str, str]: return config def get_trainer_state(self): - args = get_args() + args = self.args self.state.update( global_step=getattr(args, 'curr_iteration', 0) or 0, max_steps=getattr(args, 'train_iters', 0) or 0) return self.state diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index d6c19baab1..ecf1bdbf4b 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -5,7 +5,6 @@ import torch from megatron.core import mpu -from megatron.training import get_args, get_timers from trl import KTOTrainer from swift.utils import get_current_device, get_logger @@ -113,16 +112,12 @@ def _get_input_tensor(input_tensor, is_KL: bool, is_ref: bool, length: int, dim: return res def forward_step(self, data_iterator, model): - timers = get_timers() # Get the batch. unwrapped_model = model.module.module input_tensor = unwrapped_model.get_input_tensor() vp_stage = unwrapped_model.vp_stage - timers('batch-generator', log_level=2).start() - with self.stimer(bdata=True): - # not support loss_scale - data, kl_data = self.get_batch(data_iterator, vp_stage) - timers('batch-generator').stop() + # not support loss_scale + data, kl_data = self.get_batch(data_iterator, vp_stage) label = data.pop('label') data.pop('loss_scale', None) kl_data.pop('loss_scale', None) @@ -149,8 +144,7 @@ def forward_step(self, data_iterator, model): if input_tensor is not None: unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False, length, 0)) - with self.stimer: - output_tensor = model(**data) + output_tensor = model(**data) if self.mcore_013: is_pp_last_stage = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage) else: diff --git a/swift/megatron/trainers/reranker_trainer.py b/swift/megatron/trainers/reranker_trainer.py index 115d3c70ad..7763914025 100644 --- a/swift/megatron/trainers/reranker_trainer.py +++ b/swift/megatron/trainers/reranker_trainer.py @@ -3,7 +3,6 @@ from functools import partial import torch.nn -from megatron.training import get_args, get_timers from swift.loss import loss_map from swift.metrics import eval_metrics_map @@ -63,17 +62,11 @@ def setup_model_and_optimizer(self, *_args, **kwargs): return res def forward_step(self, data_iterator, model): - timers = get_timers() - # Get the batch. vp_stage = model.module.module.vp_stage - timers('batch-generator', log_level=2).start() - with self.stimer(bdata=True): - data = self.get_batch(data_iterator, vp_stage) - timers('batch-generator').stop() + data = self.get_batch(data_iterator, vp_stage) labels = data.pop('labels', None) - with self.stimer: - output_tensor = model(**data) + output_tensor = model(**data) packed_seq_params = data.get('packed_seq_params') loss_func = partial(self.loss_func, labels=labels, packed_seq_params=packed_seq_params) return output_tensor, loss_func diff --git a/swift/megatron/trainers/reward_trainer.py b/swift/megatron/trainers/reward_trainer.py index 8d6e11fb3c..66194723a1 100644 --- a/swift/megatron/trainers/reward_trainer.py +++ b/swift/megatron/trainers/reward_trainer.py @@ -2,7 +2,6 @@ from functools import partial import torch -from megatron.training import get_args, get_timers from torch import nn from swift.utils import get_logger @@ -45,15 +44,9 @@ def loss_func(self, output_tensor, *, data): return loss, metric def forward_step(self, data_iterator, model): - timers = get_timers() - # Get the batch. vp_stage = model.module.module.vp_stage - timers('batch-generator', log_level=2).start() - with self.stimer(bdata=True): - data = self.get_batch(data_iterator, vp_stage) - timers('batch-generator').stop() + data = self.get_batch(data_iterator, vp_stage) data.pop('loss_scale', None) - with self.stimer: - output_tensor = model(**data) + output_tensor = model(**data) return output_tensor, partial(self.loss_func, data=data) diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index edc5207839..49e996fbf5 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -3,12 +3,12 @@ import torch from megatron.core import mpu -from megatron.training import get_args, get_model -from megatron.training.checkpointing import load_checkpoint -from megatron.training.utils import unwrap_model +# from megatron.training import get_args, get_model +from megatron.core.utils import unwrap_model from torch.distributed.nn import all_reduce from transformers.utils import ContextManagers +from swift.megatron.utils import load_mcore_checkpoint from swift.utils import get_logger from .base import BaseMegatronTrainer @@ -30,7 +30,7 @@ def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **k self.bridge.load_weights(m, args.ref_model) m.requires_grad_(False).eval() if args.ref_load: - load_checkpoint(ref_models, None, None, load_arg='ref_load') + load_mcore_checkpoint(ref_models, None, None, load_arg='ref_load') self.ref_models = ref_models return super().setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index 1779ca1508..6808f76cc0 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -6,7 +6,6 @@ import torch.nn from megatron.core import mpu from megatron.core.rerun_state_machine import get_rerun_state_machine -from megatron.training import get_args, get_timers from torch.distributed.nn import all_reduce from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -54,7 +53,7 @@ def loss_func(self, loss_scale: Optional[torch.Tensor] = None, channels: Optional[List[str]] = None, packed_seq_params=None): - args = get_args() + args = self.args losses = output_tensor.float() loss_mask = labels != -100 @@ -134,21 +133,15 @@ def loss_func(self, ) def forward_step(self, data_iterator, model): - timers = get_timers() - # Get the batch. vp_stage = model.module.module.vp_stage - timers('batch-generator', log_level=2).start() - with self.stimer(bdata=True): - data = self.get_batch(data_iterator, vp_stage) - timers('batch-generator').stop() + data = self.get_batch(data_iterator, vp_stage) loss_scale = data.pop('loss_scale', None) channels = data.pop('channel', None) labels = data.get('labels') if self.args.task_type == 'seq_cls': data.pop('labels', None) - with self.stimer: - output_tensor = model(**data) + output_tensor = model(**data) packed_seq_params = data.get('packed_seq_params') if self.args.task_type == 'seq_cls': loss_func = partial( diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 7b3cf68645..1b23d7a13b 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -13,7 +13,7 @@ from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.optimizer import ChainedOptimizer from megatron.core.packed_seq_params import PackedSeqParams -from megatron.training import get_args, get_wandb_writer +# from megatron.training import get_wandb_writer from packaging import version from transformers.utils import is_torch_npu_available @@ -21,11 +21,6 @@ from swift.utils import get_packed_seq_params as _get_packed_seq_params from swift.utils import to_device -try: - from megatron.training.datasets.data_samplers import RandomSeedDataset -except ImportError: - from megatron.legacy.data.data_samplers import RandomSeedDataset - mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') logger = get_logger() @@ -442,9 +437,6 @@ def __iter__(self): current_epoch_samples = self.consumed_samples % active_total_samples assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 - if isinstance(self.dataset, RandomSeedDataset): - self.dataset.set_epoch(self.epoch) - if self.shuffle: # data sharding and random sampling if self.data_sharding: From 1a40342d6311f0af2a5b24dee849082bbdafdd66 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 3 Feb 2026 19:41:09 +0800 Subject: [PATCH 14/43] update --- swift/megatron/pipelines/train/sft.py | 19 ++++--------- swift/megatron/trainers/base.py | 36 +++++++++---------------- swift/pipelines/train/sft.py | 39 ++++++++++++++------------- swift/trainers/mixin.py | 22 ++------------- swift/trainers/utils.py | 18 +++++++++++++ 5 files changed, 59 insertions(+), 75 deletions(-) diff --git a/swift/megatron/pipelines/train/sft.py b/swift/megatron/pipelines/train/sft.py index 54c4883d88..0f0dbcb8e2 100644 --- a/swift/megatron/pipelines/train/sft.py +++ b/swift/megatron/pipelines/train/sft.py @@ -9,7 +9,6 @@ from swift.megatron.arguments import MegatronSftArguments from swift.megatron.trainers import MegatronEmbeddingTrainer, MegatronRerankerTrainer, MegatronTrainer -from swift.megatron.utils import get_padding_to from swift.pipelines import SwiftSft from swift.utils import get_logger, is_last_rank, plot_images from .utils import build_streaming_dataloader @@ -60,25 +59,17 @@ def __init__(self, args: Optional[Union[List[str], MegatronSftArguments]] = None self.template.use_megatron = True self.trainer = self.prepare_trainer() - def _get_data_collator(self): - data_collator = self.template.data_collator - padding_to = get_padding_to(self.args) - logger.info(f'padding_to: {padding_to}') - data_collator = partial(data_collator, padding_to=padding_to) - return data_collator - def run(self): args = self.args train_dataset, val_dataset = self._prepare_dataset() - data_collator = self._get_data_collator() - if args.streaming: - train_dataset = build_streaming_dataloader(args, train_dataset, data_collator) - if val_dataset is not None: - val_dataset = build_streaming_dataloader(args, val_dataset, data_collator) + # if args.streaming: + # train_dataset = build_streaming_dataloader(args, train_dataset, data_collator) + # if val_dataset is not None: + # val_dataset = build_streaming_dataloader(args, val_dataset, data_collator) try: - self.trainer.train(train_dataset, val_dataset, data_collator) + self.trainer.train(train_dataset, val_dataset) finally: # Visualization if is_last_rank(): diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 35b789fdfa..67a689d978 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -38,11 +38,12 @@ from tqdm.auto import tqdm from swift.megatron.tuners import LoraParallelLinear -from swift.megatron.utils import (adapter_state_dict_context, copy_original_module_weight, load_mcore_checkpoint, - patch_merge_fn, prepare_mcore_model) +from swift.megatron.utils import (adapter_state_dict_context, copy_original_module_weight, get_padding_to, + load_mcore_checkpoint, patch_merge_fn, prepare_mcore_model) from swift.metrics import MeanMetric from swift.template import Template from swift.trainers import SwiftMixin, dynamic_gradient_checkpointing +from swift.trainers.utils import patch_modelscope_hub_timeout from swift.utils import JsonlWriter, deep_getattr, format_time, get_last_valid_indices, get_logger, ms_logger_context from .utils import (MegatronPretrainingRandomSampler, get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, get_swift_datasets_provider) @@ -73,10 +74,10 @@ def __init__(self, args, template: Template): logging_path = os.path.join(args.save, 'logging.jsonl') logger.info(f'logging_path: {logging_path}') self.jsonl_writer = JsonlWriter(logging_path, enable_async=True, write_on_rank='last') # for evaluate - self._patch_megatron() + # self._patch_megatron() - if args.check_model and hasattr(args, 'model_info') and hasattr(args.model_info, 'model_dir'): - with ms_logger_context(logging.CRITICAL), self._patch_timeout(): + if args.check_model and hasattr(args, 'model_dir'): + with ms_logger_context(logging.CRITICAL), patch_modelscope_hub_timeout(): config_info = self._collect_config_info() config_info.update({ 'invoked_by': 'local_trainer', @@ -95,6 +96,13 @@ def _get_mean_metric(): } self.mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + def _get_data_collator(self): + data_collator = self.template.data_collator + padding_to = get_padding_to(self.args) + logger.info(f'padding_to: {padding_to}') + data_collator = partial(data_collator, padding_to=padding_to) + return data_collator + @property def bridge(self): if self._bridge is None: @@ -1330,24 +1338,6 @@ def get_batch(self, data_iterator, vp_stage=None): """Generate a batch.""" return self._prepare_batch(next(data_iterator), vp_stage) - @contextmanager - def _patch_timeout(self): - from modelscope.hub.api import HubApi - __init__ = HubApi.__init__ - - def __new_init__(self, *args, **kwargs): - timeout = kwargs.get('timeout') - if timeout is not None and timeout > 5: - kwargs['timeout'] = 5 - __init__(self, *args, **kwargs) - - HubApi.__init__ = __new_init__ - - try: - yield - finally: - HubApi.__init__ = __init__ - def _collect_config_info(self) -> Dict[str, str]: """ Collects trainer-specific configuration details. diff --git a/swift/pipelines/train/sft.py b/swift/pipelines/train/sft.py index 2a79d6e2ef..f2004deb91 100644 --- a/swift/pipelines/train/sft.py +++ b/swift/pipelines/train/sft.py @@ -239,28 +239,31 @@ def _save_trainer_state(self, trainer): append_to_jsonl(jsonl_path, self.train_msg, strict=False) return self.train_msg + def _get_resume_checkpoint(self, trainer): + if self.args.resume_from_checkpoint: + return self.args.resume_from_checkpoint + resume_checkpoint = None + # If flash checkpoint is enabled, try to resume from the last complete checkpoint. + # If the previous training finished, resume_checkpoint stays None. + if self.args.use_flash_ckpt: + # resume_checkpoint = /checkpoint- + resume_checkpoint = trainer.get_resume_checkpoint() + + # Elastic runs require a universal checkpoint; fall back when missing or incomplete. + callbacks = set(getattr(self.args, 'callbacks', [])) + elastic_enabled = 'deepspeed_elastic' in callbacks + if elastic_enabled and (resume_checkpoint is None + or not os.path.exists(os.path.join(resume_checkpoint, 'latest_universal'))): + # get_resume_checkpoint_until_find_ucp returns /checkpoint- with latest_universal, + # or None; when None, no universal checkpoint exists and training starts from scratch. + resume_checkpoint = trainer.get_resume_checkpoint_until_find_ucp() + return resume_checkpoint + def train(self, trainer): logging_path = os.path.join(trainer.args.output_dir, 'logging.jsonl') logger.info(f'The logging file will be saved in: {logging_path}') + resume_checkpoint = self._get_resume_checkpoint(trainer) try: - - resume_checkpoint = None - callbacks = set(getattr(self.args, 'callbacks', [])) - elastic_enabled = 'deepspeed_elastic' in callbacks - # If flash checkpoint is enabled, try to resume from the last complete checkpoint. - # If the previous training finished, resume_checkpoint stays None. - if self.args.use_flash_ckpt: - # resume_checkpoint = /checkpoint- - resume_checkpoint = trainer.get_resume_checkpoint() - # Elastic runs require a universal checkpoint; fall back when missing or incomplete. - if elastic_enabled and (resume_checkpoint is None - or not os.path.exists(os.path.join(resume_checkpoint, 'latest_universal'))): - # get_resume_checkpoint_until_find_ucp returns /checkpoint- with latest_universal, - # or None; when None, no universal checkpoint exists and training starts from scratch. - resume_checkpoint = trainer.get_resume_checkpoint_until_find_ucp() - # Explicit user override always wins. - if self.args.resume_from_checkpoint: - resume_checkpoint = self.args.resume_from_checkpoint trainer.train(resume_checkpoint) finally: res = self._save_trainer_state(trainer) diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 15ea282f40..b607d05679 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -54,7 +54,7 @@ from . import patcher from .arguments import TrainingArguments from .utils import (can_return_loss, dynamic_gradient_checkpointing, find_labels, get_function, get_resume_dir, - is_instance_of_ms_model, replace_index_file) + is_instance_of_ms_model, patch_modelscope_hub_timeout, replace_index_file) try: from trl import AutoModelForCausalLMWithValueHead @@ -85,7 +85,7 @@ def __init__(self, self.task_type = self.template.task_type self.problem_type = getattr(model.config, 'problem_type', None) if args.check_model and hasattr(model, 'model_dir'): - with ms_logger_context(logging.CRITICAL), self._patch_timeout(): + with ms_logger_context(logging.CRITICAL), patch_modelscope_hub_timeout(): config_info = self._collect_config_info() config_info.update({ 'invoked_by': 'local_trainer', @@ -150,24 +150,6 @@ def _add_callbacks(self): for callback in self.args.callbacks: self.add_callback(callbacks_map[callback](self.args, self)) - @contextmanager - def _patch_timeout(self): - from modelscope.hub.api import HubApi - __init__ = HubApi.__init__ - - def __new_init__(self, *args, **kwargs): - timeout = kwargs.get('timeout') - if timeout is not None and timeout > 5: - kwargs['timeout'] = 5 - __init__(self, *args, **kwargs) - - HubApi.__init__ = __new_init__ - - try: - yield - finally: - HubApi.__init__ = __init__ - def _collect_config_info(self) -> Dict[str, str]: """ Collects trainer-specific configuration details. diff --git a/swift/trainers/utils.py b/swift/trainers/utils.py index de611290a4..bab68d338a 100644 --- a/swift/trainers/utils.py +++ b/swift/trainers/utils.py @@ -314,3 +314,21 @@ def replace_index_file(output_dir: str): from contextlib import suppress with suppress(FileNotFoundError): os.remove(os.path.join(output_dir, WEIGHTS_INDEX_NAME)) + + +def patch_modelscope_hub_timeout(): + from modelscope.hub.api import HubApi + __init__ = HubApi.__init__ + + def __new_init__(self, *args, **kwargs): + timeout = kwargs.get('timeout') + if timeout is not None and timeout > 5: + kwargs['timeout'] = 5 + __init__(self, *args, **kwargs) + + HubApi.__init__ = __new_init__ + + try: + yield + finally: + HubApi.__init__ = __init__ From 349f6e5302c6f0233b40148bd84ef3f750f0d606 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 3 Feb 2026 19:57:04 +0800 Subject: [PATCH 15/43] fix --- swift/megatron/arguments/megatron_args.py | 2 ++ swift/megatron/convert.py | 2 -- swift/megatron/pipelines/export/export.py | 2 -- swift/megatron/trainers/base.py | 2 +- swift/trainers/utils.py | 1 + 5 files changed, 4 insertions(+), 5 deletions(-) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index d0f693273a..5bf1388237 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -12,6 +12,7 @@ from transformers.utils.versions import require_version from swift.arguments import ModelArguments +from swift.megatron.utils import initialize_megatron from swift.model import get_model_info_meta from swift.utils import get_dist_setting, get_logger, json_parse_to_dict @@ -784,6 +785,7 @@ def __post_init__(self): self._init_mixed_precision() self._init_apply_rope_fusion() + initialize_megatron(self) def _init_apply_rope_fusion(self): if self.apply_rope_fusion is not None: diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 70b39d3767..ffe8099fcf 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -59,7 +59,6 @@ def convert_hf2mcore(args: ExportArguments) -> None: **current_convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype) - initialize_megatron(megatron_args) mg_model = megatron_model_meta.model_provider(megatron_args) logger.info('Megatron model created successfully.') @@ -103,7 +102,6 @@ def convert_mcore2hf(args: ExportArguments) -> None: **current_convert_kwargs, save=args.output_dir if args.to_mcore else None, torch_dtype=args.torch_dtype) - initialize_megatron(megatron_args) mg_model = megatron_model_meta.model_provider(megatron_args) if megatron_args.load is None: diff --git a/swift/megatron/pipelines/export/export.py b/swift/megatron/pipelines/export/export.py index 4f703380d7..46febb9a9d 100644 --- a/swift/megatron/pipelines/export/export.py +++ b/swift/megatron/pipelines/export/export.py @@ -37,7 +37,6 @@ def convert_mcore2hf(self) -> None: hf_config = self.processor.model_info.config args.init_model_args(self.tokenizer, hf_config) megatron_model_meta = args.megatron_model_meta - initialize_megatron(args) pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() @@ -89,7 +88,6 @@ def convert_hf2mcore(self) -> None: self.processor = template.processor args.init_model_args(self.tokenizer, self.processor.model_info.config) megatron_model_meta = args.megatron_model_meta - initialize_megatron(args) pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 67a689d978..e50ffb5708 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -1223,7 +1223,7 @@ def _init_multimodal_full(self): if args.trainable_parameters: logger.info(f'additional trainable_parameters: {args.trainable_parameters}') - def train(self, train_dataset, val_dataset, data_collator): + def train(self, train_dataset, val_dataset): args = self.args datasets_provider = get_swift_datasets_provider(train_dataset, val_dataset) datasets_provider.is_distributed = True diff --git a/swift/trainers/utils.py b/swift/trainers/utils.py index bab68d338a..8ec107e6f8 100644 --- a/swift/trainers/utils.py +++ b/swift/trainers/utils.py @@ -316,6 +316,7 @@ def replace_index_file(output_dir: str): os.remove(os.path.join(output_dir, WEIGHTS_INDEX_NAME)) +@contextmanager def patch_modelscope_hub_timeout(): from modelscope.hub.api import HubApi __init__ = HubApi.__init__ From 81e64278cd27299307b0b545d39f9f647f685113 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 4 Feb 2026 10:54:53 +0800 Subject: [PATCH 16/43] update --- swift/megatron/arguments/megatron_args.py | 184 +--------------------- swift/megatron/arguments/model_args.py | 117 ++++++++++++++ swift/megatron/model/model_provider.py | 23 +-- swift/megatron/model/register.py | 4 - swift/megatron/utils/config.py | 2 +- swift/megatron/utils/megatron_lm_utils.py | 1 - 6 files changed, 129 insertions(+), 202 deletions(-) create mode 100644 swift/megatron/arguments/model_args.py diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 5bf1388237..0f67221c62 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -18,12 +18,12 @@ mcore_015 = version.parse(megatron.core.__version__) >= version.parse('0.15.0rc0') logger = get_logger() -MAX_NPU_EXPERTS_PER_EP = 128 @dataclass class RLHFMegatronArgumentsMixin: rlhf_type: Literal['dpo', 'kto', 'grpo', 'gkd', 'rm'] = None + loss_type: Optional[str] = None # rlhf / plugins ref_load: Optional[str] = None ref_adapter_load: Optional[str] = None @@ -167,9 +167,6 @@ class RLHFMegatronArgumentsMixin: num_iterations: int = 1 - # dataset - dataset_shuffle: Optional[bool] = True - def _init_kto(self): if self.calculate_KL is None: # Not all losses require a KL calculation @@ -321,10 +318,7 @@ def __post_init__(self): @dataclass class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): - loss_type: Optional[str] = None # rlhf / plugins - check_model: bool = True - padded_vocab_size: Optional[int] = None initialize_embedding: bool = False rope_scaling: Optional[Union[dict, str]] = None torch_dtype: Optional[Union[torch.dtype, str]] = None @@ -352,8 +346,6 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): dataloader_prefetch_factor: int = 2 group_by_length: bool = False - hf_model_type: Optional[str] = None - llm_model_type: Optional[str] = None max_epochs: Optional[int] = None enable_dft_loss: bool = False enable_channel_loss: bool = False @@ -362,10 +354,6 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): problem_type: Literal['regression', 'single_label_classification', 'multi_label_classification'] = None save_strategy: Literal['steps', 'epoch'] = 'steps' - original_max_position_embeddings: Optional[int] = None - partial_rotary_factor: Optional[float] = None - use_shared_expert_gate: Optional[bool] = None - report_to: Optional[Literal['wandb', 'swanlab']] = None # visual @@ -373,17 +361,6 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): vit_lr: Optional[float] = None aligner_lr: Optional[float] = None gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None - # qwen3_next - linear_num_value_heads: Optional[int] = None - linear_num_key_heads: Optional[int] = None - linear_key_head_dim: Optional[int] = None - linear_value_head_dim: Optional[int] = None - linear_conv_kernel_dim: Optional[int] = None - layer_types: Optional[List[str]] = None - apply_wd_to_qk_layernorm: bool = False - apply_layernorm_1p: bool = False - # qwen3_vl, qwen3_omni - mrope_interleaved: Optional[bool] = None @staticmethod def load_args_config(ckpt_dir: Optional[str]) -> Dict[str, Any]: @@ -439,7 +416,6 @@ class MegatronArguments(ExtraMegatronArguments): main_params_dtype: Literal['fp32', 'fp16'] = 'fp32' exp_avg_dtype: Literal['fp32', 'fp16', 'bf16', 'fp8'] = 'fp32' exp_avg_sq_dtype: Literal['fp32', 'fp16', 'bf16', 'fp8'] = 'fp32' - dataloader_type: Literal['single', 'cyclic', 'external'] = 'cyclic' manual_gc: bool = False manual_gc_interval: int = 0 @@ -463,17 +439,15 @@ class MegatronArguments(ExtraMegatronArguments): # checkpoint save: Optional[str] = None save_interval: int = 500 - save_retain_interval: Optional[int] = None + # save_retain_interval: Optional[int] = None no_save_optim: bool = False no_save_rng: bool = False load: Optional[str] = None no_load_optim: bool = False no_load_rng: bool = False finetune: bool = False - ckpt_format: Literal['torch', 'torch_dist', 'zarr'] = 'torch_dist' + # ckpt_format: Literal['torch', 'torch_dist', 'zarr'] = 'torch_dist' perform_initialization: bool = False - auto_detect_ckpt_format: bool = True - exit_on_missing_checkpoint: bool = True async_save: bool = False use_persistent_ckpt_worker: bool = False ckpt_fully_parallel_load: bool = False @@ -501,57 +475,6 @@ class MegatronArguments(ExtraMegatronArguments): microbatch_group_size_per_virtual_pipeline_stage: Optional[int] = None pipeline_model_parallel_layout: Optional[str] = None - # model - num_layers: Optional[int] = None - hidden_size: Optional[int] = None - ffn_hidden_size: Optional[int] = None - num_attention_heads: Optional[int] = None - group_query_attention: Optional[bool] = None - num_query_groups: Optional[int] = None - softmax_type: Optional[Literal['vanilla', 'off-by-one', 'learnable']] = None - window_size: Optional[str] = None - window_attn_skip_freq: Optional[str] = None - max_position_embeddings: Optional[int] = None - position_embedding_type: Optional[Literal['learned_absolute', 'rope', 'mrope', 'relative', 'none']] = None - mrope_section: Optional[List[int]] = None - rotary_base: Optional[int] = None - rotary_percent: float = 1. - rotary_interleaved: Optional[bool] = None - normalization: Literal['LayerNorm', 'RMSNorm'] = 'RMSNorm' - norm_epsilon: Optional[float] = None - swiglu: Optional[bool] = None - quick_geglu: Optional[bool] = None - activation_func_clamp_value: Optional[float] = None - glu_linear_offset: Optional[float] = None - untie_embeddings_and_output_weights: Optional[bool] = None - add_bias_linear: Optional[bool] = None - add_qkv_bias: Optional[bool] = None - attention_dropout: Optional[float] = None - hidden_dropout: float = 0. - kv_channels: Optional[int] = None - qk_layernorm: Optional[bool] = None - qk_l2_norm: Optional[bool] = None - no_rope_freq: Optional[int] = None - moe_apply_probs_on_input: Optional[bool] = None - transformer_impl: Literal['local', 'transformer_engine'] = 'transformer_engine' - - # moe - num_experts: Optional[int] = None - moe_layer_freq: Optional[str] = None - moe_ffn_hidden_size: Optional[int] = None - moe_shared_expert_intermediate_size: Optional[int] = None - - moe_router_topk: Optional[int] = None - moe_router_num_groups: Optional[int] = None - moe_router_group_topk: Optional[int] = None - moe_router_pre_softmax: Optional[bool] = None - moe_router_dtype: Literal['none', 'fp32', 'fp64'] = 'fp32' - moe_router_score_function: Literal['sigmoid', 'softmax'] = None - moe_router_bias_update_rate: Optional[float] = None - moe_router_enable_expert_bias: Optional[bool] = None - moe_router_topk_scaling_factor: Optional[float] = None - moe_router_load_balancing_type: Literal['aux_loss', 'seq_aux_loss', 'global_aux_loss', 'sinkhorn', 'none'] = None - expert_model_parallel_size: int = 1 expert_tensor_parallel_size: int = 1 moe_token_dispatcher_type: Literal['allgather', 'alltoall', 'flex', 'alltoall_seq'] = 'alltoall' @@ -565,14 +488,7 @@ class MegatronArguments(ExtraMegatronArguments): moe_expert_capacity_factor: Optional[float] = None moe_pad_expert_input_to_capacity: bool = False moe_token_drop_policy: Literal['probs', 'position'] = 'probs' - - # mla - multi_latent_attention: Optional[bool] = None - q_lora_rank: Optional[int] = None - kv_lora_rank: Optional[int] = None - qk_head_dim: Optional[int] = None - qk_pos_emb_head_dim: Optional[int] = None - v_head_dim: Optional[int] = None + apply_wd_to_qk_layernorm: bool = False # mtp mtp_num_layers: Optional[int] = None @@ -611,7 +527,6 @@ class MegatronArguments(ExtraMegatronArguments): # other seed: int = 42 - seq_length: Optional[int] = None num_workers: int = 4 data_sharding: bool = False @@ -625,75 +540,17 @@ def _set_default(self): self.lr = 1e-5 else: self.lr = 1e-4 - if self.num_query_groups is None: - self.num_query_groups = 1 - if self.softmax_type is None and mcore_015: - self.softmax_type = 'vanilla' - if self.norm_epsilon is None: - self.norm_epsilon = 1e-5 - if self.rotary_base is None: - self.rotary_base = 10000 - else: - self.rotary_base = int(self.rotary_base) - if self.rotary_interleaved is None: - self.rotary_interleaved = False - if self.attention_dropout is None: - self.attention_dropout = 0. - if self.untie_embeddings_and_output_weights is None: - self.untie_embeddings_and_output_weights = True - if self.swiglu is None: - self.swiglu = True - if self.quick_geglu is None: - self.quick_geglu = False - if self.glu_linear_offset is None and mcore_015: - self.glu_linear_offset = 0. - if self.add_qkv_bias is None: - self.add_qkv_bias = True - if self.add_bias_linear is None: - self.add_bias_linear = False - if self.qk_layernorm is None: - self.qk_layernorm = False - if self.multi_latent_attention is None: - self.multi_latent_attention = False - if self.kv_lora_rank is None: - self.kv_lora_rank = 32 - if self.qk_head_dim is None: - self.qk_head_dim = 128 - if self.qk_pos_emb_head_dim is None: - self.qk_pos_emb_head_dim = 64 - if self.v_head_dim is None: - self.v_head_dim = 128 if self.task_type is None: self.task_type = 'causal_lm' if self.calculate_per_token_loss is None: self.calculate_per_token_loss = self.task_type == 'causal_lm' if self.bias_dropout_fusion is None: self.bias_dropout_fusion = True - # moe - MegatronArguments._set_moe_default(self) + # log if self.wandb_exp_name is None: self.wandb_exp_name = self.save - @staticmethod - def _set_moe_default(self): - if self.use_shared_expert_gate is None: - self.use_shared_expert_gate = False - if self.moe_router_score_function is None: - self.moe_router_score_function = 'softmax' - if self.moe_router_topk is None: - self.moe_router_topk = 2 - if self.moe_router_pre_softmax is None: - self.moe_router_pre_softmax = False - if self.moe_router_load_balancing_type is None: - self.moe_router_load_balancing_type = 'aux_loss' - if self.moe_router_enable_expert_bias is None: - self.moe_router_enable_expert_bias = False - if self.moe_layer_freq is None: - self.moe_layer_freq = 1 - if self.mrope_interleaved is None: - self.mrope_interleaved = False - def _init_mixed_precision(self): ModelArguments._init_mixed_precision(self) if self.apply_query_key_layer_scaling is None: @@ -701,37 +558,16 @@ def _init_mixed_precision(self): if self.apply_query_key_layer_scaling: os.environ['NVTE_APPLY_QK_LAYER_SCALING'] = '1' - def _init_moe(self): - if self.moe_router_dtype.lower() == 'none': - self.moe_router_dtype = None - if self.moe_shared_expert_intermediate_size == 0: - self.moe_shared_expert_intermediate_size = None - if self.num_experts is not None: - if self.moe_ffn_hidden_size is None: - self.moe_ffn_hidden_size = self.ffn_hidden_size - if is_torch_npu_available() and self.num_experts > MAX_NPU_EXPERTS_PER_EP: - required_ep = (self.num_experts + MAX_NPU_EXPERTS_PER_EP - 1) // MAX_NPU_EXPERTS_PER_EP - if self.expert_model_parallel_size < required_ep: - logger.warning(f'{">"*20} WARNING {"<"*20}\n' - f'MindSpeed on NPU supports up to {MAX_NPU_EXPERTS_PER_EP} experts per EP group. ' - f'num_experts={self.num_experts}, ' - f'expert_model_parallel_size={self.expert_model_parallel_size}. ' - f'Please set expert_model_parallel_size (EP) to {required_ep} ' - f'(num_experts / {MAX_NPU_EXPERTS_PER_EP}) or higher.') - def __post_init__(self): require_version('numpy<2.0', 'Please install numpy<2.0 by running: `pip install "numpy<2.0"`.') if self.tuner_type == 'lora': - if self.num_experts is not None: - require_version('peft>=0.15') - else: - require_version('peft>=0.12') + require_version('peft>=0.15') RLHFMegatronArgumentsMixin.__post_init__(self) MegatronTunerMixin.__post_init__(self) os.environ.setdefault('CUDA_DEVICE_MAX_CONNECTIONS', '1') if self.recompute_granularity == 'none': self.recompute_granularity = None - if self.apply_wd_to_qk_layernorm and self.hf_model_type != 'qwen3_next': + if self.apply_wd_to_qk_layernorm and self.model_type != 'qwen3_next': raise ValueError('apply_wd_to_qk_layernorm is only supported for qwen3_next') self._set_default() self._init_vpp_size() @@ -746,7 +582,6 @@ def __post_init__(self): 'decoder_first_pipeline_num_layers or decoder_last_pipeline_num_layers.') if hasattr(self, 'ddp_timeout'): self.distributed_timeout_minutes = self.ddp_timeout // 60 - self.group_query_attention = self.num_query_groups > 1 self.fp8 = self.fp8_format # compat megatron-lm if self.rope_scaling is not None: self.rope_scaling = json_parse_to_dict(self.rope_scaling) @@ -769,10 +604,6 @@ def __post_init__(self): self.ref_adapters = [self.ref_adapters] if self.eval_interval is None: self.eval_interval = self.save_interval - if self.seq_length is None: - self.seq_length = self.max_position_embeddings - if self.position_embedding_type is None: - self.position_embedding_type = 'rope' if self.merge_lora is None: self.merge_lora = self.save_safetensors if self.adapters or self.adapter_load or self.ref_adapter_load: @@ -781,7 +612,6 @@ def __post_init__(self): logger.info('Setting args.tuner_type: lora') if self.adapters: self._load_adapter_config() - self._init_moe() self._init_mixed_precision() self._init_apply_rope_fusion() diff --git a/swift/megatron/arguments/model_args.py b/swift/megatron/arguments/model_args.py new file mode 100644 index 0000000000..19cc9c7564 --- /dev/null +++ b/swift/megatron/arguments/model_args.py @@ -0,0 +1,117 @@ +from typing import Optional, Literal, List +from dataclasses import dataclass + + +from swift.utils import get_logger +from transformers.utils import is_torch_npu_available + + +MAX_NPU_EXPERTS_PER_EP = 128 + + +logger = get_logger() + + +@dataclass +class MegatronModelArguments: + hf_model_type: Optional[str] = None + llm_model_type: Optional[str] = None + padded_vocab_size: Optional[int] = None + # model + num_layers: Optional[int] = None + hidden_size: Optional[int] = None + ffn_hidden_size: Optional[int] = None + num_attention_heads: Optional[int] = None + group_query_attention: bool = False + num_query_groups: Optional[int] = None + softmax_type: Literal['vanilla', 'off-by-one', 'learnable'] = 'vanilla' + window_size: Optional[str] = None + window_attn_skip_freq: Optional[str] = None + max_position_embeddings: Optional[int] = None + + position_embedding_type: Optional[Literal['learned_absolute', 'rope', 'mrope', 'relative', 'none']] = None + rotary_base: int = 10000 + rotary_percent: float = 1. + rotary_interleaved: bool = False + original_max_position_embeddings: Optional[int] = None + partial_rotary_factor: Optional[float] = None + mrope_section: Optional[List[int]] = None + # qwen3_vl, qwen3_omni + mrope_interleaved: bool = False + + normalization: Literal['LayerNorm', 'RMSNorm'] = 'RMSNorm' + layernorm_epsilon: float = 1e-5 + swiglu: bool = True + quick_geglu: bool = False + activation_func_clamp_value: Optional[float] = None + glu_linear_offset: float = 0. + untie_embeddings_and_output_weights: bool = True + add_bias_linear: bool = False + add_qkv_bias: bool = True + attention_dropout: float = 0. + hidden_dropout: float = 0. + kv_channels: Optional[int] = None + qk_layernorm: bool = False + qk_l2_norm: Optional[bool] = None + no_rope_freq: Optional[int] = None + moe_apply_probs_on_input: Optional[bool] = None + + # moe + num_experts: Optional[int] = None + moe_layer_freq: str = 1 + moe_ffn_hidden_size: Optional[int] = None + moe_shared_expert_intermediate_size: Optional[int] = None + + moe_router_topk: int = 2 + moe_router_num_groups: Optional[int] = None + moe_router_group_topk: Optional[int] = None + moe_router_pre_softmax: bool = False + moe_router_dtype: Literal['none', 'fp32', 'fp64'] = 'fp32' + moe_router_score_function: Literal['sigmoid', 'softmax'] = 'softmax' + moe_router_bias_update_rate: Optional[float] = None + moe_router_enable_expert_bias: bool = False + moe_router_topk_scaling_factor: Optional[float] = None + moe_router_load_balancing_type: Literal['aux_loss', 'seq_aux_loss', 'global_aux_loss', 'sinkhorn', 'none'] = 'aux_loss' + use_shared_expert_gate: bool = False + + # mla + multi_latent_attention: bool = False + q_lora_rank: Optional[int] = None + kv_lora_rank: int = 32 + qk_head_dim: int = 128 + qk_pos_emb_head_dim: int= 64 + v_head_dim: int = 128 + + # qwen3_next + linear_num_value_heads: Optional[int] = None + linear_num_key_heads: Optional[int] = None + linear_key_head_dim: Optional[int] = None + linear_value_head_dim: Optional[int] = None + linear_conv_kernel_dim: Optional[int] = None + layer_types: Optional[List[str]] = None + # apply_layernorm_1p: bool = False # TODO + + + def __post_init__(self): + if self.num_query_groups is not None and self.num_query_groups > 1: + self.group_query_attention = True + self._init_moe() + + def _init_moe(self): + if self.moe_router_dtype.lower() == 'none': + self.moe_router_dtype = None + if self.moe_shared_expert_intermediate_size == 0: + self.moe_shared_expert_intermediate_size = None + if self.num_experts is not None: + if self.moe_ffn_hidden_size is None: + self.moe_ffn_hidden_size = self.ffn_hidden_size + # TODO: remove + if is_torch_npu_available() and self.num_experts > MAX_NPU_EXPERTS_PER_EP: + required_ep = (self.num_experts + MAX_NPU_EXPERTS_PER_EP - 1) // MAX_NPU_EXPERTS_PER_EP + if self.expert_model_parallel_size < required_ep: + logger.warning(f'{">"*20} WARNING {"<"*20}\n' + f'MindSpeed on NPU supports up to {MAX_NPU_EXPERTS_PER_EP} experts per EP group. ' + f'num_experts={self.num_experts}, ' + f'expert_model_parallel_size={self.expert_model_parallel_size}. ' + f'Please set expert_model_parallel_size (EP) to {required_ep} ' + f'(num_experts / {MAX_NPU_EXPERTS_PER_EP}) or higher.') diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py index b80fdfcd3b..ce5882711d 100644 --- a/swift/megatron/model/model_provider.py +++ b/swift/megatron/model/model_provider.py @@ -7,11 +7,9 @@ from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, get_gpt_mtp_block_spec) -from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import get_gpt_heterogeneous_layer_spec -from megatron.core.transformer.spec_utils import import_module from packaging import version -from swift.megatron.utils import core_transformer_config_from_args +from swift.megatron.utils import core_transformer_config_from_args, convert_hf_config from swift.utils import get_logger logger = get_logger() @@ -33,25 +31,12 @@ def _get_transformer_layer_spec(args): ) -# Code borrowed from NVIDIA/Megatron-LM -def model_provider(args, pre_process=True, post_process=True, vp_stage: Optional[int] = None) -> 'GPTModel': - """Builds the model. - - If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. - - Args: - pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. - post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. - - - Returns: - Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model - """ +def get_mcore_model(args, model_args, pre_process=True, post_process=True, vp_stage: Optional[int] = None) -> 'GPTModel': from .register import get_megatron_model_meta - megatron_model_meta = get_megatron_model_meta(args.hf_model_type) + megatron_model_meta = get_megatron_model_meta(args.model_type) logger.info('building GPT model ...') - config = core_transformer_config_from_args(args) + config = core_transformer_config_from_args(args, model_args) config.variable_seq_lengths = True if megatron_model_meta.get_transformer_layer_spec is not None: transformer_layer_spec = megatron_model_meta.get_transformer_layer_spec(config, vp_stage=vp_stage) diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index e606cbed5b..b2b081d9c8 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -1,16 +1,13 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from argparse import ArgumentParser from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, List, Optional, Type import torch.nn as nn - from swift.model import MODEL_MAPPING from .constant import MLLMMegatronModelType from .gpt_bridge import GPTBridge from .gpt_model import GPTModel from .mm_gpt_model import MultimodalGPTModel -from .model_provider import model_provider as model_provider_func if TYPE_CHECKING: from swift.megatron.arguments import MegatronArguments @@ -27,7 +24,6 @@ class MegatronModelMeta: bridge_cls: Type[GPTBridge] = GPTBridge model_cls: Optional[Type[nn.Module]] = None get_transformer_layer_spec: Optional[Callable] = None - model_provider: Callable[['MegatronArguments'], nn.Module] = model_provider_func visual_cls: Optional[Type[nn.Module]] = None get_mtp_block_spec: Optional[Callable] = None diff --git a/swift/megatron/utils/config.py b/swift/megatron/utils/config.py index 4cb902e478..fa0ce5f3d6 100644 --- a/swift/megatron/utils/config.py +++ b/swift/megatron/utils/config.py @@ -12,7 +12,7 @@ 'num_attention_heads': ['num_attention_heads'], 'num_query_groups': ['num_key_value_heads'], 'max_position_embeddings': ['max_position_embeddings'], - 'norm_epsilon': ['rms_norm_eps'], + 'layernorm_epsilon': ['rms_norm_eps'], 'rotary_base': ['rope_theta'], 'padded_vocab_size': ['vocab_size'], 'attention_dropout': ['attention_dropout'], diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index bc018cfd23..6705155c2e 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -126,7 +126,6 @@ def core_transformer_config_from_args(args, config_class=None): kw_args['persist_layer_norm'] = True # TODO: apply_layernorm_1p kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p - kw_args['layernorm_epsilon'] = args.norm_epsilon kw_args['deallocate_pipeline_outputs'] = True kw_args['pipeline_dtype'] = args.torch_dtype kw_args['batch_p2p_comm'] = True From 7233e9aa5d17bef7322bd888530f7522b3e905f6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 4 Feb 2026 15:13:23 +0800 Subject: [PATCH 17/43] update --- swift/megatron/model/gpt_bridge.py | 2 +- swift/megatron/model/gpt_model.py | 4 +- swift/megatron/model/gpts/glm4.py | 6 +- swift/megatron/model/gpts/minimax_m2.py | 5 +- swift/megatron/model/gpts/olmoe.py | 6 +- swift/megatron/model/gpts/qwen3_next.py | 8 +- swift/megatron/model/mm_gpt_model.py | 8 +- .../model_args.py => model/model_config.py} | 72 +++++---- swift/megatron/model/model_provider.py | 88 ----------- swift/megatron/model/register.py | 143 ++++++++++++++++-- swift/megatron/utils/__init__.py | 3 +- swift/megatron/utils/megatron_lm_utils.py | 50 ------ 12 files changed, 194 insertions(+), 201 deletions(-) rename swift/megatron/{arguments/model_args.py => model/model_config.py} (63%) delete mode 100644 swift/megatron/model/model_provider.py diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 8f1ab86b78..144056ae0a 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -518,7 +518,7 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int hf_state_dict = {} hf_attn = self.hf_layers[layer_idx].self_attn args = self.args - num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) + num_query_groups = (args.num_query_groups if args.num_query_groups is not None else args.num_attention_heads) hidden_size_block = args.hidden_size // self.fp8_block_size if to_mcore: if isinstance(mg_attn.linear_qkv, LoraParallelLinear): diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 9000279f50..be01e1ea73 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -21,11 +21,11 @@ reduce_from_tensor_model_parallel_region) from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import WrappedTensor, deprecate_inference_params from packaging import version from swift.utils import get_logger +from .model_config import MegatronModelConfig from .rope import dynamic_rope_update, get_rope_inv_freq logger = get_logger() @@ -56,7 +56,7 @@ class GPTModel(McoreGPTModel): def __init__( self, - config: TransformerConfig, + config: MegatronModelConfig, transformer_layer_spec: ModuleSpec, vocab_size: int, max_sequence_length: int, diff --git a/swift/megatron/model/gpts/glm4.py b/swift/megatron/model/gpts/glm4.py index 81eef6b3de..c7ac043d2f 100644 --- a/swift/megatron/model/gpts/glm4.py +++ b/swift/megatron/model/gpts/glm4.py @@ -9,13 +9,13 @@ from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.mlp import MLP, apply_swiglu_sharded_factory from megatron.core.transformer.spec_utils import build_module -from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import sharded_state_dict_default from packaging import version from swift.model import ModelType from ..constant import MegatronModelType from ..gpt_bridge import GPTBridge +from ..model_config import MegatronModelConfig from ..register import MegatronModelMeta, register_megatron_model mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -25,7 +25,7 @@ class Glm4SelfAttention(SelfAttention): def __init__( self, - config: TransformerConfig, + config: MegatronModelConfig, *args, **kwargs, ): @@ -48,7 +48,7 @@ class Glm4MLP(MLP): def __init__( self, - config: TransformerConfig, + config: MegatronModelConfig, *args, **kwargs, ): diff --git a/swift/megatron/model/gpts/minimax_m2.py b/swift/megatron/model/gpts/minimax_m2.py index 1ec21df584..2bec049f5e 100644 --- a/swift/megatron/model/gpts/minimax_m2.py +++ b/swift/megatron/model/gpts/minimax_m2.py @@ -7,12 +7,12 @@ from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.spec_utils import build_module -from megatron.core.transformer.transformer_config import TransformerConfig from packaging import version from swift.model import ModelType from ..constant import MegatronModelType from ..gpt_bridge import GPTBridge +from ..model_config import MegatronModelConfig from ..register import MegatronModelMeta, register_megatron_model mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -22,7 +22,7 @@ class MinimaxM2SelfAttention(SelfAttention): def __init__( self, - config: TransformerConfig, + config: MegatronModelConfig, submodules: SelfAttentionSubmodules, *args, **kwargs, @@ -48,7 +48,6 @@ def __init__( ) def get_query_key_value_tensors(self, *_args, **kwargs): - args = get_args() query, key, value = super().get_query_key_value_tensors(*_args, **kwargs) query = query.reshape(*query.shape[:-2], -1) key = key.reshape(*key.shape[:-2], -1) diff --git a/swift/megatron/model/gpts/olmoe.py b/swift/megatron/model/gpts/olmoe.py index 720544d976..612c9453ef 100644 --- a/swift/megatron/model/gpts/olmoe.py +++ b/swift/megatron/model/gpts/olmoe.py @@ -10,7 +10,6 @@ from megatron.core.transformer.attention import SelfAttentionSubmodules from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.transformer_block import TransformerBlockSubmodules, get_num_layers_to_build -from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import get_transformer_layer_offset from packaging import version @@ -18,6 +17,7 @@ from swift.model import ModelType from ..constant import MegatronModelType from ..gpt_bridge import GPTBridge +from ..model_config import MegatronModelConfig from ..register import MegatronModelMeta, register_megatron_model mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -25,7 +25,7 @@ class OLMoESelfAttention(SelfAttentionBase): - def __init__(self, config: TransformerConfig, submodules: SelfAttentionSubmodules, *args, **kwargs): + def __init__(self, config: MegatronModelConfig, submodules: SelfAttentionSubmodules, *args, **kwargs): super().__init__(config, submodules, *args, **kwargs) self.q_layernorm = build_module( submodules.q_layernorm, @@ -74,7 +74,7 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None, *arg def get_olmoe_decoder_block_spec( - config: TransformerConfig, + config: MegatronModelConfig, vp_stage: Optional[int] = None, ) -> TransformerBlockSubmodules: """GPT block spec.""" diff --git a/swift/megatron/model/gpts/qwen3_next.py b/swift/megatron/model/gpts/qwen3_next.py index 5bc665b49f..bea36de69e 100644 --- a/swift/megatron/model/gpts/qwen3_next.py +++ b/swift/megatron/model/gpts/qwen3_next.py @@ -15,7 +15,6 @@ from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.transformer_block import TransformerBlockSubmodules -from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import deprecate_inference_params, is_fa_min_version from packaging import version @@ -24,6 +23,7 @@ from swift.utils import get_logger from ..constant import MegatronModelType from ..gpt_bridge import GPTBridge +from ..model_config import MegatronModelConfig from ..register import MegatronModelMeta, register_megatron_model mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -67,7 +67,7 @@ class Qwen3NextRMSNorm(torch.nn.Module): Interface matches TENorm for compatibility with Megatron-Core build_module. """ - def __init__(self, config: TransformerConfig, hidden_size: int, eps: float = 1e-5): + def __init__(self, config: MegatronModelConfig, hidden_size: int, eps: float = 1e-5): super().__init__() self.config = config self.eps = eps @@ -87,7 +87,7 @@ def forward(self, hidden_states): class Qwen3NextSelfAttention(SelfAttention): - def __init__(self, config: TransformerConfig, submodules: SelfAttentionSubmodules, *args, **kwargs): + def __init__(self, config: MegatronModelConfig, submodules: SelfAttentionSubmodules, *args, **kwargs): super(SelfAttention, self).__init__(config, submodules, *args, attention_type='self', **kwargs) kwargs = {} if mcore_015: @@ -429,7 +429,7 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): class Qwen3NextGatedDeltaNet(_HuggingFaceModule, _Qwen3NextGatedDeltaNet): - def __init__(self, config: TransformerConfig, submodules: SelfAttentionSubmodules, layer_number: int, **kwargs): + def __init__(self, config: MegatronModelConfig, submodules: SelfAttentionSubmodules, layer_number: int, **kwargs): assert config.context_parallel_size == 1, 'Qwen3Next currently does not support context parallel.' assert _Qwen3NextGatedDeltaNet is not object, 'please update the `transformers` version.' _Qwen3NextGatedDeltaNet.__init__(self, config, layer_number) diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index ec0de8033b..1db8de8734 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -3,15 +3,15 @@ import megatron.core import torch -from megatron.core import InferenceParams +from megatron.core import InferenceParams, mpu from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel import VocabParallelEmbedding, reduce_scatter_to_sequence_parallel_region from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_config import TransformerConfig from packaging import version from .gpt_model import GPTModel +from .model_config import MegatronModelConfig mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -19,7 +19,7 @@ class MultimodalGPTModel(MegatronModule): def __init__(self, - config: TransformerConfig, + config: MegatronModelConfig, transformer_layer_spec: ModuleSpec, vocab_size: int, max_sequence_length: int, @@ -35,7 +35,6 @@ def __init__(self, post_process, *args, **kwargs) self.vp_stage = self.language_model.vp_stage self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights - args = get_args() self.megatron_model_meta = get_megatron_model_meta(args.hf_model_type) self.visual = None if args.mtp_num_layers: @@ -49,7 +48,6 @@ def _patch_word_embeddings(self, kwargs): def forward(_self, input_): from ..trainers.utils import split_cp_inputs - args = get_args() reduce_scatter_embeddings = _self.reduce_scatter_embeddings _self.reduce_scatter_embeddings = False input_ = torch.masked_fill(input_, input_ < 0, 0) diff --git a/swift/megatron/arguments/model_args.py b/swift/megatron/model/model_config.py similarity index 63% rename from swift/megatron/arguments/model_args.py rename to swift/megatron/model/model_config.py index 19cc9c7564..5e0ff6fc59 100644 --- a/swift/megatron/arguments/model_args.py +++ b/swift/megatron/model/model_config.py @@ -1,19 +1,18 @@ -from typing import Optional, Literal, List -from dataclasses import dataclass +from dataclasses import dataclass, fields +from typing import List, Literal, Optional +import torch.nn.functional as F +from megatron.core.fusions.fused_bias_geglu import quick_gelu +from megatron.core.transformer import TransformerConfig +from swift.megatron.utils import convert_hf_config from swift.utils import get_logger -from transformers.utils import is_torch_npu_available - - -MAX_NPU_EXPERTS_PER_EP = 128 - logger = get_logger() @dataclass -class MegatronModelArguments: +class MegatronModelConfig(TransformerConfig): hf_model_type: Optional[str] = None llm_model_type: Optional[str] = None padded_vocab_size: Optional[int] = None @@ -22,7 +21,6 @@ class MegatronModelArguments: hidden_size: Optional[int] = None ffn_hidden_size: Optional[int] = None num_attention_heads: Optional[int] = None - group_query_attention: bool = False num_query_groups: Optional[int] = None softmax_type: Literal['vanilla', 'off-by-one', 'learnable'] = 'vanilla' window_size: Optional[str] = None @@ -71,7 +69,8 @@ class MegatronModelArguments: moe_router_bias_update_rate: Optional[float] = None moe_router_enable_expert_bias: bool = False moe_router_topk_scaling_factor: Optional[float] = None - moe_router_load_balancing_type: Literal['aux_loss', 'seq_aux_loss', 'global_aux_loss', 'sinkhorn', 'none'] = 'aux_loss' + moe_router_load_balancing_type: Literal['aux_loss', 'seq_aux_loss', 'global_aux_loss', 'sinkhorn', + 'none'] = 'aux_loss' use_shared_expert_gate: bool = False # mla @@ -79,7 +78,7 @@ class MegatronModelArguments: q_lora_rank: Optional[int] = None kv_lora_rank: int = 32 qk_head_dim: int = 128 - qk_pos_emb_head_dim: int= 64 + qk_pos_emb_head_dim: int = 64 v_head_dim: int = 128 # qwen3_next @@ -89,15 +88,10 @@ class MegatronModelArguments: linear_value_head_dim: Optional[int] = None linear_conv_kernel_dim: Optional[int] = None layer_types: Optional[List[str]] = None - # apply_layernorm_1p: bool = False # TODO + # apply_layernorm_1p: bool = False # TODO def __post_init__(self): - if self.num_query_groups is not None and self.num_query_groups > 1: - self.group_query_attention = True - self._init_moe() - - def _init_moe(self): if self.moe_router_dtype.lower() == 'none': self.moe_router_dtype = None if self.moe_shared_expert_intermediate_size == 0: @@ -105,13 +99,37 @@ def _init_moe(self): if self.num_experts is not None: if self.moe_ffn_hidden_size is None: self.moe_ffn_hidden_size = self.ffn_hidden_size - # TODO: remove - if is_torch_npu_available() and self.num_experts > MAX_NPU_EXPERTS_PER_EP: - required_ep = (self.num_experts + MAX_NPU_EXPERTS_PER_EP - 1) // MAX_NPU_EXPERTS_PER_EP - if self.expert_model_parallel_size < required_ep: - logger.warning(f'{">"*20} WARNING {"<"*20}\n' - f'MindSpeed on NPU supports up to {MAX_NPU_EXPERTS_PER_EP} experts per EP group. ' - f'num_experts={self.num_experts}, ' - f'expert_model_parallel_size={self.expert_model_parallel_size}. ' - f'Please set expert_model_parallel_size (EP) to {required_ep} ' - f'(num_experts / {MAX_NPU_EXPERTS_PER_EP}) or higher.') + super().__post_init__() + self.variable_seq_lengths = True + + +def create_mcore_model_config(args, hf_config): + # Translate args to core transformer configuration + kw_args = convert_hf_config(hf_config) + kw_args['persist_layer_norm'] = True + # TODO: apply_layernorm_1p + kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p + kw_args['deallocate_pipeline_outputs'] = True + kw_args['pipeline_dtype'] = args.torch_dtype + kw_args['batch_p2p_comm'] = True + kw_args['num_moe_experts'] = args.num_experts + kw_args['rotary_interleaved'] = args.rotary_interleaved + kw_args['num_layers_in_first_pipeline_stage'] = args.decoder_first_pipeline_num_layers + kw_args['num_layers_in_last_pipeline_stage'] = args.decoder_last_pipeline_num_layers + kw_args['fp8_param'] = args.fp8_param_gather + if args.swiglu: + kw_args['activation_func'] = F.silu + kw_args['gated_linear_unit'] = True + kw_args['bias_activation_fusion'] = args.bias_swiglu_fusion + else: + kw_args['bias_activation_fusion'] = args.bias_gelu_fusion + if args.quick_geglu: + assert not args.swiglu + kw_args['gated_linear_unit'] = True + kw_args['activation_func'] = quick_gelu + kw_args['cp_comm_type'] = 'p2p' + kw_args['inference_sampling_seed'] = args.seed + kw_args['variable_seq_lengths'] = True + config = MegatronModelConfig(**kw_args) + config.args = config + return config diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py deleted file mode 100644 index ce5882711d..0000000000 --- a/swift/megatron/model/model_provider.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import math -from typing import TYPE_CHECKING, Optional, Union - -import megatron.core -import torch -from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, - get_gpt_mtp_block_spec) -from packaging import version - -from swift.megatron.utils import core_transformer_config_from_args, convert_hf_config -from swift.utils import get_logger - -logger = get_logger() - -mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') - -if TYPE_CHECKING: - from .gpt_model import GPTModel - - -def _get_transformer_layer_spec(args): - kwargs = {'qk_l2_norm': args.qk_l2_norm} if mcore_013 else {} - return get_gpt_layer_with_transformer_engine_spec( - args.num_experts, - args.moe_grouped_gemm, - args.qk_layernorm, - args.multi_latent_attention, - **kwargs, - ) - - -def get_mcore_model(args, model_args, pre_process=True, post_process=True, vp_stage: Optional[int] = None) -> 'GPTModel': - from .register import get_megatron_model_meta - megatron_model_meta = get_megatron_model_meta(args.model_type) - - logger.info('building GPT model ...') - config = core_transformer_config_from_args(args, model_args) - config.variable_seq_lengths = True - if megatron_model_meta.get_transformer_layer_spec is not None: - transformer_layer_spec = megatron_model_meta.get_transformer_layer_spec(config, vp_stage=vp_stage) - else: - if args.num_experts: - kwargs = {'qk_l2_norm': args.qk_l2_norm, 'vp_stage': vp_stage} if mcore_013 else {} - # Define the decoder block spec - transformer_layer_spec = get_gpt_decoder_block_spec( - config, use_transformer_engine=True, normalization=args.normalization, **kwargs) - else: - # Define the decoder layer spec - transformer_layer_spec = _get_transformer_layer_spec(args) - mtp_block_spec = None - if args.mtp_num_layers is not None: - if hasattr(transformer_layer_spec, 'layer_specs') and len(transformer_layer_spec.layer_specs) == 0: - # Get the decoder layer spec explicitly if no decoder layer in the last stage, - # Only happens with block spec (TransformerBlockSubmodules) when using MoE. - transformer_layer_spec_for_mtp = _get_transformer_layer_spec(args) - else: - transformer_layer_spec_for_mtp = transformer_layer_spec - kwargs = {'vp_stage': vp_stage} if mcore_013 else {} - if megatron_model_meta.get_mtp_block_spec is not None: - get_mtp_block_spec = megatron_model_meta.get_mtp_block_spec - else: - get_mtp_block_spec = get_gpt_mtp_block_spec - mtp_block_spec = get_mtp_block_spec( - config, transformer_layer_spec_for_mtp, use_transformer_engine=True, **kwargs) - - if args.use_shared_expert_gate and args.num_experts and args.moe_shared_expert_intermediate_size: - for layer_spec in transformer_layer_spec.layer_specs: - if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): - layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} - model = megatron_model_meta.model_cls( - config=config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=math.ceil(args.padded_vocab_size / args.tensor_model_parallel_size) - * args.tensor_model_parallel_size, - max_sequence_length=args.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_base=args.rotary_base, - hf_rope_scaling=args.rope_scaling, - mtp_block_spec=mtp_block_spec, - vp_stage=vp_stage, - ) - - return model diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index b2b081d9c8..6fe251e51e 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -1,18 +1,28 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, List, Optional, Type +from typing import TYPE_CHECKING, Callable, List, Optional, Type, Union -import torch.nn as nn +import megatron.core +from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, + get_gpt_layer_with_transformer_engine_spec, + get_gpt_mtp_block_spec) +from packaging import version +from transformers.utils import is_torch_cuda_available, is_torch_npu_available + +from swift.megatron.utils import convert_hf_config from swift.model import MODEL_MAPPING +from swift.utils import get_logger from .constant import MLLMMegatronModelType -from .gpt_bridge import GPTBridge -from .gpt_model import GPTModel -from .mm_gpt_model import MultimodalGPTModel +from .model_config import create_mcore_model_config if TYPE_CHECKING: - from swift.megatron.arguments import MegatronArguments + from .gpt_model import GPTModel + from .mm_gpt_model import MultimodalGPTModel MEGATRON_MODEL_MAPPING = {} +logger = get_logger() +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @dataclass @@ -20,18 +30,20 @@ class MegatronModelMeta: megatron_model_type: str model_types: List[str] + loader: Optional[Type['MegatronModelLoader']] = None is_multimodal: bool = False - bridge_cls: Type[GPTBridge] = GPTBridge - model_cls: Optional[Type[nn.Module]] = None - get_transformer_layer_spec: Optional[Callable] = None - visual_cls: Optional[Type[nn.Module]] = None - get_mtp_block_spec: Optional[Callable] = None + + # bridge_cls: Type[GPTBridge] = GPTBridge + # model_cls: Optional[Type[nn.Module]] = None + # get_transformer_layer_spec: Optional[Callable] = None + # visual_cls: Optional[Type[nn.Module]] = None + # get_mtp_block_spec: Optional[Callable] = None def __post_init__(self): if self.megatron_model_type in MLLMMegatronModelType.__dict__: self.is_multimodal = True - if self.model_cls is None: - self.model_cls = MultimodalGPTModel if self.is_multimodal else GPTModel + if self.loader is None: + self.loader = MegatronModelLoader def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False): @@ -57,3 +69,108 @@ def get_megatron_model_meta(model_type: str) -> Optional[MegatronModelMeta]: if model_type not in _MODEL_META_MAPPING: return return MEGATRON_MODEL_MAPPING[_MODEL_META_MAPPING[model_type]] + + +class MegatronModelLoader: + + def __init__(self, args, hf_config): + self.args = args + self.hf_config = hf_config + self.config = create_mcore_model_config(args, hf_config) + self._check_npu() + + def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + args = self.args + if self.config.num_experts: + kwargs = {'qk_l2_norm': self.config.qk_l2_norm, 'vp_stage': vp_stage} if mcore_013 else {} + # Define the decoder block spec + transformer_layer_spec = get_gpt_decoder_block_spec( + self.config, use_transformer_engine=True, normalization=self.config.normalization, **kwargs) + else: + transformer_layer_spec = self._get_transformer_layer_spec() + + if args.use_shared_expert_gate and args.num_experts and args.moe_shared_expert_intermediate_size: + for layer_spec in transformer_layer_spec.layer_specs: + if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): + layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} + return transformer_layer_spec + + def _get_transformer_layer_spec(self): + config = self.config + kwargs = {'qk_l2_norm': config.qk_l2_norm} if mcore_013 else {} + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + config.num_experts, + self.args.moe_grouped_gemm, + config.qk_layernorm, + config.multi_latent_attention, + **kwargs, + ) + return transformer_layer_spec + + def get_mtp_block_spec(self, transformer_layer_spec, vp_stage: Optional[int] = None): + if self.args.mtp_num_layers is None: + return + if hasattr(transformer_layer_spec, 'layer_specs') and len(transformer_layer_spec.layer_specs) == 0: + # Get the decoder layer spec explicitly if no decoder layer in the last stage, + # Only happens with block spec (TransformerBlockSubmodules) when using MoE. + transformer_layer_spec_for_mtp = self._get_transformer_layer_spec() + else: + transformer_layer_spec_for_mtp = transformer_layer_spec + kwargs = {'vp_stage': vp_stage} if mcore_013 else {} + + return get_gpt_mtp_block_spec( + self.config, transformer_layer_spec_for_mtp, use_transformer_engine=True, **kwargs) + + def create_model_and_load( + self, + pre_process=True, + post_process=True, + vp_stage: Optional[int] = None, + ) -> Union[GPTModel, MultimodalGPTModel]: + transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) + mtp_block_spec = self.get_mtp_block_spec(transformer_layer_spec, vp_stage=vp_stage) + return self._create_model( + transformer_layer_spec, + mtp_block_spec, + pre_process=pre_process, + post_process=post_process, + vp_stage=vp_stage) + + def _check_npu(self): + MAX_NPU_EXPERTS_PER_EP = 128 + num_experts = self.config.num_experts + if is_torch_npu_available() and num_experts > MAX_NPU_EXPERTS_PER_EP: + required_ep = (num_experts + MAX_NPU_EXPERTS_PER_EP - 1) // MAX_NPU_EXPERTS_PER_EP + if self.args.expert_model_parallel_size < required_ep: + logger.warning(f'{">" * 20} WARNING {"<" * 20}\n' + f'MindSpeed on NPU supports up to {MAX_NPU_EXPERTS_PER_EP} experts per EP group. ' + f'num_experts={num_experts}, ' + f'expert_model_parallel_size={self.args.expert_model_parallel_size}. ' + f'Please set expert_model_parallel_size (EP) to {required_ep} ' + f'(num_experts / {MAX_NPU_EXPERTS_PER_EP}) or higher.') + + def _create_model(self, + transformer_layer_spec, + mtp_block_spec, + pre_process=True, + post_process=True, + vp_stage: Optional[int] = None): + if self.args.is_multimodal: + model_cls = MultimodalGPTModel + else: + model_cls = GPTModel + return model_cls( + config=self.config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=math.ceil(args.padded_vocab_size / args.tensor_model_parallel_size) + * args.tensor_model_parallel_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_base=args.rotary_base, + hf_rope_scaling=args.rope_scaling, + mtp_block_spec=mtp_block_spec, + vp_stage=vp_stage, + ) diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 759362581c..9c02183206 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -2,8 +2,7 @@ from .config import convert_hf_config from .convert_utils import test_convert_precision -from .megatron_lm_utils import (core_transformer_config_from_args, initialize_megatron, load_mcore_checkpoint, - save_mcore_checkpoint) +from .megatron_lm_utils import initialize_megatron, load_mcore_checkpoint, save_mcore_checkpoint from .patcher import patch_merge_fn, patch_torch_dist_shard from .utils import (MegatronTrainerState, adapter_state_dict_context, copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index 6705155c2e..42797b8945 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -9,17 +9,14 @@ import numpy as np import torch -import torch.nn.functional as F from megatron.core import dist_checkpointing, mpu, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedObject from megatron.core.dist_checkpointing.serialization import (get_default_load_sharded_strategy, get_default_save_sharded_strategy) from megatron.core.dist_checkpointing.strategies.fully_parallel import (FullyParallelLoadStrategyWrapper, FullyParallelSaveStrategyWrapper) -from megatron.core.fusions.fused_bias_geglu import quick_gelu from megatron.core.msc_utils import open_file from megatron.core.num_microbatches_calculator import update_num_microbatches -from megatron.core.transformer import MLATransformerConfig, TransformerConfig from megatron.core.utils import unwrap_model from swift.utils import check_json_format, get_logger, init_process_group, is_master, seed_everything, set_device @@ -111,53 +108,6 @@ def initialize_megatron(args): # TODO: tp_comm_overlap, _compile_dependencies -def core_transformer_config_from_args(args, config_class=None): - # Config class. - config_class = config_class or TransformerConfig - - if args.multi_latent_attention: - config_class = MLATransformerConfig - - # Translate args to core transformer configuration - kw_args = {} - for f in dataclasses.fields(config_class): - if hasattr(args, f.name): - kw_args[f.name] = getattr(args, f.name) - kw_args['persist_layer_norm'] = True - # TODO: apply_layernorm_1p - kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p - kw_args['deallocate_pipeline_outputs'] = True - kw_args['pipeline_dtype'] = args.torch_dtype - kw_args['batch_p2p_comm'] = True - kw_args['num_moe_experts'] = args.num_experts - kw_args['rotary_interleaved'] = args.rotary_interleaved - kw_args['num_layers_in_first_pipeline_stage'] = args.decoder_first_pipeline_num_layers - kw_args['num_layers_in_last_pipeline_stage'] = args.decoder_last_pipeline_num_layers - kw_args['fp8_param'] = args.fp8_param_gather - if args.swiglu: - kw_args['activation_func'] = F.silu - kw_args['gated_linear_unit'] = True - kw_args['bias_activation_fusion'] = args.bias_swiglu_fusion - else: - kw_args['bias_activation_fusion'] = args.bias_gelu_fusion - if args.quick_geglu: - assert not args.swiglu - kw_args['gated_linear_unit'] = True - kw_args['activation_func'] = quick_gelu - if args.group_query_attention: - kw_args['num_query_groups'] = args.num_query_groups - else: - kw_args['num_query_groups'] = None - - kw_args['cp_comm_type'] = 'p2p' - kw_args['inference_sampling_seed'] = args.seed - - config = config_class(**kw_args) - config.args = args - - return config - - def _get_rng_state(): """Collect rng state across data parallel ranks.""" rng_state = { From 7b8e28c2fc71cadf73832f18998fbd8a80b43752 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 4 Feb 2026 15:49:45 +0800 Subject: [PATCH 18/43] update --- swift/arguments/rlhf_args.py | 6 +++--- swift/megatron/arguments/megatron_args.py | 22 ++++++--------------- swift/megatron/model/gpt_model.py | 24 ++++++++++------------- swift/megatron/model/model_config.py | 16 +++++++++------ swift/megatron/model/register.py | 12 ++---------- 5 files changed, 31 insertions(+), 49 deletions(-) diff --git a/swift/arguments/rlhf_args.py b/swift/arguments/rlhf_args.py index 8bbe67469b..60055ab665 100644 --- a/swift/arguments/rlhf_args.py +++ b/swift/arguments/rlhf_args.py @@ -319,9 +319,9 @@ def _init_grpo(self): logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}') if self.truncation_strategy is None: self.truncation_strategy = 'left' - assert self.truncation_strategy in ['left', 'delete' - ], ("GRPO requires `truncation_strategy 'left' or 'delete'`, " - f"Current value: `truncation_strategy='{self.truncation_strategy}'`.") + if self.truncation_strategy not in {'left', 'delete'}: + raise ValueError("GRPO requires `truncation_strategy 'left' or 'delete'`, " + f"Current value: `truncation_strategy='{self.truncation_strategy}'`.") if self.beta is None: self.beta = 0.04 # https://arxiv.org/abs/2402.03300 if self.async_generate: diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 0f67221c62..31c3662767 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -259,10 +259,9 @@ def _check_batch_params(): logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}') if self.truncation_strategy is None: self.truncation_strategy = 'left' - assert self.truncation_strategy in ['left', 'delete' - ], ("GRPO requires `truncation_strategy 'left' or 'delete'`, " - f"Current value: `truncation_strategy='{self.truncation_strategy}'`." - ) # noqa + if self.truncation_strategy not in {'left', 'delete'}: + raise ValueError("GRPO requires `truncation_strategy 'left' or 'delete'`, " + f"Current value: `truncation_strategy='{self.truncation_strategy}'`.") if self.beta is None: self.beta = 0.04 # https://arxiv.org/abs/2402.03300 if self.async_generate: @@ -320,7 +319,6 @@ def __post_init__(self): class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): check_model: bool = True initialize_embedding: bool = False - rope_scaling: Optional[Union[dict, str]] = None torch_dtype: Optional[Union[torch.dtype, str]] = None padding_free: bool = True mlp_padding_free: bool = False @@ -349,9 +347,6 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): max_epochs: Optional[int] = None enable_dft_loss: bool = False enable_channel_loss: bool = False - task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'generative_reranker'] = None - num_labels: Optional[int] = None - problem_type: Literal['regression', 'single_label_classification', 'multi_label_classification'] = None save_strategy: Literal['steps', 'epoch'] = 'steps' report_to: Optional[Literal['wandb', 'swanlab']] = None @@ -372,7 +367,8 @@ def load_args_config(ckpt_dir: Optional[str]) -> Dict[str, Any]: with open(args_path, 'r', encoding='utf-8') as f: old_args = json.load(f) keys = list(f.name for f in fields(MegatronTunerMixin)) - keys += ['load', 'padded_vocab_size', 'task_type', 'num_labels'] + # TODO: remove load/save + keys += ['load', 'padded_vocab_size', 'task_type', 'num_labels'] # TODO: padded_vocab_size for key in keys: old_value = old_args.get(key) if old_value is not None: @@ -398,7 +394,7 @@ class MegatronArguments(ExtraMegatronArguments): log_interval: int = 5 tensorboard_dir: Optional[str] = None masked_softmax_fusion: bool = True - bias_dropout_fusion: Optional[bool] = None + bias_dropout_fusion: bool = True # TODO: gpt-oss bias_swiglu_fusion: bool = True bias_gelu_fusion: bool = True apply_rope_fusion: Optional[bool] = None @@ -544,8 +540,6 @@ def _set_default(self): self.task_type = 'causal_lm' if self.calculate_per_token_loss is None: self.calculate_per_token_loss = self.task_type == 'causal_lm' - if self.bias_dropout_fusion is None: - self.bias_dropout_fusion = True # log if self.wandb_exp_name is None: @@ -583,10 +577,6 @@ def __post_init__(self): if hasattr(self, 'ddp_timeout'): self.distributed_timeout_minutes = self.ddp_timeout // 60 self.fp8 = self.fp8_format # compat megatron-lm - if self.rope_scaling is not None: - self.rope_scaling = json_parse_to_dict(self.rope_scaling) - if 'type' in self.rope_scaling and 'rope_type' not in self.rope_scaling: - self.rope_scaling['rope_type'] = self.rope_scaling['type'] if self.task_type not in {'causal_lm', 'generative_reranker'}: self.untie_embeddings_and_output_weights = True if self.gradient_checkpointing_kwargs is not None: diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index be01e1ea73..ab48fa958e 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -3,7 +3,7 @@ import os from collections import OrderedDict from copy import deepcopy -from typing import Any, Dict, Literal, Optional, Tuple +from typing import Optional, Tuple import megatron.core import torch @@ -58,18 +58,14 @@ def __init__( self, config: MegatronModelConfig, transformer_layer_spec: ModuleSpec, - vocab_size: int, - max_sequence_length: int, pre_process: bool = True, post_process: bool = True, - share_embeddings_and_output_weights: bool = False, - position_embedding_type: Literal['learned_absolute', 'rope', 'mrope', 'none'] = 'learned_absolute', - rotary_base: int = 10000, - hf_rope_scaling: Dict[str, Any] = None, mtp_block_spec: Optional[ModuleSpec] = None, vp_stage: Optional[int] = None, ): - vocab_size = math.ceil(vocab_size / config.tensor_model_parallel_size) * config.tensor_model_parallel_size + vocab_size = math.ceil( + config.padded_vocab_size / config.tensor_model_parallel_size) * config.tensor_model_parallel_size + hf_rope_scaling = config.rope_scaling if config.multi_latent_attention and config.rope_type == 'yarn': config.rope_type = 'rope' # use transformers implementation if hf_rope_scaling and hf_rope_scaling['rope_type'] == 'yarn': @@ -88,12 +84,12 @@ def __init__( config, transformer_layer_spec, vocab_size, - max_sequence_length, + config.max_position_embeddings, pre_process=pre_process, post_process=post_process, - share_embeddings_and_output_weights=share_embeddings_and_output_weights, - position_embedding_type=position_embedding_type, - rotary_base=rotary_base, + share_embeddings_and_output_weights=not config.untie_embeddings_and_output_weights, + position_embedding_type=config.position_embedding_type, + rotary_base=config.rotary_base, mtp_block_spec=mtp_block_spec, **kwargs, ) @@ -102,7 +98,7 @@ def __init__( kv_channels=config.qk_pos_emb_head_dim, rotary_percent=1, rotary_interleaved=config.rotary_interleaved, - rotary_base=rotary_base, + rotary_base=config.rotary_base, use_cpu_initialization=config.use_cpu_initialization, ) # save memory @@ -128,7 +124,7 @@ def __init__( elif args.task_type == 'embedding' and self.post_process: self.output_layer = None - if (self.attention_scaling != 1 or position_embedding_type == 'mrope') and config.apply_rope_fusion: + if (self.attention_scaling != 1 or config.position_embedding_type == 'mrope') and config.apply_rope_fusion: config.apply_rope_fusion = False if self.attention_scaling != 1: warning_string = 'attention_scaling' diff --git a/swift/megatron/model/model_config.py b/swift/megatron/model/model_config.py index 5e0ff6fc59..c657fad4f1 100644 --- a/swift/megatron/model/model_config.py +++ b/swift/megatron/model/model_config.py @@ -1,12 +1,12 @@ -from dataclasses import dataclass, fields -from typing import List, Literal, Optional +from dataclasses import dataclass +from typing import List, Literal, Optional, Union import torch.nn.functional as F from megatron.core.fusions.fused_bias_geglu import quick_gelu from megatron.core.transformer import TransformerConfig from swift.megatron.utils import convert_hf_config -from swift.utils import get_logger +from swift.utils import get_logger, json_parse_to_dict logger = get_logger() @@ -16,6 +16,8 @@ class MegatronModelConfig(TransformerConfig): hf_model_type: Optional[str] = None llm_model_type: Optional[str] = None padded_vocab_size: Optional[int] = None + rope_scaling: Optional[Union[dict, str]] = None + # model num_layers: Optional[int] = None hidden_size: Optional[int] = None @@ -27,7 +29,7 @@ class MegatronModelConfig(TransformerConfig): window_attn_skip_freq: Optional[str] = None max_position_embeddings: Optional[int] = None - position_embedding_type: Optional[Literal['learned_absolute', 'rope', 'mrope', 'relative', 'none']] = None + position_embedding_type: Literal['learned_absolute', 'rope', 'mrope', 'none'] = 'rope', rotary_base: int = 10000 rotary_percent: float = 1. rotary_interleaved: bool = False @@ -94,11 +96,13 @@ class MegatronModelConfig(TransformerConfig): def __post_init__(self): if self.moe_router_dtype.lower() == 'none': self.moe_router_dtype = None - if self.moe_shared_expert_intermediate_size == 0: - self.moe_shared_expert_intermediate_size = None if self.num_experts is not None: if self.moe_ffn_hidden_size is None: self.moe_ffn_hidden_size = self.ffn_hidden_size + if self.rope_scaling is not None: + self.rope_scaling = json_parse_to_dict(self.rope_scaling) + if 'type' in self.rope_scaling and 'rope_type' not in self.rope_scaling: + self.rope_scaling['rope_type'] = self.rope_scaling['type'] super().__post_init__() self.variable_seq_lengths = True diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 6fe251e51e..cbd716dcc0 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -1,16 +1,15 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, List, Optional, Type, Union +from typing import TYPE_CHECKING, List, Optional, Type, Union import megatron.core from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, get_gpt_layer_with_transformer_engine_spec, get_gpt_mtp_block_spec) from packaging import version -from transformers.utils import is_torch_cuda_available, is_torch_npu_available +from transformers.utils import is_torch_npu_available -from swift.megatron.utils import convert_hf_config from swift.model import MODEL_MAPPING from swift.utils import get_logger from .constant import MLLMMegatronModelType @@ -162,15 +161,8 @@ def _create_model(self, return model_cls( config=self.config, transformer_layer_spec=transformer_layer_spec, - vocab_size=math.ceil(args.padded_vocab_size / args.tensor_model_parallel_size) - * args.tensor_model_parallel_size, - max_sequence_length=args.max_position_embeddings, pre_process=pre_process, post_process=post_process, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_base=args.rotary_base, - hf_rope_scaling=args.rope_scaling, mtp_block_spec=mtp_block_spec, vp_stage=vp_stage, ) From 29f144089b8f2b5f869aaea768a97fab1bd08198 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 4 Feb 2026 17:24:52 +0800 Subject: [PATCH 19/43] update --- swift/megatron/model/gpts/glm4.py | 46 ++++---- swift/megatron/model/gpts/minimax_m2.py | 50 ++++----- swift/megatron/model/gpts/olmoe.py | 21 ++-- swift/megatron/model/gpts/qwen3_emb.py | 21 ++-- swift/megatron/model/gpts/qwen3_next.py | 135 +++++++++++------------ swift/megatron/model/mm_gpts/glm.py | 21 ++-- swift/megatron/model/mm_gpts/internvl.py | 26 +++-- swift/megatron/model/mm_gpts/kimi_vl.py | 16 ++- swift/megatron/model/mm_gpts/llama4.py | 53 +++++---- swift/megatron/model/mm_gpts/qwen.py | 65 +++++++---- swift/megatron/model/mm_gpts/qwen3_vl.py | 63 ++++++----- swift/megatron/model/mm_gpts/utils.py | 2 +- swift/megatron/model/register.py | 52 ++++----- 13 files changed, 309 insertions(+), 262 deletions(-) diff --git a/swift/megatron/model/gpts/glm4.py b/swift/megatron/model/gpts/glm4.py index c7ac043d2f..676d071ba7 100644 --- a/swift/megatron/model/gpts/glm4.py +++ b/swift/megatron/model/gpts/glm4.py @@ -4,7 +4,6 @@ import megatron.core from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.extensions.transformer_engine import TENorm -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.transformer import transformer_layer from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.mlp import MLP, apply_swiglu_sharded_factory @@ -16,7 +15,7 @@ from ..constant import MegatronModelType from ..gpt_bridge import GPTBridge from ..model_config import MegatronModelConfig -from ..register import MegatronModelMeta, register_megatron_model +from ..register import MegatronModelLoader, MegatronModelMeta, register_megatron_model mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -94,28 +93,21 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo return hf_state_dict -def get_glm4_transformer_layer_spec(config, vp_stage=None): - kwargs = {'use_kitchen': config.use_kitchen} if mcore_013 else {} - layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=config.num_moe_experts, - moe_grouped_gemm=config.moe_grouped_gemm, - qk_layernorm=config.qk_layernorm, - multi_latent_attention=config.multi_latent_attention, - moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, - **kwargs, - ) - layer_spec.submodules.self_attention.module = Glm4SelfAttention - layer_spec.submodules.mlp.module = Glm4MLP - transformer_layer.MLP = Glm4MLP # patch - return layer_spec - - -register_megatron_model( - MegatronModelMeta( - MegatronModelType.glm4, - [ - ModelType.glm4, - ], - get_transformer_layer_spec=get_glm4_transformer_layer_spec, - bridge_cls=Glm4Bridge, - )) +class Glm4Loader(MegatronModelLoader): + bridge_cls = Glm4Bridge + + def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + layer_spec = self._get_transformer_layer_spec() + layer_spec.submodules.self_attention.module = Glm4SelfAttention + layer_spec.submodules.mlp.module = Glm4MLP + transformer_layer.MLP = Glm4MLP # patch + return layer_spec + + +register_megatron_model(MegatronModelMeta( + MegatronModelType.glm4, + [ + ModelType.glm4, + ], + Glm4Loader, +)) diff --git a/swift/megatron/model/gpts/minimax_m2.py b/swift/megatron/model/gpts/minimax_m2.py index 2bec049f5e..fc0b3f99c8 100644 --- a/swift/megatron/model/gpts/minimax_m2.py +++ b/swift/megatron/model/gpts/minimax_m2.py @@ -1,7 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from typing import Optional + import megatron.core -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core import mpu from megatron.core.tensor_parallel.mappings import (gather_from_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region) from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules @@ -13,7 +15,7 @@ from ..constant import MegatronModelType from ..gpt_bridge import GPTBridge from ..model_config import MegatronModelConfig -from ..register import MegatronModelMeta, register_megatron_model +from ..register import MegatronModelLoader, MegatronModelMeta, register_megatron_model mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -48,15 +50,16 @@ def __init__( ) def get_query_key_value_tensors(self, *_args, **kwargs): + enable_tp = mpu.get_expert_data_parallel_world_size() > 1 query, key, value = super().get_query_key_value_tensors(*_args, **kwargs) query = query.reshape(*query.shape[:-2], -1) key = key.reshape(*key.shape[:-2], -1) - if args.tensor_model_parallel_size > 1: + if enable_tp: query = gather_from_tensor_model_parallel_region(query) key = gather_from_tensor_model_parallel_region(key) query = self.q_norm(query) key = self.k_norm(key) - if args.tensor_model_parallel_size > 1: + if enable_tp: query = scatter_to_tensor_model_parallel_region(query) key = scatter_to_tensor_model_parallel_region(key) query = query.view(*query.shape[:2], -1, self.hidden_size_per_attention_head) @@ -98,26 +101,19 @@ def _set_moe_state( return hf_state_dict -def get_minimax_m2_transformer_layer_spec(config, vp_stage=None): - kwargs = {'use_kitchen': config.use_kitchen} if mcore_013 else {} - layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=config.num_moe_experts, - moe_grouped_gemm=config.moe_grouped_gemm, - qk_layernorm=config.qk_layernorm, - multi_latent_attention=config.multi_latent_attention, - moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, - **kwargs, - ) - layer_spec.submodules.self_attention.module = MinimaxM2SelfAttention - return layer_spec - - -register_megatron_model( - MegatronModelMeta( - MegatronModelType.minimax_m2, - [ - ModelType.minimax_m2, - ], - get_transformer_layer_spec=get_minimax_m2_transformer_layer_spec, - bridge_cls=MinimaxM2Bridge, - )) +class MinimaxM2Loader(MegatronModelLoader): + bridge_cls = MinimaxM2Bridge + + def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + layer_spec = self._get_transformer_layer_spec() + layer_spec.submodules.self_attention.module = MinimaxM2SelfAttention + return layer_spec + + +register_megatron_model(MegatronModelMeta( + MegatronModelType.minimax_m2, + [ + ModelType.minimax_m2, + ], + MinimaxM2Loader, +)) diff --git a/swift/megatron/model/gpts/olmoe.py b/swift/megatron/model/gpts/olmoe.py index 612c9453ef..d60812b6b0 100644 --- a/swift/megatron/model/gpts/olmoe.py +++ b/swift/megatron/model/gpts/olmoe.py @@ -18,7 +18,7 @@ from ..constant import MegatronModelType from ..gpt_bridge import GPTBridge from ..model_config import MegatronModelConfig -from ..register import MegatronModelMeta, register_megatron_model +from ..register import MegatronModelLoader, MegatronModelMeta, register_megatron_model mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -217,10 +217,15 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int return hf_state_dict -register_megatron_model( - MegatronModelMeta( - MegatronModelType.olmoe, - [ModelType.olmoe], - get_transformer_layer_spec=get_olmoe_decoder_block_spec, - bridge_cls=OLMoEBridge, - )) +class OlMoELoader(MegatronModelLoader): + bridge_cls = OLMoEBridge + + def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + return get_olmoe_decoder_block_spec(self.config, vp_stage) + + +register_megatron_model(MegatronModelMeta( + MegatronModelType.olmoe, + [ModelType.olmoe], + OlMoELoader, +)) diff --git a/swift/megatron/model/gpts/qwen3_emb.py b/swift/megatron/model/gpts/qwen3_emb.py index 3c83c6cf80..074155bace 100644 --- a/swift/megatron/model/gpts/qwen3_emb.py +++ b/swift/megatron/model/gpts/qwen3_emb.py @@ -2,7 +2,7 @@ from swift.model import ModelType from ..constant import MegatronModelType from ..gpt_bridge import GPTBridge -from ..register import MegatronModelMeta, register_megatron_model +from ..register import MegatronModelLoader, MegatronModelMeta, register_megatron_model class Qwen3EmbBridge(GPTBridge): @@ -16,11 +16,14 @@ def _convert_hf_state_dict(self, hf_state_dict, to_mcore): return res -register_megatron_model( - MegatronModelMeta( - MegatronModelType.qwen3_emb, - [ - ModelType.qwen3_emb, - ], - bridge_cls=Qwen3EmbBridge, - )) +class Qwen3EmbLoader(MegatronModelLoader): + bridge_cls = Qwen3EmbBridge + + +register_megatron_model(MegatronModelMeta( + MegatronModelType.qwen3_emb, + [ + ModelType.qwen3_emb, + ], + Qwen3EmbLoader, +)) diff --git a/swift/megatron/model/gpts/qwen3_next.py b/swift/megatron/model/gpts/qwen3_next.py index bea36de69e..835af76e03 100644 --- a/swift/megatron/model/gpts/qwen3_next.py +++ b/swift/megatron/model/gpts/qwen3_next.py @@ -24,7 +24,7 @@ from ..constant import MegatronModelType from ..gpt_bridge import GPTBridge from ..model_config import MegatronModelConfig -from ..register import MegatronModelMeta, register_megatron_model +from ..register import MegatronModelLoader, MegatronModelMeta, register_megatron_model mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') mcore_015 = version.parse(megatron.core.__version__) >= version.parse('0.15.0rc0') @@ -474,65 +474,6 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): return res, None -def get_qwen3_next_transformer_layer_spec(config, vp_stage=None): - config.hetereogenous_dist_checkpoint = True - # compat Qwen3NextGatedDeltaNet - args = get_args() - config.hidden_act = 'silu' - config.rms_norm_eps = config.layernorm_epsilon - config.dtype = args.torch_dtype - config.linear_num_value_heads = args.linear_num_value_heads - config.linear_num_key_heads = args.linear_num_key_heads - config.linear_key_head_dim = args.linear_key_head_dim - config.linear_value_head_dim = args.linear_value_head_dim - config.linear_conv_kernel_dim = args.linear_conv_kernel_dim - - # Use Zero-Centered RMSNorm to match HuggingFace exactly (no +1/-1 conversion needed) - layer_norm_impl = Qwen3NextRMSNorm - kwargs = {'use_kitchen': config.use_kitchen} if mcore_013 else {} - moe_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=config.num_moe_experts, - moe_grouped_gemm=config.moe_grouped_gemm, - qk_layernorm=config.qk_layernorm, - multi_latent_attention=config.multi_latent_attention, - moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, - **kwargs, - ) - layer_specs = [] - for layer_type in args.layer_types: - layer_spec = deepcopy(moe_layer_spec) - if layer_type == 'linear_attention': - layer_spec.submodules.self_attention.module = Qwen3NextGatedDeltaNet - elif layer_type == 'full_attention': - layer_spec.submodules.self_attention.submodules.linear_qkv = TEColumnParallelLinear - layer_spec.submodules.self_attention.module = Qwen3NextSelfAttention - # Replace ALL layernorms with Qwen3NextRMSNorm (Zero-Centered) - layer_spec.submodules.input_layernorm = layer_norm_impl - if hasattr(layer_spec.submodules, 'pre_mlp_layernorm'): - layer_spec.submodules.pre_mlp_layernorm = layer_norm_impl - # Replace qk_layernorm if present - if hasattr(layer_spec.submodules.self_attention.submodules, 'q_layernorm'): - layer_spec.submodules.self_attention.submodules.q_layernorm = layer_norm_impl - if hasattr(layer_spec.submodules.self_attention.submodules, 'k_layernorm'): - layer_spec.submodules.self_attention.submodules.k_layernorm = layer_norm_impl - layer_specs.append(layer_spec) - - local_layer_specs = get_local_layer_specs(config, layer_specs, vp_stage=vp_stage) - block_spec = TransformerBlockSubmodules(layer_specs=local_layer_specs, layer_norm=layer_norm_impl) - - return block_spec - - -def get_qwen3_next_mtp_block_spec(*args, **kwargs): - mtp_block_spec = get_gpt_mtp_block_spec(*args, **kwargs) - if mtp_block_spec is not None: - for layer_spec in mtp_block_spec.layer_specs: - layer_spec.submodules.enorm = Qwen3NextRMSNorm - layer_spec.submodules.hnorm = Qwen3NextRMSNorm - layer_spec.submodules.layer_norm = Qwen3NextRMSNorm - return mtp_block_spec - - class Qwen3NextBridge(GPTBridge): hf_mtp_prefix = 'mtp.layers' @@ -559,13 +500,67 @@ def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state origin_hf_state_dict.update(self._add_prefix(hf_state_dict, 'mtp.')) -register_megatron_model( - MegatronModelMeta( - MegatronModelType.qwen3_next, - [ - ModelType.qwen3_next, - ], - get_transformer_layer_spec=get_qwen3_next_transformer_layer_spec, - get_mtp_block_spec=get_qwen3_next_mtp_block_spec, - bridge_cls=Qwen3NextBridge, - )) +class Qwen3NextLoader(MegatronModelLoader): + bridge_cls = Qwen3NextBridge + + def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + config = self.config + args = self.args + config.hetereogenous_dist_checkpoint = True + # compat Qwen3NextGatedDeltaNet + config.hidden_act = 'silu' + config.rms_norm_eps = config.layernorm_epsilon + config.dtype = args.torch_dtype + + # Use Zero-Centered RMSNorm to match HuggingFace exactly (no +1/-1 conversion needed) + layer_norm_impl = Qwen3NextRMSNorm + kwargs = {'use_kitchen': config.use_kitchen} if mcore_013 else {} + moe_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + **kwargs, + ) + layer_specs = [] + for layer_type in args.layer_types: + layer_spec = deepcopy(moe_layer_spec) + if layer_type == 'linear_attention': + layer_spec.submodules.self_attention.module = Qwen3NextGatedDeltaNet + elif layer_type == 'full_attention': + layer_spec.submodules.self_attention.submodules.linear_qkv = TEColumnParallelLinear + layer_spec.submodules.self_attention.module = Qwen3NextSelfAttention + # Replace ALL layernorms with Qwen3NextRMSNorm (Zero-Centered) + layer_spec.submodules.input_layernorm = layer_norm_impl + if hasattr(layer_spec.submodules, 'pre_mlp_layernorm'): + layer_spec.submodules.pre_mlp_layernorm = layer_norm_impl + # Replace qk_layernorm if present + if hasattr(layer_spec.submodules.self_attention.submodules, 'q_layernorm'): + layer_spec.submodules.self_attention.submodules.q_layernorm = layer_norm_impl + if hasattr(layer_spec.submodules.self_attention.submodules, 'k_layernorm'): + layer_spec.submodules.self_attention.submodules.k_layernorm = layer_norm_impl + layer_specs.append(layer_spec) + + local_layer_specs = get_local_layer_specs(config, layer_specs, vp_stage=vp_stage) + block_spec = TransformerBlockSubmodules(layer_specs=local_layer_specs, layer_norm=layer_norm_impl) + + return block_spec + + def get_mtp_block_spec(self, *args, **kwargs): + # TODO: layernorm_zero_centered_gamma + mtp_block_spec = super().get_mtp_block_spec(*args, **kwargs) + for layer_spec in mtp_block_spec.layer_specs: + layer_spec.submodules.enorm = Qwen3NextRMSNorm + layer_spec.submodules.hnorm = Qwen3NextRMSNorm + layer_spec.submodules.layer_norm = Qwen3NextRMSNorm + return mtp_block_spec + + +register_megatron_model(MegatronModelMeta( + MegatronModelType.qwen3_next, + [ + ModelType.qwen3_next, + ], + Qwen3NextLoader, +)) diff --git a/swift/megatron/model/mm_gpts/glm.py b/swift/megatron/model/mm_gpts/glm.py index 19898ac71d..65a95ea2bc 100644 --- a/swift/megatron/model/mm_gpts/glm.py +++ b/swift/megatron/model/mm_gpts/glm.py @@ -3,7 +3,7 @@ from swift.template import Template from ..constant import MegatronModelType from ..gpt_bridge import MultimodalGPTBridge -from ..gpts.glm4 import Glm4Bridge, get_glm4_transformer_layer_spec +from ..gpts.glm4 import Glm4Bridge, Glm4Loader from ..register import MegatronModelMeta, register_megatron_model from .utils import HuggingFaceModule @@ -32,11 +32,14 @@ class Glm4vBridge(Glm4Bridge, MultimodalGPTBridge): pass -register_megatron_model( - MegatronModelMeta( - MegatronModelType.glm4v, [ - ModelType.glm4v, - ], - get_transformer_layer_spec=get_glm4_transformer_layer_spec, - bridge_cls=Glm4vBridge, - visual_cls=Glm4vVit)) +class Glm4vLoader(Glm4Loader): + bridge_cls = Glm4vBridge + visual_cls = Glm4vVit + + +register_megatron_model(MegatronModelMeta( + MegatronModelType.glm4v, + [ + ModelType.glm4v, + ], +)) diff --git a/swift/megatron/model/mm_gpts/internvl.py b/swift/megatron/model/mm_gpts/internvl.py index 6b6a2696cb..25442e190c 100644 --- a/swift/megatron/model/mm_gpts/internvl.py +++ b/swift/megatron/model/mm_gpts/internvl.py @@ -4,7 +4,7 @@ from swift.model import ModelType from ..constant import MegatronModelType from ..gpt_bridge import GPTBridge, MultimodalGPTBridge -from ..register import MegatronModelMeta, register_megatron_model +from ..register import MegatronModelLoader, MegatronModelMeta, register_megatron_model from .utils import HuggingFaceModule @@ -59,15 +59,21 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds +class InternvlLoader(MegatronModelLoader): + bridge_cls = Internvl3Bridge + visual_cls = Internvl3Vit + + register_megatron_model( MegatronModelMeta( - MegatronModelType.internvl3, [ + MegatronModelType.internvl3, + [ ModelType.internvl3, ModelType.internvl3_5, ModelType.internvl3_5_gpt, ], - bridge_cls=Internvl3Bridge, - visual_cls=Internvl3Vit)) + InternvlLoader, + )) class InternvlHfBridge(MultimodalGPTBridge): @@ -127,11 +133,17 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds +class InternvlHfLoader(MegatronModelLoader): + bridge_cls = InternvlHfBridge + visual_cls = InternvlHfVit + + register_megatron_model( MegatronModelMeta( - MegatronModelType.internvl_hf, [ + MegatronModelType.internvl_hf, + [ ModelType.internvl_hf, ModelType.internvl_gpt_hf, ], - bridge_cls=InternvlHfBridge, - visual_cls=InternvlHfVit)) + InternvlHfLoader, + )) diff --git a/swift/megatron/model/mm_gpts/kimi_vl.py b/swift/megatron/model/mm_gpts/kimi_vl.py index 0eb9ddb1c3..888eb0888a 100644 --- a/swift/megatron/model/mm_gpts/kimi_vl.py +++ b/swift/megatron/model/mm_gpts/kimi_vl.py @@ -6,7 +6,7 @@ from swift.model import ModelType from ..constant import MegatronModelType from ..gpt_bridge import MultimodalGPTBridge -from ..register import MegatronModelMeta, register_megatron_model +from ..register import MegatronModelLoader, MegatronModelMeta, register_megatron_model from .utils import HuggingFaceModule @@ -47,7 +47,15 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds -register_megatron_model( - MegatronModelMeta(MegatronModelType.kimi_vl, [ +class KimiLoader(MegatronModelLoader): + bridge_cls = KimiVLBridge + visual_cls = KimiVLVit + + +register_megatron_model(MegatronModelMeta( + MegatronModelType.kimi_vl, + [ ModelType.kimi_vl, - ], bridge_cls=KimiVLBridge, visual_cls=KimiVLVit)) + ], + KimiLoader, +)) diff --git a/swift/megatron/model/mm_gpts/llama4.py b/swift/megatron/model/mm_gpts/llama4.py index 2950d9b800..8ff23a65a7 100644 --- a/swift/megatron/model/mm_gpts/llama4.py +++ b/swift/megatron/model/mm_gpts/llama4.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from copy import deepcopy +from typing import Optional import megatron.core import torch @@ -11,7 +12,7 @@ from swift.model import ModelType from ..constant import MegatronModelType from ..gpt_bridge import GPTBridge -from ..register import MegatronModelMeta, register_megatron_model +from ..register import MegatronModelLoader, MegatronModelMeta, register_megatron_model from .utils import HuggingFaceModule mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -49,24 +50,6 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds -def get_llama4_transformer_layer_spec(config, vp_stage=None): - args = get_args() - use_te = args.transformer_impl == 'transformer_engine' - kwargs = {'qk_l2_norm': args.qk_l2_norm, 'vp_stage': vp_stage} if mcore_013 else {} - # Define the decoder block spec - transformer_layer_spec = get_gpt_decoder_block_spec( - config, use_transformer_engine=use_te, normalization=args.normalization, **kwargs) - for i, layer_spec in enumerate(transformer_layer_spec.layer_specs): - global_i = i + get_transformer_layer_offset(config, vp_stage) - no_rope = config.no_rope_freq[global_i] - layer_spec = deepcopy(layer_spec) - if no_rope: - layer_spec.submodules.self_attention.submodules.q_layernorm = IdentityOp - layer_spec.submodules.self_attention.submodules.k_layernorm = IdentityOp - transformer_layer_spec.layer_specs[i] = layer_spec - return transformer_layer_spec - - class Llama4Bridge(GPTBridge): hf_layers_prefix = 'language_model.model.layers' hf_embed_key = 'language_model.model.embed_tokens.weight' @@ -75,11 +58,27 @@ class Llama4Bridge(GPTBridge): hf_score_key = 'language_model.score.weight' -register_megatron_model( - MegatronModelMeta( - MegatronModelType.llama4, [ - ModelType.llama4, - ], - bridge_cls=Llama4Bridge, - get_transformer_layer_spec=get_llama4_transformer_layer_spec, - visual_cls=Llama4Vit)) +class Llama4Loader(MegatronModelLoader): + bridge_cls = Llama4Bridge + visual_cls = Llama4Vit + + def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + layer_specs = super().get_transformer_layer_spec(vp_stage) + for i, layer_spec in enumerate(layer_specs.layer_specs): + global_i = i + get_transformer_layer_offset(self.config, vp_stage) + no_rope = self.config.no_rope_freq[global_i] + layer_spec = deepcopy(layer_spec) + if no_rope: + layer_spec.submodules.self_attention.submodules.q_layernorm = IdentityOp + layer_spec.submodules.self_attention.submodules.k_layernorm = IdentityOp + layer_specs.layer_specs[i] = layer_spec + return layer_specs + + +register_megatron_model(MegatronModelMeta( + MegatronModelType.llama4, + [ + ModelType.llama4, + ], + Llama4Loader, +)) diff --git a/swift/megatron/model/mm_gpts/qwen.py b/swift/megatron/model/mm_gpts/qwen.py index c59e28cfbe..dd519e8c9e 100644 --- a/swift/megatron/model/mm_gpts/qwen.py +++ b/swift/megatron/model/mm_gpts/qwen.py @@ -7,7 +7,7 @@ from swift.utils import get_env_args from ..constant import MegatronModelType from ..gpt_bridge import GPTBridge, MultimodalGPTBridge -from ..register import MegatronModelMeta, register_megatron_model +from ..register import MegatronModelLoader, MegatronModelMeta, register_megatron_model from .utils import HuggingFaceModule @@ -46,22 +46,36 @@ class Qwen2_5VLBridge(MultimodalGPTBridge): } -register_megatron_model( - MegatronModelMeta( - MegatronModelType.qwen2_5_vl, [ - ModelType.qwen2_5_vl, - ], bridge_cls=Qwen2_5VLBridge, visual_cls=Qwen2_5VL_Vit)) +class Qwen2_5VLLoader(MegatronModelLoader): + bridge_cls = Qwen2_5VLBridge + visual_cls = Qwen2_5VL_Vit + + +register_megatron_model(MegatronModelMeta( + MegatronModelType.qwen2_5_vl, + [ + ModelType.qwen2_5_vl, + ], + Qwen2_5VLLoader, +)) class Qwen2VL_Vit(Qwen2_5VL_Vit): version = 'v2' -register_megatron_model( - MegatronModelMeta( - MegatronModelType.qwen2_vl, [ - ModelType.qwen2_vl, - ], bridge_cls=Qwen2_5VLBridge, visual_cls=Qwen2VL_Vit)) +class Qwen2VLLoader(Qwen2_5VLLoader): + bridge_cls = Qwen2_5VLBridge + visual_cls = Qwen2VL_Vit + + +register_megatron_model(MegatronModelMeta( + MegatronModelType.qwen2_vl, + [ + ModelType.qwen2_vl, + ], + Qwen2VLLoader, +)) class Qwen2_5OmniBridge(GPTBridge): @@ -108,13 +122,19 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds +class Qwen2_5OmniLoader(MegatronModelLoader): + bridge_cls = Qwen2_5OmniBridge, + visual_cls = Qwen2_5Omni_Vit + + register_megatron_model( MegatronModelMeta( - MegatronModelType.qwen2_5_omni, [ + MegatronModelType.qwen2_5_omni, + [ ModelType.qwen2_5_omni, ], - bridge_cls=Qwen2_5OmniBridge, - visual_cls=Qwen2_5Omni_Vit)) + Qwen2_5OmniLoader, + )) class Ovis2_5Bridge(GPTBridge): @@ -164,8 +184,15 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds -register_megatron_model( - MegatronModelMeta( - MegatronModelType.ovis2_5, [ - ModelType.ovis2_5, - ], bridge_cls=Ovis2_5Bridge, visual_cls=Ovis2_5Vit)) +class Ovis2_5Loader(MegatronModelLoader): + bridge_cls = Ovis2_5Bridge + visual_cls = Ovis2_5Vit + + +register_megatron_model(MegatronModelMeta( + MegatronModelType.ovis2_5, + [ + ModelType.ovis2_5, + ], + Ovis2_5Loader, +)) diff --git a/swift/megatron/model/mm_gpts/qwen3_vl.py b/swift/megatron/model/mm_gpts/qwen3_vl.py index 0607bdd11e..1f18dc59f2 100644 --- a/swift/megatron/model/mm_gpts/qwen3_vl.py +++ b/swift/megatron/model/mm_gpts/qwen3_vl.py @@ -17,7 +17,7 @@ from ..constant import MegatronModelType from ..gpt_bridge import GPTBridge, MultimodalGPTBridge from ..mm_gpt_model import MultimodalGPTModel -from ..register import MegatronModelMeta, register_megatron_model +from ..register import MegatronModelLoader, MegatronModelMeta, register_megatron_model from .utils import HuggingFaceModule te_checkpoint = None @@ -453,19 +453,6 @@ def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torc return hidden_states -class Qwen3VLGPTModel(MultimodalGPTModel): - - def _patch_transformer_block(self): - if hasattr(gpt_model, 'OriginTransformerBlock'): - return - gpt_model.OriginTransformerBlock = gpt_model.TransformerBlock - gpt_model.TransformerBlock = Qwen3VLTransformerBlock - - def __init__(self, *args, **kwargs): - self._patch_transformer_block() - super().__init__(*args, **kwargs) - - class Qwen3OmniBridge(GPTBridge): hf_layers_prefix = 'thinker.model.layers' hf_embed_key = 'thinker.model.embed_tokens.weight' @@ -474,16 +461,6 @@ class Qwen3OmniBridge(GPTBridge): hf_score_key = 'thinker.score.weight' -register_megatron_model( - MegatronModelMeta( - MegatronModelType.qwen3_omni, [ - ModelType.qwen3_omni_moe, - ], - model_cls=Qwen3VLGPTModel, - bridge_cls=Qwen3OmniBridge, - visual_cls=Qwen3Omni_Vit)) - - class Qwen3VL_Vit(HuggingFaceModule): module_mapping = {'model.visual': 'visual'} _vision_tower = ['visual'] @@ -498,14 +475,44 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return Qwen3Omni_Vit._get_inputs_embeds(inputs_embeds, kwargs, self.visual, self.processor, self.model_config) +class Qwen3VLLoader(MegatronModelLoader): + bridge_cls = MultimodalGPTBridge + visual_cls = Qwen3VL_Vit + + def _patch_transformer_block(self): + if hasattr(gpt_model, 'OriginTransformerBlock'): + return + gpt_model.OriginTransformerBlock = gpt_model.TransformerBlock + gpt_model.TransformerBlock = Qwen3VLTransformerBlock + + def __init__(self, args, hf_config): + super().__init__(args, hf_config) + self._patch_transformer_block() + + register_megatron_model( MegatronModelMeta( - MegatronModelType.qwen3_vl, [ + MegatronModelType.qwen3_vl, + [ ModelType.qwen3_vl, ModelType.qwen3_vl_moe, ModelType.qwen3_vl_emb, ModelType.qwen3_vl_reranker, ], - model_cls=Qwen3VLGPTModel, - bridge_cls=MultimodalGPTBridge, - visual_cls=Qwen3VL_Vit)) + Qwen3VLLoader, + )) + + +class Qwen3OmniLoader(Qwen3VLLoader): + bridge_cls = Qwen3OmniBridge + visual_cls = Qwen3Omni_Vit + + +register_megatron_model( + MegatronModelMeta( + MegatronModelType.qwen3_omni, + [ + ModelType.qwen3_omni_moe, + ], + Qwen3OmniLoader, + )) diff --git a/swift/megatron/model/mm_gpts/utils.py b/swift/megatron/model/mm_gpts/utils.py index 1e6151ad4c..ba118b39ce 100644 --- a/swift/megatron/model/mm_gpts/utils.py +++ b/swift/megatron/model/mm_gpts/utils.py @@ -47,7 +47,7 @@ class HuggingFaceModule(_HuggingFaceModule, ABC): def __init__(self, config, ignore_init_model_cls=None): super().__init__(config) - args = get_args() + args = self.config.args attn_impl = getattr(args, 'attn_impl', None) or 'flash_attn' kwargs = {'attn_impl': attn_impl} if args.attention_backend.name == 'flash' else {} ignore_init_model_cls = ignore_init_model_cls or [] diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index cbd716dcc0..413ad944db 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -13,6 +13,7 @@ from swift.model import MODEL_MAPPING from swift.utils import get_logger from .constant import MLLMMegatronModelType +from .gpt_bridge import GPTBridge from .model_config import create_mcore_model_config if TYPE_CHECKING: @@ -21,7 +22,6 @@ MEGATRON_MODEL_MAPPING = {} logger = get_logger() -mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @dataclass @@ -32,12 +32,6 @@ class MegatronModelMeta: loader: Optional[Type['MegatronModelLoader']] = None is_multimodal: bool = False - # bridge_cls: Type[GPTBridge] = GPTBridge - # model_cls: Optional[Type[nn.Module]] = None - # get_transformer_layer_spec: Optional[Callable] = None - # visual_cls: Optional[Type[nn.Module]] = None - # get_mtp_block_spec: Optional[Callable] = None - def __post_init__(self): if self.megatron_model_type in MLLMMegatronModelType.__dict__: self.is_multimodal = True @@ -71,32 +65,33 @@ def get_megatron_model_meta(model_type: str) -> Optional[MegatronModelMeta]: class MegatronModelLoader: + model_cls = None + visual_cls = None + bridge_cls = GPTBridge def __init__(self, args, hf_config): self.args = args self.hf_config = hf_config self.config = create_mcore_model_config(args, hf_config) + self.mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + if self.model_cls is None: + self.model_cls = MultimodalGPTModel if self.args.is_multimodal else GPTModel self._check_npu() def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): args = self.args if self.config.num_experts: - kwargs = {'qk_l2_norm': self.config.qk_l2_norm, 'vp_stage': vp_stage} if mcore_013 else {} + kwargs = {'qk_l2_norm': self.config.qk_l2_norm, 'vp_stage': vp_stage} if self.mcore_013 else {} # Define the decoder block spec transformer_layer_spec = get_gpt_decoder_block_spec( self.config, use_transformer_engine=True, normalization=self.config.normalization, **kwargs) else: transformer_layer_spec = self._get_transformer_layer_spec() - - if args.use_shared_expert_gate and args.num_experts and args.moe_shared_expert_intermediate_size: - for layer_spec in transformer_layer_spec.layer_specs: - if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): - layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} return transformer_layer_spec def _get_transformer_layer_spec(self): config = self.config - kwargs = {'qk_l2_norm': config.qk_l2_norm} if mcore_013 else {} + kwargs = {'qk_l2_norm': config.qk_l2_norm} if self.mcore_013 else {} transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( config.num_experts, self.args.moe_grouped_gemm, @@ -107,33 +102,42 @@ def _get_transformer_layer_spec(self): return transformer_layer_spec def get_mtp_block_spec(self, transformer_layer_spec, vp_stage: Optional[int] = None): - if self.args.mtp_num_layers is None: - return if hasattr(transformer_layer_spec, 'layer_specs') and len(transformer_layer_spec.layer_specs) == 0: # Get the decoder layer spec explicitly if no decoder layer in the last stage, # Only happens with block spec (TransformerBlockSubmodules) when using MoE. + # TODO: remove transformer_layer_spec_for_mtp = self._get_transformer_layer_spec() else: transformer_layer_spec_for_mtp = transformer_layer_spec - kwargs = {'vp_stage': vp_stage} if mcore_013 else {} - + kwargs = {'vp_stage': vp_stage} if self.mcore_013 else {} return get_gpt_mtp_block_spec( self.config, transformer_layer_spec_for_mtp, use_transformer_engine=True, **kwargs) + def _set_shared_expert_gate(self, transformer_layer_spec): + if (self.config.use_shared_expert_gate and self.config.num_experts + and self.config.moe_shared_expert_intermediate_size): + for layer_spec in transformer_layer_spec.layer_specs: + if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): + layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} + def create_model_and_load( self, pre_process=True, post_process=True, vp_stage: Optional[int] = None, - ) -> Union[GPTModel, MultimodalGPTModel]: + ) -> Union['GPTModel', 'MultimodalGPTModel']: transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) - mtp_block_spec = self.get_mtp_block_spec(transformer_layer_spec, vp_stage=vp_stage) - return self._create_model( + self._set_shared_expert_gate(transformer_layer_spec) + mtp_block_spec = None + if self.args.mtp_num_layers is not None: + mtp_block_spec = self.get_mtp_block_spec(transformer_layer_spec, vp_stage=vp_stage) + model = self._create_model( transformer_layer_spec, mtp_block_spec, pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) + return model def _check_npu(self): MAX_NPU_EXPERTS_PER_EP = 128 @@ -154,11 +158,7 @@ def _create_model(self, pre_process=True, post_process=True, vp_stage: Optional[int] = None): - if self.args.is_multimodal: - model_cls = MultimodalGPTModel - else: - model_cls = GPTModel - return model_cls( + return self.model_cls( config=self.config, transformer_layer_spec=transformer_layer_spec, pre_process=pre_process, From 5014e4c2f0c8707966c9e3ccacc58ef23fcbaa35 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 4 Feb 2026 18:01:40 +0800 Subject: [PATCH 20/43] fix --- swift/megatron/model/mm_gpts/glm.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/swift/megatron/model/mm_gpts/glm.py b/swift/megatron/model/mm_gpts/glm.py index 65a95ea2bc..55aacf082a 100644 --- a/swift/megatron/model/mm_gpts/glm.py +++ b/swift/megatron/model/mm_gpts/glm.py @@ -4,7 +4,7 @@ from ..constant import MegatronModelType from ..gpt_bridge import MultimodalGPTBridge from ..gpts.glm4 import Glm4Bridge, Glm4Loader -from ..register import MegatronModelMeta, register_megatron_model +from ..register import MegatronModelLoader, MegatronModelMeta, register_megatron_model from .utils import HuggingFaceModule @@ -21,11 +21,18 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return Template._get_inputs_embeds_hf(inputs_embeds, kwargs, self.visual, self.processor, self.model_config) -register_megatron_model( - MegatronModelMeta( - MegatronModelType.glm4v_moe, [ - ModelType.glm4v_moe, - ], bridge_cls=MultimodalGPTBridge, visual_cls=Glm4vVit)) +class Glm4vMoeLoader(MegatronModelLoader): + bridge_cls = MultimodalGPTBridge + visual_cls = Glm4vVit + + +register_megatron_model(MegatronModelMeta( + MegatronModelType.glm4v_moe, + [ + ModelType.glm4v_moe, + ], + Glm4vMoeLoader, +)) class Glm4vBridge(Glm4Bridge, MultimodalGPTBridge): @@ -42,4 +49,5 @@ class Glm4vLoader(Glm4Loader): [ ModelType.glm4v, ], + Glm4vLoader, )) From 46a23cc4a88d01cdc0054203e519c20beb11a7ce Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 4 Feb 2026 18:41:06 +0800 Subject: [PATCH 21/43] update --- swift/megatron/arguments/megatron_base_args.py | 8 -------- swift/megatron/convert.py | 9 --------- 2 files changed, 17 deletions(-) diff --git a/swift/megatron/arguments/megatron_base_args.py b/swift/megatron/arguments/megatron_base_args.py index 98e3bb7886..44cebac027 100644 --- a/swift/megatron/arguments/megatron_base_args.py +++ b/swift/megatron/arguments/megatron_base_args.py @@ -27,19 +27,11 @@ def __post_init__(self): self.num_workers = 1 logger.info('Using streaming dataset, setting args.num_workers to 1.') - @staticmethod - def _check_megatron_kwargs(kwargs): - # Make sure that the keys in kwargs have default values of None in MegatronArguments. - default_mapping = {field.name: field.default for field in fields(MegatronArguments)} - for k in kwargs.keys(): - assert default_mapping[k] is None - def init_model_args(self, tokenizer, config): if self.task_type == 'seq_cls': self.problem_type = self.problem_type or getattr(config, 'problem_type', None) logger.info(f'args.problem_type: {self.problem_type}') kwargs = convert_hf_config(config) - self._check_megatron_kwargs(kwargs) if tokenizer is not None and self.new_special_tokens and kwargs['padded_vocab_size'] < len(tokenizer): kwargs['padded_vocab_size'] = math.ceil(len(tokenizer) / 128) * 128 self.initialize_embedding = True diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index ffe8099fcf..9c4cbda7b4 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -29,13 +29,6 @@ } -def _check_megatron_kwargs(kwargs): - # Make sure that the keys in kwargs have default values of None in MegatronArguments. - default_mapping = {field.name: field.default for field in fields(MegatronArguments)} - for k in kwargs.keys(): - assert default_mapping[k] is None - - def convert_hf2mcore(args: ExportArguments) -> None: hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) processor = template.processor @@ -48,7 +41,6 @@ def convert_hf2mcore(args: ExportArguments) -> None: assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' kwargs = convert_hf_config(processor.model_info.config) logger.info(f'megatron_config: {kwargs}') - _check_megatron_kwargs(kwargs) current_convert_kwargs = convert_kwargs.copy() if args.model_info.is_moe_model: current_convert_kwargs['moe_grouped_gemm'] = True @@ -85,7 +77,6 @@ def convert_mcore2hf(args: ExportArguments) -> None: hf_config = processor.model_info.config kwargs = convert_hf_config(hf_config) logger.info(f'megatron_config: {kwargs}') - _check_megatron_kwargs(kwargs) current_convert_kwargs = convert_kwargs.copy() if args.model_info.is_moe_model: current_convert_kwargs['moe_grouped_gemm'] = True From 2cf83b9971b67ad5b5a7f3ad2872795241ff8a3c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Feb 2026 10:41:09 +0800 Subject: [PATCH 22/43] update --- swift/megatron/model/gpt_bridge.py | 4 ++++ swift/megatron/model/model_config.py | 2 +- swift/megatron/model/register.py | 23 ++++++++++++++++++++++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 144056ae0a..b24d8de91e 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -9,6 +9,7 @@ import torch.nn.functional as F import transformers from megatron.core import mpu +from megatron.core.utils import unwrap_model from packaging import version from peft.utils import ModulesToSaveWrapper from tqdm import tqdm @@ -1445,6 +1446,7 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: def load_weights(self, mg_model, hf_model_dir: str, is_peft_format: bool = False, adapter_name: str = 'default'): self._is_peft_format = is_peft_format self._adapter_name = adapter_name + mg_model = unwrap_model(mg_model) hf_model_dir = safe_snapshot_download(hf_model_dir, use_hf=self.args.use_hf, hub_token=self.args.hub_token) with torch.no_grad(), SafetensorLazyLoader(hf_model_dir, is_peft_format=is_peft_format) as loader: state_dict = loader.get_state_dict() @@ -1464,6 +1466,7 @@ def export_weights(self, self._peft_target_modules = set() self._peft_modules_to_save = set() hf_prefix = 'base_model.model.' if is_peft_format else '' + mg_models = [unwrap_model(mg_model) for mg_model in mg_models] with torch.no_grad(): yield from self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc) @@ -1477,6 +1480,7 @@ def save_weights(self, torch.cuda.empty_cache() saver = StreamingSafetensorSaver( save_dir=output_dir, max_shard_size=self.args.max_shard_size, is_peft_format=is_peft_format) + mg_models = [unwrap_model(mg_model) for mg_model in mg_models] for k, v in self.export_weights( mg_models, target_device='cpu', only_last_rank=True, is_peft_format=is_peft_format, tqdm_desc='Saving: '): diff --git a/swift/megatron/model/model_config.py b/swift/megatron/model/model_config.py index c657fad4f1..f0197ce378 100644 --- a/swift/megatron/model/model_config.py +++ b/swift/megatron/model/model_config.py @@ -29,7 +29,7 @@ class MegatronModelConfig(TransformerConfig): window_attn_skip_freq: Optional[str] = None max_position_embeddings: Optional[int] = None - position_embedding_type: Literal['learned_absolute', 'rope', 'mrope', 'none'] = 'rope', + position_embedding_type: Literal['learned_absolute', 'rope', 'mrope', 'none'] = 'rope' rotary_base: int = 10000 rotary_percent: float = 1. rotary_interleaved: bool = False diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 413ad944db..bd441dc4b2 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, List, Optional, Type, Union import megatron.core +from megatron.core import mpu from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, get_gpt_layer_with_transformer_engine_spec, get_gpt_mtp_block_spec) @@ -77,12 +78,12 @@ def __init__(self, args, hf_config): if self.model_cls is None: self.model_cls = MultimodalGPTModel if self.args.is_multimodal else GPTModel self._check_npu() + self.bridge = self.bridge_cls(args) def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): args = self.args if self.config.num_experts: kwargs = {'qk_l2_norm': self.config.qk_l2_norm, 'vp_stage': vp_stage} if self.mcore_013 else {} - # Define the decoder block spec transformer_layer_spec = get_gpt_decoder_block_spec( self.config, use_transformer_engine=True, normalization=self.config.normalization, **kwargs) else: @@ -137,6 +138,7 @@ def create_model_and_load( pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) + self.bridge.load_weights(model, self.args.model_dir) return model def _check_npu(self): @@ -166,3 +168,22 @@ def _create_model(self, mtp_block_spec=mtp_block_spec, vp_stage=vp_stage, ) + + +def get_mcore_model( + args, + hf_config, +): + loader = args.megatron_model_meta.loader(args, hf_config) + if (mpu.get_pipeline_model_parallel_world_size() > 1 and args.virtual_pipeline_model_parallel_size is not None): + models = [] + for i in range(args.virtual_pipeline_model_parallel_size): + pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) + post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) + model = loader.create_model_and_load(pre_process, post_process, vp_stage=i) + models.append(model) + else: + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + models = [loader.create_model_and_load(pre_process=pre_process, post_process=post_process)] + return models From c2b385a9afb0e2d02cec719823831078df98def0 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Feb 2026 10:46:09 +0800 Subject: [PATCH 23/43] update --- swift/megatron/arguments/megatron_args.py | 2 ++ swift/megatron/convert.py | 17 ++++------------- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index b609d7bc7a..a619ff5369 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -570,6 +570,8 @@ def __post_init__(self): self.hf_model_type = self.model_type = self.model_info.model_type self.model_dir = self.model_info.model_dir self.is_multimodal = self.model_meta.is_multimodal + self.megatron_model_meta = get_megatron_model_meta(self.model_type) + assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' if self.apply_wd_to_qk_layernorm and self.hf_model_type != 'qwen3_next': raise ValueError('apply_wd_to_qk_layernorm is only supported for qwen3_next') if self.pipeline_model_parallel_size == 1 and (self.decoder_first_pipeline_num_layers is not None diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 9c4cbda7b4..1da3930f08 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -13,8 +13,8 @@ from swift.utils import get_logger, get_n_params_grads, is_master from .arguments import MegatronArguments from .model import get_megatron_model_meta -from .utils import (convert_hf_config, initialize_megatron, load_mcore_checkpoint, patch_torch_dist_shard, - save_mcore_checkpoint, test_convert_precision) +from .utils import (initialize_megatron, load_mcore_checkpoint, patch_torch_dist_shard, save_mcore_checkpoint, + test_convert_precision) logger = get_logger() @@ -37,17 +37,13 @@ def convert_hf2mcore(args: ExportArguments) -> None: args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB patch_torch_dist_shard(args.thread_count) - megatron_model_meta = get_megatron_model_meta(args.model_type) - assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' - kwargs = convert_hf_config(processor.model_info.config) - logger.info(f'megatron_config: {kwargs}') + hf_config = processor.model_info.config current_convert_kwargs = convert_kwargs.copy() if args.model_info.is_moe_model: current_convert_kwargs['moe_grouped_gemm'] = True megatron_args = MegatronArguments( model=args.model, model_type=args.model_type, - **kwargs, **current_convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype) @@ -72,11 +68,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: _, template = prepare_model_template(args, load_model=False) processor = template.processor - megatron_model_meta = get_megatron_model_meta(args.model_type) - assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' hf_config = processor.model_info.config - kwargs = convert_hf_config(hf_config) - logger.info(f'megatron_config: {kwargs}') current_convert_kwargs = convert_kwargs.copy() if args.model_info.is_moe_model: current_convert_kwargs['moe_grouped_gemm'] = True @@ -85,11 +77,10 @@ def convert_mcore2hf(args: ExportArguments) -> None: extra_config['adapter_load'] = adapter_load if args.mcore_model is not None: extra_config['load'] = args.mcore_model - kwargs.update(extra_config) + current_convert_kwargs.update(extra_config) megatron_args = MegatronArguments( model=args.model, model_type=args.model_type, - **kwargs, **current_convert_kwargs, save=args.output_dir if args.to_mcore else None, torch_dtype=args.torch_dtype) From 5503a1b19008bcc3005d5290b3c11f5df6886350 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Feb 2026 14:13:43 +0800 Subject: [PATCH 24/43] update --- .../Megatron-SWIFT/Command-line-parameters.md | 4 +- .../Megatron-SWIFT/Command-line-parameters.md | 4 +- swift/megatron/arguments/megatron_args.py | 149 ++++++++---------- swift/megatron/convert.py | 9 +- swift/megatron/init.py | 6 +- swift/megatron/model/__init__.py | 5 +- swift/megatron/model/gpt_bridge.py | 107 +++++++------ swift/megatron/model/gpt_model.py | 24 ++- swift/megatron/model/mm_gpts/glm.py | 2 +- swift/megatron/model/mm_gpts/internvl.py | 10 +- swift/megatron/model/mm_gpts/llama4.py | 2 +- swift/megatron/model/mm_gpts/qwen.py | 4 +- swift/megatron/model/mm_gpts/qwen3_vl.py | 4 +- swift/megatron/model/mm_gpts/utils.py | 11 +- swift/megatron/model/model_config.py | 69 +++++--- swift/megatron/model/register.py | 31 ++-- swift/megatron/model/rope.py | 30 ++-- swift/megatron/pipelines/export/export.py | 2 +- swift/megatron/trainers/base.py | 2 +- swift/megatron/trainers/gkd_trainer.py | 6 +- swift/megatron/utils/config.py | 2 +- swift/megatron/utils/megatron_lm_utils.py | 2 +- 22 files changed, 247 insertions(+), 238 deletions(-) diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 669cb960e4..9b3adc7238 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -199,9 +199,9 @@ **MoE参数**: -- num_experts: MoE的专家数,默认为None。自动从config.json读取。 +- num_moe_experts: MoE的专家数,默认为None。自动从config.json读取。 - moe_layer_freq: MoE 层与 Dense 层之间的分布频率。默认为None。从config.json中读取。 -- moe_ffn_hidden_size: 每个专家的前馈网络(ffn)的隐藏层大小。默认为None,自动从config.json读取。若未读取到且`num_experts`不为None,则设置为ffn_hidden_size。 +- moe_ffn_hidden_size: 每个专家的前馈网络(ffn)的隐藏层大小。默认为None,自动从config.json读取。若未读取到且`num_moe_experts`不为None,则设置为ffn_hidden_size。 - moe_shared_expert_intermediate_size: 共享专家的总FFN隐藏层大小。如果有多个共享专家,它应等于 `num_shared_experts * ffn_size_of_each_shared_expert`。 默认为None。自动从config.json读取。 - moe_router_topk: 每个token路由到的专家数量。默认为None。自动从config.json读取。 - moe_router_num_groups: 将专家分成的组数,用于组限制路由。参考DeepSeek-V2和DeepSeek-V3。默认为None。自动从config.json读取。 diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index e0899ad864..47813eb960 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -211,9 +211,9 @@ For guidance on selecting parallelization strategies, please refer to the [Train **MoE Parameters**: -- num_experts: The number of experts in MoE, default is None. Automatically read from config.json. +- num_moe_experts: The number of experts in MoE, default is None. Automatically read from config.json. - moe_layer_freq: Frequency distribution between MoE layers and Dense layers. Default is None. This parameter is read from config.json. -- moe_ffn_hidden_size: Hidden layer size of the feedforward network (ffn) for each expert. Default is None and will be automatically read from config.json. If not found and `num_experts` is not None, it will be set to ffn_hidden_size. +- moe_ffn_hidden_size: Hidden layer size of the feedforward network (ffn) for each expert. Default is None and will be automatically read from config.json. If not found and `num_moe_experts` is not None, it will be set to ffn_hidden_size. - moe_shared_expert_intermediate_size: The total FFN hidden layer size for shared experts. If there are multiple shared experts, it should equal `num_shared_experts * ffn_size_of_each_shared_expert`. Default is None. Automatically read from config.json. - moe_router_topk: The number of experts each token is routed to. Default is None. Automatically read from config.json. - moe_router_num_groups: Number of groups to divide experts into for group-limited routing. Refers to DeepSeek-V2 and DeepSeek-V3. Default is None. Automatically read from config.json. diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index a619ff5369..fe99cd0c8f 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -12,6 +12,7 @@ from transformers.utils.versions import require_version from swift.arguments import ModelArguments +from swift.megatron.model import get_megatron_model_meta from swift.megatron.utils import initialize_megatron from swift.model import get_model_info_meta from swift.utils import get_dist_setting, get_logger, json_parse_to_dict @@ -316,71 +317,7 @@ def __post_init__(self): @dataclass -class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): - check_model: bool = True - initialize_embedding: bool = False - torch_dtype: Optional[Union[torch.dtype, str]] = None - padding_free: bool = True - mlp_padding_free: bool = False - # mcore-bridge - model: Optional[str] = None - model_type: Optional[str] = None - load_safetensors: Optional[bool] = None - save_safetensors: bool = True - adapters: List[str] = field(default_factory=list) - ref_model: Optional[str] = None - ref_adapters: List[str] = field(default_factory=list) - use_hf: bool = False - # None: use env var `MODELSCOPE_API_TOKEN` - hub_token: Optional[str] = field( - default=None, metadata={'help': 'SDK token can be found in https://modelscope.cn/my/myaccesstoken'}) - merge_lora: Optional[bool] = None - max_shard_size: str = '5GB' - - # dataloader - train_dataloader_shuffle: bool = True - dataloader_pin_memory: bool = True - dataloader_persistent_workers: bool = True - dataloader_prefetch_factor: int = 2 - group_by_length: bool = False - - max_epochs: Optional[int] = None - enable_dft_loss: bool = False - enable_channel_loss: bool = False - save_strategy: Literal['steps', 'epoch'] = 'steps' - - report_to: Optional[Literal['wandb', 'swanlab']] = None - - # visual - vit_gradient_checkpointing: bool = True - vit_lr: Optional[float] = None - aligner_lr: Optional[float] = None - gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None - - @staticmethod - def load_args_config(ckpt_dir: Optional[str]) -> Dict[str, Any]: - res = {} - if ckpt_dir is None: - return res - args_path = os.path.join(ckpt_dir, 'args.json') - if os.path.exists(args_path): - with open(args_path, 'r', encoding='utf-8') as f: - old_args = json.load(f) - keys = list(f.name for f in fields(MegatronTunerMixin)) - # TODO: remove load/save - keys += ['load', 'padded_vocab_size', 'task_type', 'num_labels'] # TODO: padded_vocab_size - for key in keys: - old_value = old_args.get(key) - if old_value is not None: - res[key] = old_value - res.pop('adapter_load', None) - if res['tuner_type'] != 'lora': - res.pop('load', None) - return res - - -@dataclass -class MegatronArguments(ExtraMegatronArguments): +class MegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): # training micro_batch_size: int = 1 global_batch_size: int = 16 @@ -526,6 +463,71 @@ class MegatronArguments(ExtraMegatronArguments): num_workers: int = 4 data_sharding: bool = False + check_model: bool = True + initialize_embedding: bool = False + torch_dtype: Optional[Union[torch.dtype, str]] = None + padding_free: bool = True + mlp_padding_free: bool = False + + # mcore-bridge + model: Optional[str] = None + model_type: Optional[str] = None + load_safetensors: Optional[bool] = None + save_safetensors: bool = True + adapters: List[str] = field(default_factory=list) + ref_model: Optional[str] = None + ref_adapters: List[str] = field(default_factory=list) + use_hf: bool = False + # None: use env var `MODELSCOPE_API_TOKEN` + hub_token: Optional[str] = field( + default=None, metadata={'help': 'SDK token can be found in https://modelscope.cn/my/myaccesstoken'}) + merge_lora: Optional[bool] = None + max_shard_size: str = '5GB' + + # dataloader + train_dataloader_shuffle: bool = True + dataloader_pin_memory: bool = True + dataloader_persistent_workers: bool = True + dataloader_prefetch_factor: int = 2 + group_by_length: bool = False + + max_epochs: Optional[int] = None + enable_dft_loss: bool = False + enable_channel_loss: bool = False + task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'generative_reranker'] = None + num_labels: Optional[int] = None + problem_type: Literal['regression', 'single_label_classification', 'multi_label_classification'] = None + save_strategy: Literal['steps', 'epoch'] = 'steps' + + report_to: Optional[Literal['wandb', 'swanlab']] = None + + # visual + vit_gradient_checkpointing: bool = True + vit_lr: Optional[float] = None + aligner_lr: Optional[float] = None + gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None + + @staticmethod + def load_args_config(ckpt_dir: Optional[str]) -> Dict[str, Any]: + res = {} + if ckpt_dir is None: + return res + args_path = os.path.join(ckpt_dir, 'args.json') + if os.path.exists(args_path): + with open(args_path, 'r', encoding='utf-8') as f: + old_args = json.load(f) + keys = list(f.name for f in fields(MegatronTunerMixin)) + # TODO: remove load/save + keys += ['load', 'padded_vocab_size', 'task_type', 'num_labels'] # TODO: padded_vocab_size + for key in keys: + old_value = old_args.get(key) + if old_value is not None: + res[key] = old_value + res.pop('adapter_load', None) + if res['tuner_type'] != 'lora': + res.pop('load', None) + return res + def _set_default(self): if self.mlp_padding_free and (self.sequence_parallel or self.context_parallel_size > 1): raise ValueError('mlp_padding_free is not compatible with sequence parallel or context parallel.') @@ -567,12 +569,12 @@ def __post_init__(self): self.model, model_type=self.model_type, use_hf=self.use_hf, hub_token=self.hub_token) # Megatron has a model_type parameter with the same name, so we need to avoid conflicts. - self.hf_model_type = self.model_type = self.model_info.model_type + self.model_type = self.model_info.model_type self.model_dir = self.model_info.model_dir self.is_multimodal = self.model_meta.is_multimodal self.megatron_model_meta = get_megatron_model_meta(self.model_type) - assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' - if self.apply_wd_to_qk_layernorm and self.hf_model_type != 'qwen3_next': + assert self.megatron_model_meta is not None, f'Model: {args.model} is not supported.' + if self.apply_wd_to_qk_layernorm and self.model_type != 'qwen3_next': raise ValueError('apply_wd_to_qk_layernorm is only supported for qwen3_next') if self.pipeline_model_parallel_size == 1 and (self.decoder_first_pipeline_num_layers is not None or self.decoder_last_pipeline_num_layers is not None): @@ -608,19 +610,8 @@ def __post_init__(self): self._load_adapter_config() self._init_mixed_precision() - self._init_apply_rope_fusion() initialize_megatron(self) - def _init_apply_rope_fusion(self): - if self.apply_rope_fusion is not None: - return - if self.multi_latent_attention or self.rotary_interleaved: - # Upgrading transformer_engine requires checking here. - self.apply_rope_fusion = False - else: - self.apply_rope_fusion = True - logger.info(f'Setting args.apply_rope_fusion: {self.apply_rope_fusion}.') - def _init_vpp_size(self): # TODO self.virtual_pipeline_model_parallel_size = None diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 1da3930f08..249f4ac75c 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -12,7 +12,7 @@ from swift.pipelines import prepare_model_template from swift.utils import get_logger, get_n_params_grads, is_master from .arguments import MegatronArguments -from .model import get_megatron_model_meta +from .model import get_mcore_model, get_megatron_model_meta from .utils import (initialize_megatron, load_mcore_checkpoint, patch_torch_dist_shard, save_mcore_checkpoint, test_convert_precision) @@ -48,11 +48,8 @@ def convert_hf2mcore(args: ExportArguments) -> None: save=args.output_dir, torch_dtype=args.torch_dtype) - mg_model = megatron_model_meta.model_provider(megatron_args) + mg_model = get_mcore_model(megatron_args, hf_config)[0] logger.info('Megatron model created successfully.') - bridge = megatron_model_meta.bridge_cls(megatron_args) - bridge.load_weights(mg_model, args.model_info.model_dir) - logger.info('Successfully transferred HF model weights to MG model.') _test_convert_precision = strtobool(os.getenv('SWIFT_TEST_CONVERT_PRECISION', '0')) if not _test_convert_precision: args.save_args() @@ -99,7 +96,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: if args.to_hf: bridge = megatron_model_meta.bridge_cls(megatron_args) logger.info('Converting weights and saving the model...') - bridge.save_weights([mg_model], args.output_dir, processor=processor, config=hf_config) + bridge.save_weights([mg_model], args.output_dir, processor=processor, hf_config=hf_config) if is_master(): args_path = os.path.join(megatron_args.adapter_load or megatron_args.load or args.model, 'args.json') if os.path.exists(args_path): diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 5004c3fef2..2a4ced5bda 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -532,8 +532,7 @@ def forward(self, *_args, **kwargs): if not mcore_013: return _origin_forward(self, *_args, **kwargs) hidden_states, context = self._forward_attention(*_args, **kwargs) - args = self.config.args - mlp_padding_free = args.mlp_padding_free and 'attention_mask' in kwargs + mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs mask = None if mlp_padding_free and hidden_states.shape[1] > 1: mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t() @@ -697,8 +696,7 @@ def forward(self, position_ids, mrope_section: List[int], packed_seq: bool = Fal seq_expanded = seq[:, :, None, :].float() # shape (3, bs, seq_length, dim) freqs = (inv_freq_expanded @ seq_expanded).transpose(2, 3) - args = self.config.args - if args.mrope_interleaved: + if self.config.mrope_interleaved: freqs = apply_interleaved_mrope(freqs, mrope_section) emb = torch.cat((freqs, freqs), dim=-1) else: diff --git a/swift/megatron/model/__init__.py b/swift/megatron/model/__init__.py index b0761b4f19..b9c80b24ed 100644 --- a/swift/megatron/model/__init__.py +++ b/swift/megatron/model/__init__.py @@ -1,4 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from . import gpts, mm_gpts from .constant import MegatronModelType -from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model +from .gpt_bridge import GPTBridge +from .gpt_model import GPTModel +from .mm_gpt_model import MultimodalGPTModel +from .register import MegatronModelMeta, get_mcore_model, get_megatron_model_meta, register_megatron_model diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index b24d8de91e..3af00395ff 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -46,11 +46,12 @@ def __init__(self, args, disable_tqmd: bool = False): self._peft_modules_to_save = set() self._is_peft_format = False self._adapter_name = 'default' + self.config = None self._init_meta_hf_model() self.hf_layers = deep_getattr(self.hf_model, self.hf_layers_prefix) self.module_mapping = {} self.mcore_014 = version.parse(megatron.core.__version__) >= version.parse('0.14.0rc0') - megatron_model_meta = get_megatron_model_meta(self.args.hf_model_type) + megatron_model_meta = get_megatron_model_meta(self.args.model_type) if self.args.is_multimodal and megatron_model_meta.visual_cls is not None: self.module_mapping = megatron_model_meta.visual_cls.module_mapping self.tp_size = self.args.tensor_model_parallel_size @@ -108,7 +109,7 @@ def _get_hf_mlp(self, layer_idx): def _init_meta_hf_model(self): with torch.device('meta'): self.hf_model, self.processor = get_model_processor( - self.args.model_dir, model_type=self.args.hf_model_type, return_dummy_model=True) + self.args.model_dir, model_type=self.args.model_type, return_dummy_model=True) def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: if mg_key is None: @@ -376,7 +377,7 @@ def _get_weight( tensor = torch.concat(tensor, dim=0) if mg_scale_inv is not None: mg_scale_inv = torch.concat(mg_scale_inv, dim=0) - num_local_experts = self.args.num_experts // self.ep_size if is_expert else 1 + num_local_experts = self.config.num_moe_experts // self.ep_size if is_expert else 1 tp_dim = self._get_tp_split_dim(mg_key) is_linear_fc1 = (mg_key is not None and mg_key.split('.', 1)[0] == 'linear_fc1' and tp_dim is not None) if tensor is not None and is_linear_fc1: @@ -396,8 +397,8 @@ def _get_weight( assert mg_scale_inv is None, f'mg_key: {mg_key}' tensor = tensor + offset is_embedding = mg_key in {'embedding.word_embeddings.weight', 'output_layer.weight'} - if is_embedding and self.args.padded_vocab_size < tensor.shape[0]: - tensor = tensor[:self.args.padded_vocab_size] + if is_embedding and self.config.padded_vocab_size < tensor.shape[0]: + tensor = tensor[:self.config.padded_vocab_size] if self._target_device is not None: tensor = tensor.to(device=self._target_device) if mg_scale_inv is not None: @@ -471,8 +472,8 @@ def _set_state_dict(self, hf_weight = hf_state_dict[hf_key].load() if module_key in { 'embedding.word_embeddings', 'output_layer' - } and hf_weight.shape[0] < self.args.padded_vocab_size and self.args.task_type != 'seq_cls': - hf_weight = F.pad(hf_weight, (0, 0, 0, self.args.padded_vocab_size - hf_weight.shape[0])) + } and hf_weight.shape[0] < self.config.padded_vocab_size and self.args.task_type != 'seq_cls': + hf_weight = F.pad(hf_weight, (0, 0, 0, self.config.padded_vocab_size - hf_weight.shape[0])) hf_scale_inv = None if f'{hf_key}_scale_inv' in hf_state_dict: hf_scale_inv = hf_state_dict[f'{hf_key}_scale_inv'].load() @@ -518,9 +519,10 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int else: hf_state_dict = {} hf_attn = self.hf_layers[layer_idx].self_attn - args = self.args - num_query_groups = (args.num_query_groups if args.num_query_groups is not None else args.num_attention_heads) - hidden_size_block = args.hidden_size // self.fp8_block_size + config = self.config + num_query_groups = ( + config.num_query_groups if config.num_query_groups is not None else config.num_attention_heads) + hidden_size_block = config.hidden_size // self.fp8_block_size if to_mcore: if isinstance(mg_attn.linear_qkv, LoraParallelLinear): lora_A = hf_state_dict['q_proj.lora_A.weight'].load() @@ -540,11 +542,11 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int 'linear_qkv.lora_B.weight') elif not self._is_peft_format: linear_qkv_weight = torch.cat([ - hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), - hf_state_dict['k_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), - hf_state_dict['v_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), + hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, config.hidden_size)), + hf_state_dict['k_proj.weight'].load().reshape((num_query_groups, -1, config.hidden_size)), + hf_state_dict['v_proj.weight'].load().reshape((num_query_groups, -1, config.hidden_size)), ], - dim=1).reshape((-1, args.hidden_size)) + dim=1).reshape((-1, config.hidden_size)) qkv_scale_inv = None if 'q_proj.weight_scale_inv' in hf_state_dict: qkv_scale_inv = torch.cat([ @@ -589,13 +591,13 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int mg_attn_weight, scale_inv = self._get_weight( None if mg_attn is None else mg_attn.linear_qkv.weight.data, 'linear_qkv.weight') if mg_attn_weight is not None: - mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, args.hidden_size)) - hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size).clone() - hf_state_dict['k_proj.weight'] = mg_attn_weight[:, - q_dim:-kv_dim, :].reshape(-1, - args.hidden_size).clone() + mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, config.hidden_size)) + hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, + config.hidden_size).clone() + hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape( + -1, config.hidden_size).clone() hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, - args.hidden_size).clone() + config.hidden_size).clone() if scale_inv is not None: scale_inv = scale_inv.reshape((num_query_groups, -1, hidden_size_block)) hf_state_dict['q_proj.weight_scale_inv'] = scale_inv[:, :q_block, :].reshape( @@ -606,11 +608,11 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int -1, hidden_size_block).clone() del mg_attn_weight self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore) - if args.add_bias_linear: + if config.add_bias_linear: self._set_state_dict(mg_attn, 'linear_proj.bias', hf_state_dict, 'o_proj.bias', to_mcore) # Copy bias - if (args.add_bias_linear or args.add_qkv_bias) and not self._is_peft_format: + if (config.add_bias_linear or config.add_qkv_bias) and not self._is_peft_format: if to_mcore: linear_qkv_bias = torch.cat([ hf_state_dict['q_proj.bias'].load().reshape((num_query_groups, -1)), @@ -627,9 +629,9 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone() hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1).clone() hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone() - if getattr(args, 'softmax_type', 'vanilla') == 'learnable': + if getattr(config, 'softmax_type', 'vanilla') == 'learnable': self._set_state_dict(mg_attn, 'core_attention.softmax_offset', hf_state_dict, 'sinks', to_mcore) - if args.qk_layernorm: + if config.qk_layernorm: self._set_qk_layernorm(mg_attn, hf_attn, hf_state_dict, to_mcore) if to_mcore: hf_state_dict = {} @@ -662,7 +664,7 @@ def _set_moe_state( hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) else: hf_state_dict = {} - args = self.args + config = self.config hf_mlp = self._get_hf_mlp(layer_idx) if hasattr(hf_mlp, 'router'): hf_gate_key = 'router.weight' @@ -671,13 +673,13 @@ def _set_moe_state( else: hf_gate_key = 'gate.weight' self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore) - if args.add_bias_linear: + if config.add_bias_linear: self._set_state_dict(mg_mlp, 'router.bias', hf_state_dict, hf_gate_key.replace('weight', 'bias'), to_mcore) - if args.moe_router_enable_expert_bias: + if config.moe_router_enable_expert_bias: hf_bias_key = self.get_e_score_correction_bias_key(hf_mlp) self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, hf_bias_key, to_mcore) - if args.moe_shared_expert_intermediate_size: + if config.moe_shared_expert_intermediate_size: for key in ['shared_expert', 'shared_experts', 'shared_mlp']: if hasattr(hf_mlp, key): hf_shared_expert_prefix = f'{key}.' @@ -710,7 +712,7 @@ def _set_moe_state( return hf_state_dict def _get_hf_grouped(self): - if self.args.hf_model_type in { + if self.args.model_type in { 'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'dots1', 'ernie4_5_moe', 'glm4_moe', 'glm4_moe_lite', 'glm4v_moe', 'minimax_m2', 'olmoe', 'qwen3_next', 'kimi_vl', 'qwen3_omni_moe', 'qwen3_vl_moe' @@ -735,7 +737,7 @@ def _set_mlp_state(self, if is_expert: hf_grouped = not hasattr(hf_mlp.experts, '__len__') hf_mlp = hf_mlp.experts if hf_grouped else hf_mlp.experts[0] - num_local_experts = args.num_experts // self.ep_size + num_local_experts = args.num_moe_experts // self.ep_size is_gate_up = hasattr(hf_mlp, 'gate_up_proj') # transformers 5.0 compatibility if self.is_transformers_5: @@ -982,7 +984,7 @@ def _set_mlp_state(self, if 'gate_up_proj' in hf_state_dict: gate_up_proj_weight = torch.concat( [hf_state_dict['gate_up_proj'], gate_up_proj_weight], dim=0) - is_last_ckpt = gate_up_proj_weight.shape[0] == args.num_experts + is_last_ckpt = gate_up_proj_weight.shape[0] == args.num_moe_experts if args.llm_model_type == 'gpt_oss' and is_last_ckpt: gate_proj_weight, up_proj_weight = gate_up_proj_weight.chunk(2, dim=2) new_gate_up_proj_weight = torch.empty_like(gate_up_proj_weight) @@ -1233,7 +1235,7 @@ def _set_mla_attn_state( def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): mg_attn = None if mg_layer is None else mg_layer.self_attention - if self.args.multi_latent_attention: + if self.config.multi_latent_attention: hf_state_dict.update(self._set_mla_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore) else: @@ -1345,8 +1347,8 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) hf_state_dict = {} layer_idx = 0 - prog_bar = tqdm(range(self.args.num_layers), dynamic_ncols=True, desc=tqdm_desc, disable=self.disable_tqmd) - while layer_idx < self.args.num_layers: + prog_bar = tqdm(range(self.config.num_layers), dynamic_ncols=True, desc=tqdm_desc, disable=self.disable_tqmd) + while layer_idx < self.config.num_layers: lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model if len(lm_model.decoder.layers) > 0: start_idx = lm_model.decoder.layers[0].layer_number - 1 @@ -1377,13 +1379,13 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd yield from list(self._add_prefix(res, hf_prefix).items()) hf_state_dict = {} - if (not to_mcore or is_pp_last_stage) and self.args.mtp_num_layers: + if (not to_mcore or is_pp_last_stage) and self.config.mtp_num_layers: lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model if to_mcore and self.pp_rank > 0: self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) layer_idx = 0 - while layer_idx < self.args.mtp_num_layers: + while layer_idx < self.config.mtp_num_layers: res = self._convert_mtp_layer(lm_model, hf_state_dict, f'{self.hf_mtp_prefix}.', layer_idx, to_mcore) layer_idx += 1 if to_mcore: @@ -1407,7 +1409,7 @@ def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): mtp_layer = lm_model.mtp.layers[layer_idx] if hasattr(lm_model, 'mtp') else None if self.hf_mtp_prefix == self.hf_layers_prefix: - hf_layer_idx = layer_idx + self.args.num_layers + hf_layer_idx = layer_idx + self.config.num_layers else: hf_layer_idx = layer_idx hf_prefix = f'{hf_prefix}{hf_layer_idx}.' @@ -1428,10 +1430,10 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: hf_state_dict = {} self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict) transformer_layer = None if mtp_layer is None else mtp_layer.transformer_layer - if not to_mcore and not self.args.hf_model_type.startswith('qwen3_next'): + if not to_mcore and not self.args.model_type.startswith('qwen3_next'): self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', to_mcore) - if self.args.untie_embeddings_and_output_weights: + if self.config.untie_embeddings_and_output_weights: self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) hf_state_dict.update(self._set_layer_attn(transformer_layer, hf_state_dict, -1, to_mcore)) @@ -1447,6 +1449,7 @@ def load_weights(self, mg_model, hf_model_dir: str, is_peft_format: bool = False self._is_peft_format = is_peft_format self._adapter_name = adapter_name mg_model = unwrap_model(mg_model) + self.config = mg_model.config hf_model_dir = safe_snapshot_download(hf_model_dir, use_hf=self.args.use_hf, hub_token=self.args.hub_token) with torch.no_grad(), SafetensorLazyLoader(hf_model_dir, is_peft_format=is_peft_format) as loader: state_dict = loader.get_state_dict() @@ -1467,6 +1470,7 @@ def export_weights(self, self._peft_modules_to_save = set() hf_prefix = 'base_model.model.' if is_peft_format else '' mg_models = [unwrap_model(mg_model) for mg_model in mg_models] + self.config = mg_models[0].config with torch.no_grad(): yield from self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc) @@ -1475,7 +1479,7 @@ def save_weights(self, output_dir: str, is_peft_format: bool = False, processor=None, - config=None) -> None: + hf_config=None) -> None: """Save the mg_model checkpoint in HF format""" torch.cuda.empty_cache() saver = StreamingSafetensorSaver( @@ -1488,9 +1492,9 @@ def save_weights(self, saver.finalize() args = self.args processor = processor if processor is not None else self.processor - if config is None: - config = self.hf_model.config - config = copy(config) + if hf_config is None: + hf_config = self.hf_model.config + hf_config = copy(hf_config) if is_last_rank(): if is_peft_format: peft_config = copy(mg_models[0].peft_config[self._adapter_name]) @@ -1512,19 +1516,20 @@ def save_weights(self, peft_config.save_pretrained(output_dir) else: if args.mtp_num_layers: - config.num_nextn_predict_layers = args.mtp_num_layers - config.vocab_size = args.padded_vocab_size + hf_config.num_nextn_predict_layers = args.mtp_num_layers + hf_config.vocab_size = args.padded_vocab_size if args.fp8 is not None and args.fp8_recipe == 'blockwise' and args.fp8_param_gather: - if getattr(config, 'quantization_config', None) is None: + if getattr(hf_config, 'quantization_config', None) is None: from transformers.utils.quantization_config import FineGrainedFP8Config modules_to_not_convert = get_modules_to_not_convert(self.hf_model) - config.quantization_config = FineGrainedFP8Config(modules_to_not_convert=modules_to_not_convert) - elif hasattr(config, 'quantization_config'): - del config.quantization_config - config.save_pretrained(output_dir) + hf_config.quantization_config = FineGrainedFP8Config( + modules_to_not_convert=modules_to_not_convert) + elif hasattr(hf_config, 'quantization_config'): + del hf_config.quantization_config + hf_config.save_pretrained(output_dir) if getattr(self.hf_model, '_auto_class') is not None: try: - custom_object_save(self.hf_model, output_dir, config=config) + custom_object_save(self.hf_model, output_dir, config=hf_config) except FileNotFoundError as e: logger.error(f'custom_object_save Error: {e}') save_checkpoint( diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index ab48fa958e..7ea3b48975 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -106,13 +106,12 @@ def __init__( if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'): del self.decoder.layers[i].self_attention.rotary_pos_emb self.attention_scaling = 1. - self.args = args = config.args - new_inv_freq, self.attention_scaling = get_rope_inv_freq(args) + new_inv_freq, self.attention_scaling = get_rope_inv_freq(config) self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) - if args.task_type == 'seq_cls' and self.post_process: + if self.config.task_type == 'seq_cls' and self.post_process: self.output_layer = OutputLayerLinear( config.hidden_size, - args.num_labels, + self.config.num_labels, config=config, init_method=config.init_method, bias=False, @@ -121,7 +120,7 @@ def __init__( skip_weight_param_allocation=False, ) self.output_layer.weight.average_gradients_across_tp_domain = True - elif args.task_type == 'embedding' and self.post_process: + elif self.config.task_type == 'embedding' and self.post_process: self.output_layer = None if (self.attention_scaling != 1 or config.position_embedding_type == 'mrope') and config.apply_rope_fusion: @@ -202,7 +201,7 @@ def _preprocess( attention_scaling = dynamic_rope_update(self, self.rotary_pos_emb.inv_freq, rotary_seq_len) if attention_scaling is not None and attention_scaling != self.attention_scaling: raise ValueError('Currently does not support changing attention_scaling during training. ' - f'args.attention_scaling: {self.attention_scaling}, ' + f'self.attention_scaling: {self.attention_scaling}, ' f'current_attention_scaling: {attention_scaling}.') packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if self.position_embedding_type == 'mrope': @@ -338,8 +337,7 @@ def _postprocess( """ if not self.post_process: return hidden_states - args = self.args - labels = labels if args.task_type == 'causal_lm' else None + labels = labels if self.config.task_type == 'causal_lm' else None in_inference_mode = inference_context is not None and not self.training if in_inference_mode: assert runtime_gather_output, 'Inference must always gather TP logits' @@ -391,7 +389,7 @@ def _postprocess( else: loss_mask[:, cu_seqlens[:-1]] = 0 loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1) - if args.context_parallel_size > 1: + if self.config.context_parallel_size > 1: loss_mask_ = split_cp_inputs(loss_mask, cu_seqlens, dim=1) else: loss_mask_ = loss_mask.clone() @@ -432,16 +430,16 @@ def _postprocess( # (so that the output layer, which expects S×B×H, receives only the final token) hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1) - if args.task_type in {'seq_cls', 'embedding' - } and args.sequence_parallel and args.tensor_model_parallel_size > 1: + if self.config.task_type in {'seq_cls', 'embedding' + } and self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1: hidden_states = gather_from_sequence_parallel_region(hidden_states) - if args.task_type == 'embedding': + if self.config.task_type == 'embedding': logits = F.normalize(hidden_states, p=2, dim=-1) else: logits, _ = self.output_layer( hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) - if args.task_type == 'generative_reranker': + if self.config.task_type == 'generative_reranker': logits = gather_from_tensor_model_parallel_region(logits) positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes') negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no') diff --git a/swift/megatron/model/mm_gpts/glm.py b/swift/megatron/model/mm_gpts/glm.py index 55aacf082a..05c9a394c4 100644 --- a/swift/megatron/model/mm_gpts/glm.py +++ b/swift/megatron/model/mm_gpts/glm.py @@ -18,7 +18,7 @@ def __init__(self, config): super().__init__(config, Glm4vMoeTextModel) def get_inputs_embeds(self, inputs_embeds, **kwargs): - return Template._get_inputs_embeds_hf(inputs_embeds, kwargs, self.visual, self.processor, self.model_config) + return Template._get_inputs_embeds_hf(inputs_embeds, kwargs, self.visual, self.processor, self.hf_config) class Glm4vMoeLoader(MegatronModelLoader): diff --git a/swift/megatron/model/mm_gpts/internvl.py b/swift/megatron/model/mm_gpts/internvl.py index 25442e190c..469538b5dd 100644 --- a/swift/megatron/model/mm_gpts/internvl.py +++ b/swift/megatron/model/mm_gpts/internvl.py @@ -115,10 +115,10 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): pixel_values = pixel_values.to(device=device) image_features = model.model.get_image_features( pixel_values, - vision_feature_layer=self.model_config.vision_feature_layer, - vision_feature_select_strategy=self.model_config.vision_feature_select_strategy, + vision_feature_layer=self.hf_config.vision_feature_layer, + vision_feature_select_strategy=self.hf_config.vision_feature_select_strategy, ) - special_image_mask = input_ids == self.model_config.image_token_id + special_image_mask = input_ids == self.hf_config.image_token_id special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) @@ -126,8 +126,8 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): dummy_pixel_values = torch.zeros((1, 3, 32, 32), device=device, dtype=self.vision_tower.dtype) image_features = model.model.get_image_features( dummy_pixel_values, - vision_feature_layer=self.model_config.vision_feature_layer, - vision_feature_select_strategy=self.model_config.vision_feature_select_strategy, + vision_feature_layer=self.hf_config.vision_feature_layer, + vision_feature_select_strategy=self.hf_config.vision_feature_select_strategy, ) inputs_embeds = inputs_embeds + image_features.mean() * 0. return inputs_embeds diff --git a/swift/megatron/model/mm_gpts/llama4.py b/swift/megatron/model/mm_gpts/llama4.py index 8ff23a65a7..af66272bed 100644 --- a/swift/megatron/model/mm_gpts/llama4.py +++ b/swift/megatron/model/mm_gpts/llama4.py @@ -31,7 +31,7 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): pixel_values = kwargs.get('pixel_values') input_ids = kwargs.get('input_ids') model = self._hf_model[0] - vision_feature_select_strategy = self.model_config.vision_config.vision_feature_select_strategy + vision_feature_select_strategy = self.hf_config.vision_config.vision_feature_select_strategy origin_pixel_values = pixel_values if pixel_values is None: pixel_values = torch.zeros((1, 3, 336, 336), dtype=self.vision_model.dtype, device=inputs_embeds.device) diff --git a/swift/megatron/model/mm_gpts/qwen.py b/swift/megatron/model/mm_gpts/qwen.py index dd519e8c9e..de0bc5b6fc 100644 --- a/swift/megatron/model/mm_gpts/qwen.py +++ b/swift/megatron/model/mm_gpts/qwen.py @@ -33,7 +33,7 @@ def __init__(self, config): super().__init__(config, ignore_init_model_cls) def get_inputs_embeds(self, inputs_embeds, **kwargs): - return Template._get_inputs_embeds_hf(inputs_embeds, kwargs, self.visual, self.processor, self.model_config) + return Template._get_inputs_embeds_hf(inputs_embeds, kwargs, self.visual, self.processor, self.hf_config) class Qwen2_5VLBridge(MultimodalGPTBridge): @@ -101,7 +101,7 @@ def prepare_model(self, hf_model): del self.thinker.lm_head def get_inputs_embeds(self, inputs_embeds, **kwargs): - thinker_config = self.model_config.thinker_config + thinker_config = self.hf_config.thinker_config inputs_embeds = Template._get_inputs_embeds_hf(inputs_embeds, kwargs, self.thinker.visual, self.processor, thinker_config) input_ids = kwargs['input_ids'] diff --git a/swift/megatron/model/mm_gpts/qwen3_vl.py b/swift/megatron/model/mm_gpts/qwen3_vl.py index 1f18dc59f2..c9cdd60ca7 100644 --- a/swift/megatron/model/mm_gpts/qwen3_vl.py +++ b/swift/megatron/model/mm_gpts/qwen3_vl.py @@ -158,7 +158,7 @@ def _get_inputs_embeds(inputs_embeds, inputs, visual, processor, config): def get_inputs_embeds(self, inputs_embeds, **kwargs): input_ids = kwargs['input_ids'] visual = self.thinker.visual - config = self.model_config.thinker_config + config = self.hf_config.thinker_config res = self._get_inputs_embeds(inputs_embeds, kwargs, visual, self.processor, config) inputs_embeds = res['inputs_embeds'] input_features = kwargs.get('input_features') @@ -472,7 +472,7 @@ def __init__(self, config): super().__init__(config, [Qwen3VLTextModel, Qwen3VLMoeTextModel]) def get_inputs_embeds(self, inputs_embeds, **kwargs): - return Qwen3Omni_Vit._get_inputs_embeds(inputs_embeds, kwargs, self.visual, self.processor, self.model_config) + return Qwen3Omni_Vit._get_inputs_embeds(inputs_embeds, kwargs, self.visual, self.processor, self.hf_config) class Qwen3VLLoader(MegatronModelLoader): diff --git a/swift/megatron/model/mm_gpts/utils.py b/swift/megatron/model/mm_gpts/utils.py index ba118b39ce..d4bc88cce3 100644 --- a/swift/megatron/model/mm_gpts/utils.py +++ b/swift/megatron/model/mm_gpts/utils.py @@ -47,19 +47,18 @@ class HuggingFaceModule(_HuggingFaceModule, ABC): def __init__(self, config, ignore_init_model_cls=None): super().__init__(config) - args = self.config.args - attn_impl = getattr(args, 'attn_impl', None) or 'flash_attn' - kwargs = {'attn_impl': attn_impl} if args.attention_backend.name == 'flash' else {} + attn_impl = getattr(config, 'attn_impl', None) or 'flash_attn' + kwargs = {'attn_impl': attn_impl} if config.attention_backend.name == 'flash' else {} ignore_init_model_cls = ignore_init_model_cls or [] if not isinstance(ignore_init_model_cls, list): ignore_init_model_cls = [ignore_init_model_cls] context_list = [patch_device_map_meta(model_cls) for model_cls in ignore_init_model_cls] context_list.append(patch_hf_initialize_weight()) - kwargs['model_type'] = args.hf_model_type + kwargs['model_type'] = config.hf_model_type with ContextManagers(context_list), disable_safe_ddp_context_use_barrier(): model, self.processor = get_model_processor( - args.model_dir, torch_dtype=args.torch_dtype, return_dummy_model=True, **kwargs) - self.model_config = model.config + config.model_dir, torch_dtype=config.torch_dtype, return_dummy_model=True, **kwargs) + self.hf_config = model.config for hf_prefix, mg_prefix in self.module_mapping.items(): setattr(self, mg_prefix, deep_getattr(model, hf_prefix)) self._hf_model = [model] diff --git a/swift/megatron/model/model_config.py b/swift/megatron/model/model_config.py index f0197ce378..08a98fd45a 100644 --- a/swift/megatron/model/model_config.py +++ b/swift/megatron/model/model_config.py @@ -1,9 +1,12 @@ -from dataclasses import dataclass +from dataclasses import dataclass, fields from typing import List, Literal, Optional, Union +import torch import torch.nn.functional as F +from megatron.core import mpu from megatron.core.fusions.fused_bias_geglu import quick_gelu from megatron.core.transformer import TransformerConfig +from transformers.utils import is_torch_npu_available from swift.megatron.utils import convert_hf_config from swift.utils import get_logger, json_parse_to_dict @@ -57,7 +60,7 @@ class MegatronModelConfig(TransformerConfig): moe_apply_probs_on_input: Optional[bool] = None # moe - num_experts: Optional[int] = None + num_moe_experts: Optional[int] = None moe_layer_freq: str = 1 moe_ffn_hidden_size: Optional[int] = None moe_shared_expert_intermediate_size: Optional[int] = None @@ -91,49 +94,65 @@ class MegatronModelConfig(TransformerConfig): linear_conv_kernel_dim: Optional[int] = None layer_types: Optional[List[str]] = None - # apply_layernorm_1p: bool = False # TODO + layernorm_zero_centered_gamma: bool = False + + # Override + persist_layer_norm: bool = True + deallocate_pipeline_outputs: bool = True + batch_p2p_comm: bool = True + cp_comm_type: str = 'p2p' def __post_init__(self): + self.pipeline_dtype = self.torch_dtype if self.moe_router_dtype.lower() == 'none': self.moe_router_dtype = None - if self.num_experts is not None: + if self.num_moe_experts is not None: if self.moe_ffn_hidden_size is None: self.moe_ffn_hidden_size = self.ffn_hidden_size if self.rope_scaling is not None: self.rope_scaling = json_parse_to_dict(self.rope_scaling) if 'type' in self.rope_scaling and 'rope_type' not in self.rope_scaling: self.rope_scaling['rope_type'] = self.rope_scaling['type'] + + if self.swiglu: + self.activation_func = F.silu + self.gated_linear_unit = True + if self.quick_geglu: + assert not self.swiglu + self.gated_linear_unit = True + self.activation_func = quick_gelu super().__post_init__() + self._check_npu() self.variable_seq_lengths = True + def _check_npu(self): + MAX_NPU_EXPERTS_PER_EP = 128 + num_experts = self.num_moe_experts + expert_model_parallel_size = mpu.get_expert_model_parallel_world_size() + if is_torch_npu_available() and num_experts > MAX_NPU_EXPERTS_PER_EP: + required_ep = (num_experts + MAX_NPU_EXPERTS_PER_EP - 1) // MAX_NPU_EXPERTS_PER_EP + if expert_model_parallel_size < required_ep: + logger.warning(f'{">" * 20} WARNING {"<" * 20}\n' + f'MindSpeed on NPU supports up to {MAX_NPU_EXPERTS_PER_EP} experts per EP group. ' + f'num_experts={num_experts}, ' + f'expert_model_parallel_size={expert_model_parallel_size}. ' + f'Please set expert_model_parallel_size (EP) to {required_ep} ' + f'(num_experts / {MAX_NPU_EXPERTS_PER_EP}) or higher.') + def create_mcore_model_config(args, hf_config): # Translate args to core transformer configuration kw_args = convert_hf_config(hf_config) - kw_args['persist_layer_norm'] = True - # TODO: apply_layernorm_1p - kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p - kw_args['deallocate_pipeline_outputs'] = True - kw_args['pipeline_dtype'] = args.torch_dtype - kw_args['batch_p2p_comm'] = True - kw_args['num_moe_experts'] = args.num_experts - kw_args['rotary_interleaved'] = args.rotary_interleaved + for f in fields(MegatronModelConfig): + if hasattr(args, f.name): + kw_args[f.name] = getattr(args, f.name) kw_args['num_layers_in_first_pipeline_stage'] = args.decoder_first_pipeline_num_layers kw_args['num_layers_in_last_pipeline_stage'] = args.decoder_last_pipeline_num_layers kw_args['fp8_param'] = args.fp8_param_gather - if args.swiglu: - kw_args['activation_func'] = F.silu - kw_args['gated_linear_unit'] = True - kw_args['bias_activation_fusion'] = args.bias_swiglu_fusion - else: - kw_args['bias_activation_fusion'] = args.bias_gelu_fusion - if args.quick_geglu: - assert not args.swiglu - kw_args['gated_linear_unit'] = True - kw_args['activation_func'] = quick_gelu - kw_args['cp_comm_type'] = 'p2p' kw_args['inference_sampling_seed'] = args.seed - kw_args['variable_seq_lengths'] = True + swiglu = kw_args.get('swiglu', True) + kw_args['bias_activation_fusion'] = args.bias_swiglu_fusion if swiglu else args.bias_gelu_fusion config = MegatronModelConfig(**kw_args) - config.args = config + config.hf_config = hf_config + config.args = args return config diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index bd441dc4b2..4c71f307d3 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -71,18 +71,18 @@ class MegatronModelLoader: bridge_cls = GPTBridge def __init__(self, args, hf_config): + from swift.megatron.model import GPTModel, MultimodalGPTModel self.args = args self.hf_config = hf_config self.config = create_mcore_model_config(args, hf_config) self.mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') if self.model_cls is None: self.model_cls = MultimodalGPTModel if self.args.is_multimodal else GPTModel - self._check_npu() + self._init_config() self.bridge = self.bridge_cls(args) def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - args = self.args - if self.config.num_experts: + if self.config.num_moe_experts: kwargs = {'qk_l2_norm': self.config.qk_l2_norm, 'vp_stage': vp_stage} if self.mcore_013 else {} transformer_layer_spec = get_gpt_decoder_block_spec( self.config, use_transformer_engine=True, normalization=self.config.normalization, **kwargs) @@ -94,7 +94,7 @@ def _get_transformer_layer_spec(self): config = self.config kwargs = {'qk_l2_norm': config.qk_l2_norm} if self.mcore_013 else {} transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - config.num_experts, + config.num_moe_experts, self.args.moe_grouped_gemm, config.qk_layernorm, config.multi_latent_attention, @@ -141,18 +141,17 @@ def create_model_and_load( self.bridge.load_weights(model, self.args.model_dir) return model - def _check_npu(self): - MAX_NPU_EXPERTS_PER_EP = 128 - num_experts = self.config.num_experts - if is_torch_npu_available() and num_experts > MAX_NPU_EXPERTS_PER_EP: - required_ep = (num_experts + MAX_NPU_EXPERTS_PER_EP - 1) // MAX_NPU_EXPERTS_PER_EP - if self.args.expert_model_parallel_size < required_ep: - logger.warning(f'{">" * 20} WARNING {"<" * 20}\n' - f'MindSpeed on NPU supports up to {MAX_NPU_EXPERTS_PER_EP} experts per EP group. ' - f'num_experts={num_experts}, ' - f'expert_model_parallel_size={self.args.expert_model_parallel_size}. ' - f'Please set expert_model_parallel_size (EP) to {required_ep} ' - f'(num_experts / {MAX_NPU_EXPERTS_PER_EP}) or higher.') + def _init_config(self): + config = self.config + # apply_rope_fusion + if config.apply_rope_fusion is not None: + return + if config.multi_latent_attention or config.rotary_interleaved: + # Upgrading transformer_engine requires checking here. + config.apply_rope_fusion = False + else: + config.apply_rope_fusion = True + logger.info(f'Setting config.apply_rope_fusion: {config.apply_rope_fusion}.') def _create_model(self, transformer_layer_spec, diff --git a/swift/megatron/model/rope.py b/swift/megatron/model/rope.py index 9943c33b16..55cc1c43e0 100644 --- a/swift/megatron/model/rope.py +++ b/swift/megatron/model/rope.py @@ -25,24 +25,24 @@ def __init__(self, **kwargs): setattr(self, k, v) -def _get_dummy_config(args): +def _get_dummy_config(config): dummy_config = DummyConfig( - rope_scaling=args.rope_scaling, - rope_theta=args.rotary_base, - max_position_embeddings=args.max_position_embeddings, - head_dim=args.qk_pos_emb_head_dim if args.multi_latent_attention else args.kv_channels, - hidden_size=args.hidden_size, - num_attention_heads=args.num_attention_heads, + rope_scaling=config.rope_scaling, + rope_theta=config.rotary_base, + max_position_embeddings=config.max_position_embeddings, + head_dim=config.qk_pos_emb_head_dim if config.multi_latent_attention else config.kv_channels, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, ) - original_max_position_embeddings = args.original_max_position_embeddings or ( - args.rope_scaling or {}).get('original_max_position_embeddings') + original_max_position_embeddings = config.original_max_position_embeddings or ( + config.rope_scaling or {}).get('original_max_position_embeddings') if original_max_position_embeddings is not None: dummy_config.original_max_position_embeddings = original_max_position_embeddings - if args.partial_rotary_factor is not None: - dummy_config.partial_rotary_factor = args.partial_rotary_factor + if config.partial_rotary_factor is not None: + dummy_config.partial_rotary_factor = config.partial_rotary_factor if transformers_5: rope_parameters = getattr(dummy_config, 'rope_parameters', None) or {} - rope_parameters.update(args.rope_scaling or {}) + rope_parameters.update(config.rope_scaling or {}) dummy_config.rope_parameters = rope_parameters return dummy_config @@ -107,11 +107,11 @@ def _get_rope_type(rope_scaling: Optional[Dict[str, Any]]): return rope_type -def get_rope_inv_freq(args, seq_len=None): +def get_rope_inv_freq(config, seq_len=None): from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS) - dummy_config = _get_dummy_config(args) - rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(args.rope_scaling)] + dummy_config = _get_dummy_config(config) + rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(config.rope_scaling)] inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len) if attention_scaling is None: attention_scaling = 1. diff --git a/swift/megatron/pipelines/export/export.py b/swift/megatron/pipelines/export/export.py index 46febb9a9d..ae943c3e00 100644 --- a/swift/megatron/pipelines/export/export.py +++ b/swift/megatron/pipelines/export/export.py @@ -61,7 +61,7 @@ def convert_mcore2hf(self) -> None: args.save, is_peft_format=save_peft_format, processor=self.processor, - config=hf_config) + hf_config=hf_config) args_path = os.path.join(args.adapter_load or args.load or args.model, 'args.json') if os.path.exists(args_path): if is_last_rank(): diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 56ce46819a..5cc698d411 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -1180,7 +1180,7 @@ def save_checkpoint(self, iteration, model, *_args, **kwargs): output_dir, is_peft_format=save_peft_format, processor=self.template.processor, - config=self.template.config) + hf_config=self.template.config) if args.tuner_type == 'lora' and args.merge_lora: self.unmerge_lora_adapters() diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index e4ce8ba62e..b856e2f16b 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -130,7 +130,7 @@ def _load_teacher_model(self, teacher_model_path: str, model_type: str): # When student is MoE and teacher is Dense (or vice versa), these keys need to be # properly reset to ensure correct model architecture creation. moe_related_keys = { - 'num_experts', + 'num_moe_experts', 'moe_ffn_hidden_size', 'moe_shared_expert_intermediate_size', 'moe_router_topk', @@ -172,7 +172,7 @@ def _load_teacher_model(self, teacher_model_path: str, model_type: str): # Reset MoE-related keys that are not in teacher config to None. # This ensures Dense teacher doesn't inherit MoE settings from MoE student, # and MoE teacher gets its own settings without interference from Dense student. - teacher_is_moe = teacher_megatron_config.get('num_experts') is not None + teacher_is_moe = teacher_megatron_config.get('num_moe_experts') is not None for key in moe_related_keys: if key not in teacher_megatron_config and hasattr(megatron_args, key): setattr(megatron_args, key, None) @@ -693,7 +693,7 @@ def patched_validate_args(self, args, *_args, **kwargs): This is called before Megatron's validate_args, allowing us to reset EP to 1 when student is Dense but EP > 1 was configured (for MoE teacher). """ - student_is_moe = getattr(args, 'num_experts', None) is not None + student_is_moe = getattr(args, 'num_moe_experts', None) is not None if not student_is_moe: # Reset EP to 1 in Megatron args for Dense student self._original_ep_size = args.expert_model_parallel_size diff --git a/swift/megatron/utils/config.py b/swift/megatron/utils/config.py index fa0ce5f3d6..ae12b22083 100644 --- a/swift/megatron/utils/config.py +++ b/swift/megatron/utils/config.py @@ -28,7 +28,7 @@ 'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k'], 'moe_router_num_groups': ['n_group'], 'moe_router_group_topk': ['topk_group'], - 'num_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts', 'num_local_experts'], + 'num_moe_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts', 'num_local_experts'], 'moe_router_pre_softmax': ['norm_topk_prob'], # deepseek 'q_lora_rank': ['q_lora_rank'], diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index 42797b8945..3bf090bbb3 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -101,7 +101,7 @@ def initialize_megatron(args): _set_random_seed(args.seed) # Setup MoE aux loss scale value. - if args.num_experts is not None: + if args.model_info.is_moe_model: from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler MoEAuxLossAutoScaler.set_loss_scale(torch.ones(1, device=torch.cuda.current_device())) From a7a1118a1a2f9f3afffa3f63f727a9eaed92141d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Feb 2026 14:31:08 +0800 Subject: [PATCH 25/43] update --- swift/megatron/model/gpt_bridge.py | 43 +++++++++++++------------- swift/megatron/model/gpt_model.py | 15 ++++----- swift/megatron/model/mm_gpt_model.py | 2 +- swift/megatron/model/mm_gpts/utils.py | 7 +++-- swift/megatron/model/model_config.py | 20 ++++++------ swift/megatron/model/register.py | 25 ++++++++------- swift/megatron/trainers/gkd_trainer.py | 2 +- swift/megatron/utils/convert_utils.py | 2 +- swift/megatron/utils/utils.py | 4 +-- 9 files changed, 63 insertions(+), 57 deletions(-) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 3af00395ff..dc5f9a0e1b 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -733,11 +733,11 @@ def _set_mlp_state(self, is_expert = ep_rank is not None num_local_experts = 1 hf_grouped = False - args = self.args + config = self.config if is_expert: hf_grouped = not hasattr(hf_mlp.experts, '__len__') hf_mlp = hf_mlp.experts if hf_grouped else hf_mlp.experts[0] - num_local_experts = args.num_moe_experts // self.ep_size + num_local_experts = config.num_moe_experts // self.ep_size is_gate_up = hasattr(hf_mlp, 'gate_up_proj') # transformers 5.0 compatibility if self.is_transformers_5: @@ -807,7 +807,7 @@ def _set_mlp_state(self, fc1_weight = [getattr(mg_mlp.linear_fc1, f'weight{i}') for i in range(num_local_experts)] if is_expert else mg_mlp.linear_fc1.weight fc1_bias = None - if args.add_bias_linear: + if config.add_bias_linear: assert is_expert and not has_scale_inv, 'not support' # TODO fc1_bias = [getattr(mg_mlp.linear_fc1, f'bias{i}') for i in range(num_local_experts)] gate_up_scale_inv = None @@ -831,7 +831,7 @@ def _set_mlp_state(self, gate_up_proj_bias = hf_state_dict['gate_up_proj_bias'].load() gate_up_proj_bias = gate_up_proj_bias[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] - if args.llm_model_type == 'gpt_oss': + if config.llm_model_type == 'gpt_oss': gate_proj_weight = gate_up_proj_weight[:, ::2] up_proj_weight = gate_up_proj_weight[:, 1::2] gate_proj_bias, up_proj_bias = gate_up_proj_bias[:, ::2], gate_up_proj_bias[:, 1::2] @@ -967,13 +967,13 @@ def _set_mlp_state(self, if isinstance(linear_fc1, LoraParallelLinear): linear_fc1 = linear_fc1.base_layer fc1_weight = [getattr(linear_fc1, f'weight{i}') for i in range(num_local_experts)] - if args.add_bias_linear: + if config.add_bias_linear: fc1_bias = [getattr(linear_fc1, f'bias{i}') for i in range(num_local_experts)] else: fc1_weight = mg_mlp.linear_fc1.weight gate_up_proj_weight, scale_inv = self._get_weight(fc1_weight, 'linear_fc1.weight', is_expert=is_expert) gate_up_proj_bias = None - if args.add_bias_linear: + if config.add_bias_linear: gate_up_proj_bias, _ = self._get_weight(fc1_bias, 'linear_fc1.bias', is_expert=is_expert) del fc1_weight if gate_up_proj_weight is not None: @@ -984,8 +984,8 @@ def _set_mlp_state(self, if 'gate_up_proj' in hf_state_dict: gate_up_proj_weight = torch.concat( [hf_state_dict['gate_up_proj'], gate_up_proj_weight], dim=0) - is_last_ckpt = gate_up_proj_weight.shape[0] == args.num_moe_experts - if args.llm_model_type == 'gpt_oss' and is_last_ckpt: + is_last_ckpt = gate_up_proj_weight.shape[0] == config.num_moe_experts + if config.llm_model_type == 'gpt_oss' and is_last_ckpt: gate_proj_weight, up_proj_weight = gate_up_proj_weight.chunk(2, dim=2) new_gate_up_proj_weight = torch.empty_like(gate_up_proj_weight) new_gate_up_proj_weight[..., ::2] = gate_proj_weight @@ -1004,7 +1004,7 @@ def _set_mlp_state(self, if 'gate_up_proj_bias' in hf_state_dict: gate_up_proj_bias = torch.concat( [hf_state_dict['gate_up_proj_bias'], gate_up_proj_bias], dim=0) - if args.llm_model_type == 'gpt_oss' and is_last_ckpt: + if config.llm_model_type == 'gpt_oss' and is_last_ckpt: gate_proj_bias, up_proj_bias = gate_up_proj_bias.chunk(2, dim=1) new_gate_up_proj_bias = torch.empty_like(gate_up_proj_bias) new_gate_up_proj_bias[:, ::2] = gate_proj_bias @@ -1076,7 +1076,7 @@ def _set_mlp_state(self, fc2_weight = [getattr(mg_mlp.linear_fc2, f'weight{i}') for i in range(num_local_experts)] if is_expert else mg_mlp.linear_fc2.weight fc2_bias = None - if args.add_bias_linear: + if config.add_bias_linear: fc2_bias = [getattr(mg_mlp.linear_fc2, f'bias{i}') for i in range(num_local_experts)] down_scale_inv = None if hf_grouped: @@ -1163,10 +1163,10 @@ def _set_mlp_state(self, if isinstance(linear_fc2, LoraParallelLinear): linear_fc2 = linear_fc2.base_layer fc2_weight = [getattr(linear_fc2, f'weight{i}') for i in range(num_local_experts)] - if args.add_bias_linear: + if config.add_bias_linear: fc2_bias = [getattr(linear_fc2, f'bias{i}') for i in range(num_local_experts)] down_proj_weight, scale_inv = self._get_weight(fc2_weight, 'linear_fc2.weight', is_expert=is_expert) - if args.add_bias_linear: + if config.add_bias_linear: down_proj_bias, _ = self._get_weight(fc2_bias, 'linear_fc2.bias', is_expert=is_expert) del fc2_weight, fc2_bias if down_proj_weight is not None: @@ -1180,7 +1180,7 @@ def _set_mlp_state(self, if 'down_proj_scale_inv' in hf_state_dict: scale_inv = torch.concat([hf_state_dict['down_proj_scale_inv'], scale_inv], dim=0) hf_state_dict['down_proj_scale_inv'] = scale_inv.clone() - if args.add_bias_linear: + if config.add_bias_linear: if 'down_proj_bias' in hf_state_dict: down_proj_bias = torch.concat([hf_state_dict['down_proj_bias'], down_proj_bias], dim=0) @@ -1213,7 +1213,7 @@ def _set_mla_attn_state( else: hf_state_dict = {} self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore) - if self.args.q_lora_rank is None: + if self.config.q_lora_rank is None: self._set_state_dict(mg_attn, 'linear_q_proj.weight', hf_state_dict, 'q_proj.weight', to_mcore) else: self._set_state_dict(mg_attn, 'linear_q_down_proj.weight', hf_state_dict, 'q_a_proj.weight', to_mcore) @@ -1221,8 +1221,8 @@ def _set_mla_attn_state( self._set_state_dict(mg_attn, 'linear_kv_down_proj.weight', hf_state_dict, 'kv_a_proj_with_mqa.weight', to_mcore) self._set_state_dict(mg_attn, 'linear_kv_up_proj.weight', hf_state_dict, 'kv_b_proj.weight', to_mcore) - if self.args.qk_layernorm: - if self.args.q_lora_rank is not None: + if self.config.qk_layernorm: + if self.config.q_lora_rank is not None: self._set_state_dict(mg_attn, 'linear_q_up_proj.layer_norm_weight', hf_state_dict, 'q_a_layernorm.weight', to_mcore) self._set_state_dict(mg_attn, 'linear_kv_up_proj.layer_norm_weight', hf_state_dict, 'kv_a_layernorm.weight', @@ -1297,7 +1297,7 @@ def _convert_post_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcor hf_state_dict = {} lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model if self.args.task_type != 'embedding': - if self.args.untie_embeddings_and_output_weights: + if self.config.untie_embeddings_and_output_weights: hf_lm_head_key = self.hf_lm_head_key if self.args.task_type == 'seq_cls': hf_lm_head_key = self.hf_score_key @@ -1515,10 +1515,11 @@ def save_weights(self, peft_config.modules_to_save = self._peft_modules_to_save peft_config.save_pretrained(output_dir) else: - if args.mtp_num_layers: - hf_config.num_nextn_predict_layers = args.mtp_num_layers - hf_config.vocab_size = args.padded_vocab_size - if args.fp8 is not None and args.fp8_recipe == 'blockwise' and args.fp8_param_gather: + config = self.config + if config.mtp_num_layers: + hf_config.num_nextn_predict_layers = config.mtp_num_layers + hf_config.vocab_size = config.padded_vocab_size + if config.fp8 is not None and config.fp8_recipe == 'blockwise' and config.fp8_param_gather: if getattr(hf_config, 'quantization_config', None) is None: from transformers.utils.quantization_config import FineGrainedFP8Config modules_to_not_convert = get_modules_to_not_convert(self.hf_model) diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 7ea3b48975..9084a30db4 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -108,7 +108,8 @@ def __init__( self.attention_scaling = 1. new_inv_freq, self.attention_scaling = get_rope_inv_freq(config) self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) - if self.config.task_type == 'seq_cls' and self.post_process: + self.args = args = config.args + if self.args.task_type == 'seq_cls' and self.post_process: self.output_layer = OutputLayerLinear( config.hidden_size, self.config.num_labels, @@ -120,7 +121,7 @@ def __init__( skip_weight_param_allocation=False, ) self.output_layer.weight.average_gradients_across_tp_domain = True - elif self.config.task_type == 'embedding' and self.post_process: + elif self.args.task_type == 'embedding' and self.post_process: self.output_layer = None if (self.attention_scaling != 1 or config.position_embedding_type == 'mrope') and config.apply_rope_fusion: @@ -337,7 +338,7 @@ def _postprocess( """ if not self.post_process: return hidden_states - labels = labels if self.config.task_type == 'causal_lm' else None + labels = labels if self.args.task_type == 'causal_lm' else None in_inference_mode = inference_context is not None and not self.training if in_inference_mode: assert runtime_gather_output, 'Inference must always gather TP logits' @@ -430,16 +431,16 @@ def _postprocess( # (so that the output layer, which expects S×B×H, receives only the final token) hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1) - if self.config.task_type in {'seq_cls', 'embedding' - } and self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1: + if self.args.task_type in {'seq_cls', 'embedding' + } and self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1: hidden_states = gather_from_sequence_parallel_region(hidden_states) - if self.config.task_type == 'embedding': + if self.args.task_type == 'embedding': logits = F.normalize(hidden_states, p=2, dim=-1) else: logits, _ = self.output_layer( hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) - if self.config.task_type == 'generative_reranker': + if self.args.task_type == 'generative_reranker': logits = gather_from_tensor_model_parallel_region(logits) positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes') negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no') diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index 1db8de8734..fd95a9cd63 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -35,7 +35,7 @@ def __init__(self, post_process, *args, **kwargs) self.vp_stage = self.language_model.vp_stage self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights - self.megatron_model_meta = get_megatron_model_meta(args.hf_model_type) + self.megatron_model_meta = get_megatron_model_meta(args.model_type) self.visual = None if args.mtp_num_layers: raise ValueError('MTP currently does not support multimodal models.') diff --git a/swift/megatron/model/mm_gpts/utils.py b/swift/megatron/model/mm_gpts/utils.py index d4bc88cce3..03a68cb906 100644 --- a/swift/megatron/model/mm_gpts/utils.py +++ b/swift/megatron/model/mm_gpts/utils.py @@ -47,17 +47,18 @@ class HuggingFaceModule(_HuggingFaceModule, ABC): def __init__(self, config, ignore_init_model_cls=None): super().__init__(config) - attn_impl = getattr(config, 'attn_impl', None) or 'flash_attn' + attn_impl = getattr(args, 'attn_impl', None) or 'flash_attn' kwargs = {'attn_impl': attn_impl} if config.attention_backend.name == 'flash' else {} ignore_init_model_cls = ignore_init_model_cls or [] if not isinstance(ignore_init_model_cls, list): ignore_init_model_cls = [ignore_init_model_cls] context_list = [patch_device_map_meta(model_cls) for model_cls in ignore_init_model_cls] context_list.append(patch_hf_initialize_weight()) - kwargs['model_type'] = config.hf_model_type + args = config.args + kwargs['model_type'] = args.model_type with ContextManagers(context_list), disable_safe_ddp_context_use_barrier(): model, self.processor = get_model_processor( - config.model_dir, torch_dtype=config.torch_dtype, return_dummy_model=True, **kwargs) + args.model_dir, torch_dtype=args.torch_dtype, return_dummy_model=True, **kwargs) self.hf_config = model.config for hf_prefix, mg_prefix in self.module_mapping.items(): setattr(self, mg_prefix, deep_getattr(model, hf_prefix)) diff --git a/swift/megatron/model/model_config.py b/swift/megatron/model/model_config.py index 08a98fd45a..620d83d918 100644 --- a/swift/megatron/model/model_config.py +++ b/swift/megatron/model/model_config.py @@ -103,7 +103,6 @@ class MegatronModelConfig(TransformerConfig): cp_comm_type: str = 'p2p' def __post_init__(self): - self.pipeline_dtype = self.torch_dtype if self.moe_router_dtype.lower() == 'none': self.moe_router_dtype = None if self.num_moe_experts is not None: @@ -142,17 +141,18 @@ def _check_npu(self): def create_mcore_model_config(args, hf_config): # Translate args to core transformer configuration - kw_args = convert_hf_config(hf_config) + kwargs = convert_hf_config(hf_config) for f in fields(MegatronModelConfig): if hasattr(args, f.name): - kw_args[f.name] = getattr(args, f.name) - kw_args['num_layers_in_first_pipeline_stage'] = args.decoder_first_pipeline_num_layers - kw_args['num_layers_in_last_pipeline_stage'] = args.decoder_last_pipeline_num_layers - kw_args['fp8_param'] = args.fp8_param_gather - kw_args['inference_sampling_seed'] = args.seed - swiglu = kw_args.get('swiglu', True) - kw_args['bias_activation_fusion'] = args.bias_swiglu_fusion if swiglu else args.bias_gelu_fusion - config = MegatronModelConfig(**kw_args) + kwargs[f.name] = getattr(args, f.name) + kwargs['pipeline_dtype'] = args.torch_dtype + kwargs['num_layers_in_first_pipeline_stage'] = args.decoder_first_pipeline_num_layers + kwargs['num_layers_in_last_pipeline_stage'] = args.decoder_last_pipeline_num_layers + kwargs['fp8_param'] = args.fp8_param_gather + kwargs['inference_sampling_seed'] = args.seed + swiglu = kwargs.get('swiglu', True) + kwargs['bias_activation_fusion'] = args.bias_swiglu_fusion if swiglu else args.bias_gelu_fusion + config = MegatronModelConfig(**kwargs) config.hf_config = hf_config config.args = args return config diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 4c71f307d3..49f4fb9686 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -121,24 +121,26 @@ def _set_shared_expert_gate(self, transformer_layer_spec): if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} - def create_model_and_load( + def build_model( self, pre_process=True, post_process=True, vp_stage: Optional[int] = None, + load_weights: bool = True, ) -> Union['GPTModel', 'MultimodalGPTModel']: transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) self._set_shared_expert_gate(transformer_layer_spec) mtp_block_spec = None if self.args.mtp_num_layers is not None: mtp_block_spec = self.get_mtp_block_spec(transformer_layer_spec, vp_stage=vp_stage) - model = self._create_model( + model = self._init_model( transformer_layer_spec, mtp_block_spec, pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) - self.bridge.load_weights(model, self.args.model_dir) + if load_weights: + self.bridge.load_weights(model, self.args.model_dir) return model def _init_config(self): @@ -153,12 +155,12 @@ def _init_config(self): config.apply_rope_fusion = True logger.info(f'Setting config.apply_rope_fusion: {config.apply_rope_fusion}.') - def _create_model(self, - transformer_layer_spec, - mtp_block_spec, - pre_process=True, - post_process=True, - vp_stage: Optional[int] = None): + def _init_model(self, + transformer_layer_spec, + mtp_block_spec, + pre_process=True, + post_process=True, + vp_stage: Optional[int] = None): return self.model_cls( config=self.config, transformer_layer_spec=transformer_layer_spec, @@ -172,6 +174,7 @@ def _create_model(self, def get_mcore_model( args, hf_config, + load_weights: bool = True, ): loader = args.megatron_model_meta.loader(args, hf_config) if (mpu.get_pipeline_model_parallel_world_size() > 1 and args.virtual_pipeline_model_parallel_size is not None): @@ -179,10 +182,10 @@ def get_mcore_model( for i in range(args.virtual_pipeline_model_parallel_size): pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) - model = loader.create_model_and_load(pre_process, post_process, vp_stage=i) + model = loader.build_model(pre_process, post_process, vp_stage=i, load_weights=load_weights) models.append(model) else: pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() - models = [loader.create_model_and_load(pre_process=pre_process, post_process=post_process)] + models = [loader.build_model(pre_process=pre_process, post_process=post_process, load_weights=load_weights)] return models diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index b856e2f16b..7ecb623cb7 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -166,7 +166,7 @@ def _load_teacher_model(self, teacher_model_path: str, model_type: str): # Apply teacher config to global Megatron args for key, value in teacher_megatron_config.items(): setattr(megatron_args, key, value) - megatron_args.hf_model_type = teacher_model_type + megatron_args.model_type = teacher_model_type megatron_args.model_dir = teacher_model_info.model_dir # Reset MoE-related keys that are not in teacher config to None. diff --git a/swift/megatron/utils/convert_utils.py b/swift/megatron/utils/convert_utils.py index c4a51a3842..de0a291038 100644 --- a/swift/megatron/utils/convert_utils.py +++ b/swift/megatron/utils/convert_utils.py @@ -190,7 +190,7 @@ def test_convert_precision(args, hf_model, mg_model, template, test_convert_dtyp _param = next(mg_language_model.parameters()) mg_dtype = _param.dtype mg_device = _param.device - if args.hf_model_type == 'minimax_m2': + if args.model_type == 'minimax_m2': # router to bfloat16 for n, m in mg_language_model.named_modules(): if n.endswith('router'): diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index 85bdbfcba2..fff8147b2d 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -57,7 +57,7 @@ def get_multimodal_target_regex( include_router: bool = False, ) -> str: from ..model import get_megatron_model_meta - megatron_model_meta = get_megatron_model_meta(args.hf_model_type) + megatron_model_meta = get_megatron_model_meta(args.model_type) modules = [] visual_cls = megatron_model_meta.visual_cls vision_tower = [f'visual.{vit}' for vit in visual_cls._vision_tower] @@ -192,7 +192,7 @@ def prepare_adapter(args, model): for m in model.modules(): if isinstance(m, LoraLinear): # just check - assert args.is_multimodal or args.hf_model_type == 'qwen3_next' + assert args.is_multimodal or args.model_type == 'qwen3_next' assert not isinstance(m, LoraParallelLinear) for p in m.parameters(): if p.requires_grad: From 69b8c2544a998097a444872ff78aca7a167f7005 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Feb 2026 16:03:11 +0800 Subject: [PATCH 26/43] update --- swift/megatron/arguments/megatron_args.py | 3 ++- .../megatron/arguments/megatron_base_args.py | 20 +------------- swift/megatron/init.py | 3 ++- swift/megatron/model/gpt_bridge.py | 27 ++++++++++++------- swift/megatron/model/model_config.py | 9 +++++++ swift/megatron/model/register.py | 7 ++++- swift/megatron/pipelines/export/export.py | 26 +++--------------- swift/megatron/pipelines/train/sft.py | 1 - swift/megatron/utils/utils.py | 15 ++++++++--- swift/model/register.py | 19 +++++++------ 10 files changed, 62 insertions(+), 68 deletions(-) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index fe99cd0c8f..07b77dc2b3 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -573,7 +573,8 @@ def __post_init__(self): self.model_dir = self.model_info.model_dir self.is_multimodal = self.model_meta.is_multimodal self.megatron_model_meta = get_megatron_model_meta(self.model_type) - assert self.megatron_model_meta is not None, f'Model: {args.model} is not supported.' + if self.megatron_model_meta is None: + raise ValueError(f'Model: {self.model} is not supported.') if self.apply_wd_to_qk_layernorm and self.model_type != 'qwen3_next': raise ValueError('apply_wd_to_qk_layernorm is only supported for qwen3_next') if self.pipeline_model_parallel_size == 1 and (self.decoder_first_pipeline_num_layers is not None diff --git a/swift/megatron/arguments/megatron_base_args.py b/swift/megatron/arguments/megatron_base_args.py index 44cebac027..d122e000a8 100644 --- a/swift/megatron/arguments/megatron_base_args.py +++ b/swift/megatron/arguments/megatron_base_args.py @@ -18,27 +18,9 @@ def __post_init__(self): if self.packing: self.padding_free = True BaseArguments.__post_init__(self) - self.megatron_model_meta = get_megatron_model_meta(self.model_type) - assert self.megatron_model_meta is not None, f'Model: {self.model} is not supported.' - self.seq_length = self.seq_length or self.packing_length or self.max_length + MegatronArguments.__post_init__(self) if self.streaming: self.dataloader_type = 'external' if self.num_workers > 1: self.num_workers = 1 logger.info('Using streaming dataset, setting args.num_workers to 1.') - - def init_model_args(self, tokenizer, config): - if self.task_type == 'seq_cls': - self.problem_type = self.problem_type or getattr(config, 'problem_type', None) - logger.info(f'args.problem_type: {self.problem_type}') - kwargs = convert_hf_config(config) - if tokenizer is not None and self.new_special_tokens and kwargs['padded_vocab_size'] < len(tokenizer): - kwargs['padded_vocab_size'] = math.ceil(len(tokenizer) / 128) * 128 - self.initialize_embedding = True - if self.task_type == 'seq_cls': - self.initialize_embedding = True - logger.info(f'megatron_config: {kwargs}') - for k, v in kwargs.items(): - if getattr(self, k) is None: - setattr(self, k, v) - MegatronArguments.__post_init__(self) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 2a4ced5bda..2bc418a812 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -532,7 +532,8 @@ def forward(self, *_args, **kwargs): if not mcore_013: return _origin_forward(self, *_args, **kwargs) hidden_states, context = self._forward_attention(*_args, **kwargs) - mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs + args = self.config.args + mlp_padding_free = args.mlp_padding_free and 'attention_mask' in kwargs mask = None if mlp_padding_free and hidden_states.shape[1] > 1: mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t() diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index dc5f9a0e1b..16bcae7cf6 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -24,6 +24,10 @@ mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') +EP_PP_SIZE = None +EP_PP_GROUP = None +EP_PP_RANK = None + # Some ideas for LoRA conversion are referenced from: https://github.com/modelscope/ms-swift/pull/6225 class GPTBridge: @@ -87,15 +91,20 @@ def __init__(self, args, disable_tqmd: bool = False): rank_offset=0, ) rank = dist.get_rank() - for ranks in expert_decoder_rank_generator.get_ranks('ep-pp'): - group = mpu.create_group( - ranks, - group_desc='EP-PP-GROUP', - ) - if rank in ranks: - self.ep_pp_size = self.ep_size * self.pp_size - self.ep_pp_group = group - self.ep_pp_rank = dist.get_rank(group) + global EP_PP_GROUP, EP_PP_RANK, EP_PP_SIZE + if EP_PP_GROUP is None: + for ranks in expert_decoder_rank_generator.get_ranks('ep-pp'): + group = mpu.create_group( + ranks, + group_desc='EP-PP-GROUP', + ) + if rank in ranks: + EP_PP_SIZE = self.ep_size * self.pp_size + EP_PP_GROUP = group + EP_PP_RANK = dist.get_rank(group) + self.ep_pp_size = EP_PP_SIZE + self.ep_pp_group = EP_PP_GROUP + self.ep_pp_rank = EP_PP_RANK def get_hf_mlp_prefix(self, layer_idx): if hasattr(self.hf_layers[layer_idx], 'feed_forward'): diff --git a/swift/megatron/model/model_config.py b/swift/megatron/model/model_config.py index 620d83d918..d0422ce59f 100644 --- a/swift/megatron/model/model_config.py +++ b/swift/megatron/model/model_config.py @@ -16,6 +16,10 @@ @dataclass class MegatronModelConfig(TransformerConfig): + """ + During Megatron training, multiple models may be created. This class is used to + distinguish the configurations of different models. + """ hf_model_type: Optional[str] = None llm_model_type: Optional[str] = None padded_vocab_size: Optional[int] = None @@ -145,6 +149,11 @@ def create_mcore_model_config(args, hf_config): for f in fields(MegatronModelConfig): if hasattr(args, f.name): kwargs[f.name] = getattr(args, f.name) + + if args.task_type == 'seq_cls': + args.problem_type = args.problem_type or getattr(hf_config, 'problem_type', None) + logger.info(f'args.problem_type: {args.problem_type}') + kwargs['pipeline_dtype'] = args.torch_dtype kwargs['num_layers_in_first_pipeline_stage'] = args.decoder_first_pipeline_num_layers kwargs['num_layers_in_last_pipeline_stage'] = args.decoder_last_pipeline_num_layers diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 49f4fb9686..04226d119c 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -140,7 +140,12 @@ def build_model( post_process=post_process, vp_stage=vp_stage) if load_weights: - self.bridge.load_weights(model, self.args.model_dir) + if self.args.load is not None: + load_mcore_checkpoint([mg_model], None, None, strict=True) + elif self.args.model is not None: + self.bridge.load_weights(model, self.args.model_dir) + else: + raise ValueError('Please specify `--load` or `--model`.') return model def _init_config(self): diff --git a/swift/megatron/pipelines/export/export.py b/swift/megatron/pipelines/export/export.py index ae943c3e00..3b89abea11 100644 --- a/swift/megatron/pipelines/export/export.py +++ b/swift/megatron/pipelines/export/export.py @@ -9,6 +9,7 @@ from swift.megatron.arguments import MegatronExportArguments from swift.megatron.convert import test_convert_precision +from swift.megatron.model import get_mcore_model from swift.megatron.utils import (adapter_state_dict_context, initialize_megatron, load_mcore_checkpoint, prepare_mcore_model, save_mcore_checkpoint) from swift.pipelines import SwiftPipeline, prepare_model_template @@ -35,7 +36,6 @@ def convert_mcore2hf(self) -> None: _, template = prepare_model_template(args, load_model=False, download_model=download_model) self.processor = template.processor hf_config = self.processor.model_info.config - args.init_model_args(self.tokenizer, hf_config) megatron_model_meta = args.megatron_model_meta pre_process = mpu.is_pipeline_first_stage() @@ -86,33 +86,15 @@ def convert_hf2mcore(self) -> None: download_model = args.model is not None _, template = prepare_model_template(args, load_model=False, download_model=download_model) self.processor = template.processor - args.init_model_args(self.tokenizer, self.processor.model_info.config) - megatron_model_meta = args.megatron_model_meta - - pre_process = mpu.is_pipeline_first_stage() - post_process = mpu.is_pipeline_last_stage() - mg_model = megatron_model_meta.model_provider(args, pre_process=pre_process, post_process=post_process) + hf_config = self.processor.model_info.config + mg_model = get_mcore_model(args, hf_config)[0] logger.info('Megatron model created successfully.') - bridge = megatron_model_meta.bridge_cls(args) - if args.model is not None: - bridge.load_weights(mg_model, args.model_info.model_dir) - elif args.load is not None: - load_mcore_checkpoint([mg_model], None, None, strict=True) - else: - raise ValueError('Please specify `--load` or `--model`.') dist.barrier() if args.adapters or args.adapter_load is not None: - peft_model = prepare_mcore_model(mg_model) - if args.adapters: - assert len(args.adapters) == 1, 'Currently only support one adapter' - bridge.load_weights(mg_model, args.adapters[0], is_peft_format=True) - elif args.adapter_load is not None: - with adapter_state_dict_context(): - load_mcore_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) + peft_model = prepare_mcore_model(mg_model, load_adapters=True) if args.merge_lora: logger.info('Merge LoRA...') mg_model = peft_model.merge_and_unload() - logger.info('Successfully transferred HF model weights to MG model.') _test_convert_precision = strtobool(os.getenv('SWIFT_TEST_CONVERT_PRECISION', '0')) if not _test_convert_precision: args.save_args(args.save) diff --git a/swift/megatron/pipelines/train/sft.py b/swift/megatron/pipelines/train/sft.py index 0f0dbcb8e2..52ce25cac1 100644 --- a/swift/megatron/pipelines/train/sft.py +++ b/swift/megatron/pipelines/train/sft.py @@ -54,7 +54,6 @@ def __init__(self, args: Optional[Union[List[str], MegatronSftArguments]] = None with torch.device('meta'): self.model, self.processor = args.get_model_processor(**kwargs, download_model=args.load is None) self._prepare_template() - args.init_model_args(self.tokenizer, self.processor.model_info.config) args.save_args(args.save) self.template.use_megatron = True self.trainer = self.prepare_trainer() diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index fff8147b2d..7250f4d9aa 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -200,7 +200,7 @@ def prepare_adapter(args, model): return model -def prepare_mcore_model(args, model): +def prepare_mcore_model(args, model, load_adapters=True): if args.tuner_type == 'full': freeze_parameters(model, args.freeze_parameters_ratio, args.freeze_parameters, args.freeze_parameters_regex) if args.trainable_parameters or args.trainable_parameters_regex: @@ -208,6 +208,12 @@ def prepare_mcore_model(args, model): elif args.tuner_type == 'lora': model.prepare_inputs_for_generation = None # fix error model = prepare_adapter(model) + if args.adapters: + assert len(args.adapters) == 1, 'Currently only support one adapter' + bridge.load_weights(mg_model, args.adapters[0], is_peft_format=True) + elif args.adapter_load is not None: + with adapter_state_dict_context(): + load_mcore_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) logger.info(f'model: {model}') logger.info_if( f'[rank{dist.get_rank()}] model_parameter_info: {get_model_parameter_info(model)}', @@ -300,14 +306,15 @@ def copy_ref_adapter_weight(model, ref_adapter_name: str): def forward_step_helper(args, model, inputs, dtype=None): + config = model.config if mpu.is_pipeline_first_stage(): micro_batch_size = 1 # use qkv_format 'thd' if not args.padding_free: micro_batch_size = args.micro_batch_size seq_length = inputs['position_ids'].shape[-1] - if args.sequence_parallel: + if config.sequence_parallel: seq_length //= mpu.get_tensor_model_parallel_world_size() - recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], + recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, config.hidden_size], device=torch.cuda.current_device(), dtype=torch.int64) else: @@ -318,7 +325,7 @@ def forward_step_helper(args, model, inputs, dtype=None): shape = recv_shape_buffer.tolist() if not mpu.is_pipeline_first_stage(): - dtype = dtype or args.params_dtype + dtype = dtype or config.params_dtype recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=dtype) recv_from_prev_pipeline_rank_(recv_buffer) model.set_input_tensor(recv_buffer) diff --git a/swift/model/register.py b/swift/model/register.py index c2fc230b37..1129d985b7 100644 --- a/swift/model/register.py +++ b/swift/model/register.py @@ -343,22 +343,21 @@ def _postprocess_model(self, model_dir, model): self._init_generation_config(model, model_dir) HfConfigFactory.set_model_config_attr(model, 'pad_token_id', self.pad_token) - def _add_new_special_tokens(self, model, processor): + def _add_new_special_tokens(self, model, processor, config): if not self.new_special_tokens: return tokenizer = self._get_tokenizer(processor) num_new_tokens = tokenizer.add_special_tokens({'additional_special_tokens': self.new_special_tokens}) if num_new_tokens > 0: logger.info(f'Added {num_new_tokens} new special tokens.') - - if model is not None and not self.return_dummy_model: - llm_model = get_lm_head_model(model, self.model_meta) - origin_vocab_size = HfConfigFactory.get_config_attr(llm_model.config, 'vocab_size') - if origin_vocab_size < len(tokenizer): - vocab_size = math.ceil(len(tokenizer) / 128) * 128 + origin_vocab_size = HfConfigFactory.get_config_attr(config, 'vocab_size') + if origin_vocab_size < len(tokenizer): + vocab_size = math.ceil(len(tokenizer) / 128) * 128 + # fix transformers==4.52.4 qwen2.5-vl + HfConfigFactory.set_config_attr(config, 'vocab_size', vocab_size) + if model is not None and not self.return_dummy_model: + llm_model = get_lm_head_model(model, self.model_meta) llm_model.resize_token_embeddings(vocab_size) - # fix transformers==4.52.4 qwen2.5-vl - HfConfigFactory.set_config_attr(llm_model.config, 'vocab_size', vocab_size) def _postprocess_processor(self, processor: Processor): tokenizer = self._get_tokenizer(processor) @@ -457,7 +456,7 @@ def load(self) -> Tuple[Optional[PreTrainedModel], Processor]: self._postprocess_processor(processor) if model: self._postprocess_model(model_dir, model) - self._add_new_special_tokens(model, processor) + self._add_new_special_tokens(model, processor, config) return model, processor From 0a2d2d6e91147c4c4fa370434cf492bfacb17e82 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Feb 2026 16:25:33 +0800 Subject: [PATCH 27/43] update --- swift/megatron/utils/megatron_lm_utils.py | 2 +- swift/megatron/utils/utils.py | 23 ++++++++++++++--------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index 3bf090bbb3..d26f20ffc3 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -218,7 +218,7 @@ def load_mcore_checkpoint(args, model, load_arg: str = 'load'): is_peft_format = True elif load_arg in {'load', 'ref_load'}: is_peft_format = False - model = unwrap_model(model) + model = [unwrap_model(m) for m in model] tracker_path = os.path.join(args.load, 'latest_checkpointed_iteration.txt') iteration = _load_iteration(tracker_path) checkpoint_dir = os.path.join(args.load, f'iter_{iteration:07d}') diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index 7250f4d9aa..06860e9f65 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -201,19 +201,24 @@ def prepare_adapter(args, model): def prepare_mcore_model(args, model, load_adapters=True): - if args.tuner_type == 'full': - freeze_parameters(model, args.freeze_parameters_ratio, args.freeze_parameters, args.freeze_parameters_regex) - if args.trainable_parameters or args.trainable_parameters_regex: - activate_parameters(model, args.trainable_parameters, args.trainable_parameters_regex) - elif args.tuner_type == 'lora': - model.prepare_inputs_for_generation = None # fix error - model = prepare_adapter(model) + from .megatron_lm_utils import load_mcore_checkpoint + for m in model: + if args.tuner_type == 'full': + freeze_parameters(m, args.freeze_parameters_ratio, args.freeze_parameters, args.freeze_parameters_regex) + if args.trainable_parameters or args.trainable_parameters_regex: + activate_parameters(m, args.trainable_parameters, args.trainable_parameters_regex) + elif args.tuner_type == 'lora': + m.prepare_inputs_for_generation = None # fix error + m = prepare_adapter(args, m) + if load_adapters: if args.adapters: + bridge = args.megatron_model_meta.loader.bridge_cls(args) assert len(args.adapters) == 1, 'Currently only support one adapter' - bridge.load_weights(mg_model, args.adapters[0], is_peft_format=True) + for m in model: + bridge.load_weights(m, args.adapters[0], is_peft_format=True) elif args.adapter_load is not None: with adapter_state_dict_context(): - load_mcore_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) + load_mcore_checkpoint(args, model, load_arg='adapter_load') logger.info(f'model: {model}') logger.info_if( f'[rank{dist.get_rank()}] model_parameter_info: {get_model_parameter_info(model)}', From 6bacd38d15ed617d7e3c441d222c13b5d3d77f64 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Feb 2026 17:07:02 +0800 Subject: [PATCH 28/43] update --- swift/megatron/model/gpts/glm4.py | 17 +++--- swift/megatron/model/gpts/minimax_m2.py | 17 +++--- swift/megatron/model/gpts/olmoe.py | 13 +++-- swift/megatron/model/mm_gpts/glm.py | 43 +++++++------- swift/megatron/model/mm_gpts/internvl.py | 16 ++---- swift/megatron/model/mm_gpts/kimi_vl.py | 21 +++---- swift/megatron/model/mm_gpts/llama4.py | 19 ++++--- swift/megatron/model/mm_gpts/qwen.py | 71 ++++++++++-------------- swift/megatron/model/mm_gpts/qwen3_vl.py | 16 ++---- swift/megatron/model/register.py | 20 ++----- swift/megatron/trainers/base.py | 6 +- swift/megatron/utils/utils.py | 32 ++++------- 12 files changed, 122 insertions(+), 169 deletions(-) diff --git a/swift/megatron/model/gpts/glm4.py b/swift/megatron/model/gpts/glm4.py index 676d071ba7..548a0821e3 100644 --- a/swift/megatron/model/gpts/glm4.py +++ b/swift/megatron/model/gpts/glm4.py @@ -94,7 +94,6 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo class Glm4Loader(MegatronModelLoader): - bridge_cls = Glm4Bridge def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): layer_spec = self._get_transformer_layer_spec() @@ -104,10 +103,12 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): return layer_spec -register_megatron_model(MegatronModelMeta( - MegatronModelType.glm4, - [ - ModelType.glm4, - ], - Glm4Loader, -)) +register_megatron_model( + MegatronModelMeta( + MegatronModelType.glm4, + [ + ModelType.glm4, + ], + bridge_cls=Glm4Bridge, + loader=Glm4Loader, + )) diff --git a/swift/megatron/model/gpts/minimax_m2.py b/swift/megatron/model/gpts/minimax_m2.py index fc0b3f99c8..9fd4b739a8 100644 --- a/swift/megatron/model/gpts/minimax_m2.py +++ b/swift/megatron/model/gpts/minimax_m2.py @@ -102,7 +102,6 @@ def _set_moe_state( class MinimaxM2Loader(MegatronModelLoader): - bridge_cls = MinimaxM2Bridge def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): layer_spec = self._get_transformer_layer_spec() @@ -110,10 +109,12 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): return layer_spec -register_megatron_model(MegatronModelMeta( - MegatronModelType.minimax_m2, - [ - ModelType.minimax_m2, - ], - MinimaxM2Loader, -)) +register_megatron_model( + MegatronModelMeta( + MegatronModelType.minimax_m2, + [ + ModelType.minimax_m2, + ], + bridge_cls=MinimaxM2Bridge, + loader=MinimaxM2Loader, + )) diff --git a/swift/megatron/model/gpts/olmoe.py b/swift/megatron/model/gpts/olmoe.py index d60812b6b0..4668c28a5d 100644 --- a/swift/megatron/model/gpts/olmoe.py +++ b/swift/megatron/model/gpts/olmoe.py @@ -218,14 +218,15 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int class OlMoELoader(MegatronModelLoader): - bridge_cls = OLMoEBridge def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): return get_olmoe_decoder_block_spec(self.config, vp_stage) -register_megatron_model(MegatronModelMeta( - MegatronModelType.olmoe, - [ModelType.olmoe], - OlMoELoader, -)) +register_megatron_model( + MegatronModelMeta( + MegatronModelType.olmoe, + [ModelType.olmoe], + bridge_cls=OLMoEBridge, + loader=OlMoELoader, + )) diff --git a/swift/megatron/model/mm_gpts/glm.py b/swift/megatron/model/mm_gpts/glm.py index 05c9a394c4..48ea3cb6e6 100644 --- a/swift/megatron/model/mm_gpts/glm.py +++ b/swift/megatron/model/mm_gpts/glm.py @@ -21,33 +21,28 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return Template._get_inputs_embeds_hf(inputs_embeds, kwargs, self.visual, self.processor, self.hf_config) -class Glm4vMoeLoader(MegatronModelLoader): - bridge_cls = MultimodalGPTBridge - visual_cls = Glm4vVit - - -register_megatron_model(MegatronModelMeta( - MegatronModelType.glm4v_moe, - [ - ModelType.glm4v_moe, - ], - Glm4vMoeLoader, -)) +register_megatron_model( + MegatronModelMeta( + MegatronModelType.glm4v_moe, + [ + ModelType.glm4v_moe, + ], + bridge_cls=MultimodalGPTBridge, + visual_cls=Glm4vVit, + )) class Glm4vBridge(Glm4Bridge, MultimodalGPTBridge): pass -class Glm4vLoader(Glm4Loader): - bridge_cls = Glm4vBridge - visual_cls = Glm4vVit - - -register_megatron_model(MegatronModelMeta( - MegatronModelType.glm4v, - [ - ModelType.glm4v, - ], - Glm4vLoader, -)) +register_megatron_model( + MegatronModelMeta( + MegatronModelType.glm4v, + [ + ModelType.glm4v, + ], + bridge_cls=Glm4vBridge, + visual_cls=Glm4vVit, + loader=Glm4Loader, + )) diff --git a/swift/megatron/model/mm_gpts/internvl.py b/swift/megatron/model/mm_gpts/internvl.py index 469538b5dd..d7a637f526 100644 --- a/swift/megatron/model/mm_gpts/internvl.py +++ b/swift/megatron/model/mm_gpts/internvl.py @@ -59,11 +59,6 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds -class InternvlLoader(MegatronModelLoader): - bridge_cls = Internvl3Bridge - visual_cls = Internvl3Vit - - register_megatron_model( MegatronModelMeta( MegatronModelType.internvl3, @@ -72,7 +67,8 @@ class InternvlLoader(MegatronModelLoader): ModelType.internvl3_5, ModelType.internvl3_5_gpt, ], - InternvlLoader, + bridge_cls=Internvl3Bridge, + visual_cls=Internvl3Vit, )) @@ -133,11 +129,6 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds -class InternvlHfLoader(MegatronModelLoader): - bridge_cls = InternvlHfBridge - visual_cls = InternvlHfVit - - register_megatron_model( MegatronModelMeta( MegatronModelType.internvl_hf, @@ -145,5 +136,6 @@ class InternvlHfLoader(MegatronModelLoader): ModelType.internvl_hf, ModelType.internvl_gpt_hf, ], - InternvlHfLoader, + bridge_cls=InternvlHfBridge, + visual_cls=InternvlHfVit, )) diff --git a/swift/megatron/model/mm_gpts/kimi_vl.py b/swift/megatron/model/mm_gpts/kimi_vl.py index 888eb0888a..cce3d39024 100644 --- a/swift/megatron/model/mm_gpts/kimi_vl.py +++ b/swift/megatron/model/mm_gpts/kimi_vl.py @@ -47,15 +47,12 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds -class KimiLoader(MegatronModelLoader): - bridge_cls = KimiVLBridge - visual_cls = KimiVLVit - - -register_megatron_model(MegatronModelMeta( - MegatronModelType.kimi_vl, - [ - ModelType.kimi_vl, - ], - KimiLoader, -)) +register_megatron_model( + MegatronModelMeta( + MegatronModelType.kimi_vl, + [ + ModelType.kimi_vl, + ], + bridge_cls=KimiVLBridge, + visual_cls=KimiVLVit, + )) diff --git a/swift/megatron/model/mm_gpts/llama4.py b/swift/megatron/model/mm_gpts/llama4.py index af66272bed..cd9a5329ae 100644 --- a/swift/megatron/model/mm_gpts/llama4.py +++ b/swift/megatron/model/mm_gpts/llama4.py @@ -59,8 +59,6 @@ class Llama4Bridge(GPTBridge): class Llama4Loader(MegatronModelLoader): - bridge_cls = Llama4Bridge - visual_cls = Llama4Vit def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): layer_specs = super().get_transformer_layer_spec(vp_stage) @@ -75,10 +73,13 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): return layer_specs -register_megatron_model(MegatronModelMeta( - MegatronModelType.llama4, - [ - ModelType.llama4, - ], - Llama4Loader, -)) +register_megatron_model( + MegatronModelMeta( + MegatronModelType.llama4, + [ + ModelType.llama4, + ], + bridge_cls=Llama4Bridge, + visual_cls=Llama4Vit, + loader=Llama4Loader, + )) diff --git a/swift/megatron/model/mm_gpts/qwen.py b/swift/megatron/model/mm_gpts/qwen.py index de0bc5b6fc..0149a2763f 100644 --- a/swift/megatron/model/mm_gpts/qwen.py +++ b/swift/megatron/model/mm_gpts/qwen.py @@ -46,36 +46,30 @@ class Qwen2_5VLBridge(MultimodalGPTBridge): } -class Qwen2_5VLLoader(MegatronModelLoader): - bridge_cls = Qwen2_5VLBridge - visual_cls = Qwen2_5VL_Vit - - -register_megatron_model(MegatronModelMeta( - MegatronModelType.qwen2_5_vl, - [ - ModelType.qwen2_5_vl, - ], - Qwen2_5VLLoader, -)) +register_megatron_model( + MegatronModelMeta( + MegatronModelType.qwen2_5_vl, + [ + ModelType.qwen2_5_vl, + ], + bridge_cls=Qwen2_5VLBridge, + visual_cls=Qwen2_5VL_Vit, + )) class Qwen2VL_Vit(Qwen2_5VL_Vit): version = 'v2' -class Qwen2VLLoader(Qwen2_5VLLoader): - bridge_cls = Qwen2_5VLBridge - visual_cls = Qwen2VL_Vit - - -register_megatron_model(MegatronModelMeta( - MegatronModelType.qwen2_vl, - [ - ModelType.qwen2_vl, - ], - Qwen2VLLoader, -)) +register_megatron_model( + MegatronModelMeta( + MegatronModelType.qwen2_vl, + [ + ModelType.qwen2_vl, + ], + bridge_cls=Qwen2_5VLBridge, + visual_cls=Qwen2VLLoader, + )) class Qwen2_5OmniBridge(GPTBridge): @@ -122,18 +116,14 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds -class Qwen2_5OmniLoader(MegatronModelLoader): - bridge_cls = Qwen2_5OmniBridge, - visual_cls = Qwen2_5Omni_Vit - - register_megatron_model( MegatronModelMeta( MegatronModelType.qwen2_5_omni, [ ModelType.qwen2_5_omni, ], - Qwen2_5OmniLoader, + bridge_cls=Qwen2_5OmniBridge, + visual_cls=Qwen2_5Omni_Vit, )) @@ -184,15 +174,12 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds -class Ovis2_5Loader(MegatronModelLoader): - bridge_cls = Ovis2_5Bridge - visual_cls = Ovis2_5Vit - - -register_megatron_model(MegatronModelMeta( - MegatronModelType.ovis2_5, - [ - ModelType.ovis2_5, - ], - Ovis2_5Loader, -)) +register_megatron_model( + MegatronModelMeta( + MegatronModelType.ovis2_5, + [ + ModelType.ovis2_5, + ], + bridge_cls=Ovis2_5Bridge, + visual_cls=Ovis2_5Vit, + )) diff --git a/swift/megatron/model/mm_gpts/qwen3_vl.py b/swift/megatron/model/mm_gpts/qwen3_vl.py index c9cdd60ca7..c3de699275 100644 --- a/swift/megatron/model/mm_gpts/qwen3_vl.py +++ b/swift/megatron/model/mm_gpts/qwen3_vl.py @@ -476,8 +476,6 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): class Qwen3VLLoader(MegatronModelLoader): - bridge_cls = MultimodalGPTBridge - visual_cls = Qwen3VL_Vit def _patch_transformer_block(self): if hasattr(gpt_model, 'OriginTransformerBlock'): @@ -499,20 +497,18 @@ def __init__(self, args, hf_config): ModelType.qwen3_vl_emb, ModelType.qwen3_vl_reranker, ], - Qwen3VLLoader, + bridge_cls=MultimodalGPTBridge, + visual_cls=Qwen3VL_Vit, + loader=Qwen3VLLoader, )) - -class Qwen3OmniLoader(Qwen3VLLoader): - bridge_cls = Qwen3OmniBridge - visual_cls = Qwen3Omni_Vit - - register_megatron_model( MegatronModelMeta( MegatronModelType.qwen3_omni, [ ModelType.qwen3_omni_moe, ], - Qwen3OmniLoader, + bridge_cls=Qwen3OmniBridge, + visual_cls=Qwen3Omni_Vit, + loader=Qwen3VLLoader, )) diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 04226d119c..de8b3f1851 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -30,8 +30,10 @@ class MegatronModelMeta: megatron_model_type: str model_types: List[str] - loader: Optional[Type['MegatronModelLoader']] = None + bridge_cls: Type[GPTBridge] = GPTBridge + visual_cls: Optional[Type[nn.Module]] = None is_multimodal: bool = False + loader: Optional[Type['MegatronModelLoader']] = None def __post_init__(self): if self.megatron_model_type in MLLMMegatronModelType.__dict__: @@ -67,8 +69,6 @@ def get_megatron_model_meta(model_type: str) -> Optional[MegatronModelMeta]: class MegatronModelLoader: model_cls = None - visual_cls = None - bridge_cls = GPTBridge def __init__(self, args, hf_config): from swift.megatron.model import GPTModel, MultimodalGPTModel @@ -79,7 +79,6 @@ def __init__(self, args, hf_config): if self.model_cls is None: self.model_cls = MultimodalGPTModel if self.args.is_multimodal else GPTModel self._init_config() - self.bridge = self.bridge_cls(args) def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): if self.config.num_moe_experts: @@ -126,7 +125,6 @@ def build_model( pre_process=True, post_process=True, vp_stage: Optional[int] = None, - load_weights: bool = True, ) -> Union['GPTModel', 'MultimodalGPTModel']: transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) self._set_shared_expert_gate(transformer_layer_spec) @@ -139,13 +137,6 @@ def build_model( pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) - if load_weights: - if self.args.load is not None: - load_mcore_checkpoint([mg_model], None, None, strict=True) - elif self.args.model is not None: - self.bridge.load_weights(model, self.args.model_dir) - else: - raise ValueError('Please specify `--load` or `--model`.') return model def _init_config(self): @@ -179,7 +170,6 @@ def _init_model(self, def get_mcore_model( args, hf_config, - load_weights: bool = True, ): loader = args.megatron_model_meta.loader(args, hf_config) if (mpu.get_pipeline_model_parallel_world_size() > 1 and args.virtual_pipeline_model_parallel_size is not None): @@ -187,10 +177,10 @@ def get_mcore_model( for i in range(args.virtual_pipeline_model_parallel_size): pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) - model = loader.build_model(pre_process, post_process, vp_stage=i, load_weights=load_weights) + model = loader.build_model(pre_process, post_process, vp_stage=i) models.append(model) else: pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() - models = [loader.build_model(pre_process=pre_process, post_process=post_process, load_weights=load_weights)] + models = [loader.build_model(pre_process=pre_process, post_process=post_process)] return models diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 5cc698d411..30c562a09e 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -37,6 +37,7 @@ from packaging import version from tqdm.auto import tqdm +from swift.megatron.model import get_mcore_model from swift.megatron.tuners import LoraParallelLinear from swift.megatron.utils import (adapter_state_dict_context, copy_original_module_weight, get_padding_to, load_mcore_checkpoint, patch_merge_fn, prepare_mcore_model) @@ -66,9 +67,10 @@ class BaseMegatronTrainer(ABC): def __init__(self, args, template: Template): self.args = args self.template = template - self.unwrapped_models = [] + hf_config = template.config + self.unwrapped_models = get_mcore_model(args, hf_config) + self.peft_models = prepare_mcore_model(args, self.unwrapped_models) self.wrapped_models = [] - self.peft_models = [] self._bridge = None self.eval_metrics = None logging_path = os.path.join(args.save, 'logging.jsonl') diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index 06860e9f65..3ada1c516d 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -200,30 +200,20 @@ def prepare_adapter(args, model): return model -def prepare_mcore_model(args, model, load_adapters=True): +def prepare_mcore_model(args, model): from .megatron_lm_utils import load_mcore_checkpoint - for m in model: - if args.tuner_type == 'full': - freeze_parameters(m, args.freeze_parameters_ratio, args.freeze_parameters, args.freeze_parameters_regex) - if args.trainable_parameters or args.trainable_parameters_regex: - activate_parameters(m, args.trainable_parameters, args.trainable_parameters_regex) - elif args.tuner_type == 'lora': - m.prepare_inputs_for_generation = None # fix error - m = prepare_adapter(args, m) - if load_adapters: - if args.adapters: - bridge = args.megatron_model_meta.loader.bridge_cls(args) - assert len(args.adapters) == 1, 'Currently only support one adapter' - for m in model: - bridge.load_weights(m, args.adapters[0], is_peft_format=True) - elif args.adapter_load is not None: - with adapter_state_dict_context(): - load_mcore_checkpoint(args, model, load_arg='adapter_load') - logger.info(f'model: {model}') + if args.tuner_type == 'full': + freeze_parameters(m, args.freeze_parameters_ratio, args.freeze_parameters, args.freeze_parameters_regex) + if args.trainable_parameters or args.trainable_parameters_regex: + activate_parameters(m, args.trainable_parameters, args.trainable_parameters_regex) + elif args.tuner_type == 'lora': + m.prepare_inputs_for_generation = None # fix error + m = prepare_adapter(args, m) + logger.info(f'model: {m}') logger.info_if( - f'[rank{dist.get_rank()}] model_parameter_info: {get_model_parameter_info(model)}', + f'[rank{dist.get_rank()}] model_parameter_info: {get_model_parameter_info(m)}', cond=mpu.get_data_parallel_rank() == 0) - return model + return model @contextmanager From 12adf2a8957b389ffeec39b8928a2e16f7b1dd93 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Feb 2026 17:31:22 +0800 Subject: [PATCH 29/43] fix --- swift/megatron/convert.py | 9 ++++++--- swift/megatron/model/mm_gpts/qwen.py | 2 +- swift/megatron/model/register.py | 2 +- swift/megatron/utils/utils.py | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 249f4ac75c..a82e44f2b2 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -50,11 +50,15 @@ def convert_hf2mcore(args: ExportArguments) -> None: mg_model = get_mcore_model(megatron_args, hf_config)[0] logger.info('Megatron model created successfully.') + bridge = megatron_args.megatron_model_meta.bridge_cls(megatron_args) + bridge.load_weights(mg_model, args.model_info.model_dir) + logger.info('Successfully transferred HF model weights to MG model.') _test_convert_precision = strtobool(os.getenv('SWIFT_TEST_CONVERT_PRECISION', '0')) if not _test_convert_precision: args.save_args() logger.info('Saving the model...') save_mcore_checkpoint(megatron_args, [mg_model]) + logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.') # Place it at the end to avoid test_convert_precision affecting precision. if args.test_convert_precision: test_convert_precision(megatron_args, hf_model, mg_model, template, test_convert_dtype=args.test_convert_dtype) @@ -82,19 +86,18 @@ def convert_mcore2hf(args: ExportArguments) -> None: save=args.output_dir if args.to_mcore else None, torch_dtype=args.torch_dtype) - mg_model = megatron_model_meta.model_provider(megatron_args) + mg_model = get_mcore_model(megatron_args, hf_config)[0] if megatron_args.load is None: raise ValueError('Please specify `--mcore_model`.') load_mcore_checkpoint(megatron_args, [mg_model], load_arg='load') if megatron_args.adapter_load is not None: peft_model = prepare_mcore_model(mg_model) - # with adapter_state_dict_context(): load_mcore_checkpoint(megatron_args, [mg_model], load_arg='adapter_load') logger.info('Merge LoRA...') mg_model = peft_model.merge_and_unload() logger.info('Megatron model created successfully.') if args.to_hf: - bridge = megatron_model_meta.bridge_cls(megatron_args) + bridge = megatron_args.megatron_model_meta.bridge_cls(megatron_args) logger.info('Converting weights and saving the model...') bridge.save_weights([mg_model], args.output_dir, processor=processor, hf_config=hf_config) if is_master(): diff --git a/swift/megatron/model/mm_gpts/qwen.py b/swift/megatron/model/mm_gpts/qwen.py index 0149a2763f..a9f75b02de 100644 --- a/swift/megatron/model/mm_gpts/qwen.py +++ b/swift/megatron/model/mm_gpts/qwen.py @@ -68,7 +68,7 @@ class Qwen2VL_Vit(Qwen2_5VL_Vit): ModelType.qwen2_vl, ], bridge_cls=Qwen2_5VLBridge, - visual_cls=Qwen2VLLoader, + visual_cls=Qwen2VL_Vit, )) diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index de8b3f1851..b43105f90c 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -2,7 +2,7 @@ import math from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Type, Union - +from torch import nn import megatron.core from megatron.core import mpu from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index 3ada1c516d..7e0976edea 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -213,7 +213,7 @@ def prepare_mcore_model(args, model): logger.info_if( f'[rank{dist.get_rank()}] model_parameter_info: {get_model_parameter_info(m)}', cond=mpu.get_data_parallel_rank() == 0) - return model + return model @contextmanager From 198f5492daf6261185d3cb193f921b4de77c52ac Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Feb 2026 22:41:07 +0800 Subject: [PATCH 30/43] update --- swift/megatron/model/register.py | 3 +- swift/megatron/pipelines/export/export.py | 19 ++++- swift/megatron/trainers/base.py | 90 ++++++++++++++++++++++- swift/megatron/utils/utils.py | 12 +-- 4 files changed, 111 insertions(+), 13 deletions(-) diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index b43105f90c..4af2334bea 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -2,13 +2,14 @@ import math from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Type, Union -from torch import nn + import megatron.core from megatron.core import mpu from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, get_gpt_layer_with_transformer_engine_spec, get_gpt_mtp_block_spec) from packaging import version +from torch import nn from transformers.utils import is_torch_npu_available from swift.model import MODEL_MAPPING diff --git a/swift/megatron/pipelines/export/export.py b/swift/megatron/pipelines/export/export.py index 3b89abea11..30d9357968 100644 --- a/swift/megatron/pipelines/export/export.py +++ b/swift/megatron/pipelines/export/export.py @@ -43,7 +43,7 @@ def convert_mcore2hf(self) -> None: mg_model = megatron_model_meta.model_provider(args, pre_process=pre_process, post_process=post_process) bridge = megatron_model_meta.bridge_cls(args) if args.load is not None: - load_mcore_checkpoint([mg_model], None, None, strict=True) + load_mcore_checkpoint(args, [mg_model], load_arg='load') elif args.model is not None: bridge.load_weights(mg_model, args.model_info.model_dir) else: @@ -51,7 +51,7 @@ def convert_mcore2hf(self) -> None: if args.adapter_load is not None: peft_model = prepare_mcore_model(mg_model) with adapter_state_dict_context(): - load_mcore_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) + load_mcore_checkpoint(args, [mg_model], load_arg='adapter_load') if args.merge_lora: logger.info('Merge LoRA...') mg_model = peft_model.merge_and_unload() @@ -89,12 +89,27 @@ def convert_hf2mcore(self) -> None: hf_config = self.processor.model_info.config mg_model = get_mcore_model(args, hf_config)[0] logger.info('Megatron model created successfully.') + bridge = args.megatron_model_meta.bridge_cls(args) + if args.model is not None: + bridge.load_weights(mg_model, args.model_info.model_dir) + elif args.load is not None: + with patch_load_base_checkpoint(): + load_mcore_checkpoint(args, [mg_model], load_arg='load') + else: + raise ValueError('Please specify `--load` or `--model`.') dist.barrier() if args.adapters or args.adapter_load is not None: peft_model = prepare_mcore_model(mg_model, load_adapters=True) + if args.adapters: + assert len(args.adapters) == 1, 'Currently only support one adapter' + bridge.load_weights(mg_model, args.adapters[0], is_peft_format=True) + elif args.adapter_load is not None: + with adapter_state_dict_context(): + load_mcore_checkpoint(args, [mg_model], load_arg='adapter_load') if args.merge_lora: logger.info('Merge LoRA...') mg_model = peft_model.merge_and_unload() + logger.info('Successfully transferred HF model weights to MG model.') _test_convert_precision = strtobool(os.getenv('SWIFT_TEST_CONVERT_PRECISION', '0')) if not _test_convert_precision: args.save_args(args.save) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 30c562a09e..364c87f6a8 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -16,12 +16,13 @@ import torch.nn from megatron.core import mpu from megatron.core.datasets.utils import Split +from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.enums import ModelType from megatron.core.num_microbatches_calculator import get_num_microbatches, update_num_microbatches from megatron.core.optimizer import _update_min_and_max_lr_in_param_groups from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine -from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.module import Float16Module, MegatronModule from megatron.core.transformer.moe.moe_utils import track_moe_metrics from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper from megatron.core.utils import StragglerDetector, unwrap_model @@ -67,11 +68,12 @@ class BaseMegatronTrainer(ABC): def __init__(self, args, template: Template): self.args = args self.template = template + self._bridge = None hf_config = template.config - self.unwrapped_models = get_mcore_model(args, hf_config) - self.peft_models = prepare_mcore_model(args, self.unwrapped_models) + self.unwrapped_models = [] + self.peft_models = [] self.wrapped_models = [] - self._bridge = None + self.prepare_model() self.eval_metrics = None logging_path = os.path.join(args.save, 'logging.jsonl') logger.info(f'logging_path: {logging_path}') @@ -98,6 +100,86 @@ def _get_mean_metric(): } self.mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + def prepare_model(self): + self.unwrapped_models = get_mcore_model(args, hf_config) + for model in self.unwrapped_models: + if args.load is None: + self.bridge.load_weights(model, args.model_dir) + + model = prepare_mcore_model(args, model) + if args.tuner_type == 'lora': + if args.adapters and args.adapter_load is None: + assert len(args.adapters) == 1, 'Currently only support one adapter.' + self.bridge.load_weights(model, args.adapters[0], is_peft_format=True, adapter_name='default') + if args.ref_adapters and args.ref_adapter_load is None: + assert len(args.ref_adapters) == 1, 'Currently only support one adapter.' + self.bridge.load_weights( + model, args.ref_adapters[0], is_peft_format=True, adapter_name='ref_adapter') + # Set tensor model parallel attributes if not set. + # Only parameters that are already tensor model parallel have these + # attributes set for them. We should make sure the default attributes + # are set for all params so the optimizer can use them. + for param in model.parameters(): + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + if args.use_cpu_initialization: + model.cuda(torch.cuda.current_device()) + self.peft_models.append(model) + # Fp16 + if args.fp16 or args.bf16: + config = get_model_config(model[0]) + model = [Float16Module(config, model_module) for model_module in model] + + # DDP + kwargs = {} + for f in dataclasses.fields(DistributedDataParallelConfig): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32 + kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad + kwargs['check_for_large_grads'] = args.check_for_large_grads + if args.ddp_num_buckets is not None: + assert args.ddp_bucket_size is None, \ + 'Cannot specify both --ddp-num-buckets and --ddp-bucket-size' + assert args.ddp_num_buckets > 0, \ + '--ddp-num-buckets must be greater than 0' + kwargs['bucket_size'] = num_parameters // args.ddp_num_buckets + else: + kwargs['bucket_size'] = args.ddp_bucket_size + kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw + kwargs['average_in_collective'] = args.ddp_average_in_collective + if args.use_megatron_fsdp and args.use_precision_aware_optimizer: + kwargs['preserve_fp32_weights'] = False + ddp_config = DistributedDataParallelConfig(**kwargs) + + # In the Megatron FSDP and DDP use path, we need to initialize the bucket size. + # If bucket_size is not provided as an input, use sane default. + # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL + # ring-reduce implementations are large enough to remain bandwidth-bound rather than + # latency-bound. + if ddp_config.bucket_size is None: + ddp_config.bucket_size = max(40000000, + 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True)) + # Set bucket_size to infinity if overlap_grad_reduce is False. + if not ddp_config.overlap_grad_reduce: + ddp_config.bucket_size = None + + with torch.cuda.stream(torch.cuda.Stream()): + model = [ + DP( + config=config, + ddp_config=ddp_config, + module=model_chunk, + # Turn off bucketing for model_chunk 2 onwards, since communication for these + # model chunks is overlapped with compute anyway. + disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step, + ) for (model_chunk_idx, model_chunk) in enumerate(model) + ] + + # Broadcast params from data parallel src rank to other data parallel ranks. + if args.data_parallel_random_init: + for model_module in model: + model_module.broadcast_params() + def _get_data_collator(self): data_collator = self.template.data_collator padding_to = get_padding_to(self.args) diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index 7e0976edea..a82dceaaf8 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -203,15 +203,15 @@ def prepare_adapter(args, model): def prepare_mcore_model(args, model): from .megatron_lm_utils import load_mcore_checkpoint if args.tuner_type == 'full': - freeze_parameters(m, args.freeze_parameters_ratio, args.freeze_parameters, args.freeze_parameters_regex) + freeze_parameters(model, args.freeze_parameters_ratio, args.freeze_parameters, args.freeze_parameters_regex) if args.trainable_parameters or args.trainable_parameters_regex: - activate_parameters(m, args.trainable_parameters, args.trainable_parameters_regex) + activate_parameters(model, args.trainable_parameters, args.trainable_parameters_regex) elif args.tuner_type == 'lora': - m.prepare_inputs_for_generation = None # fix error - m = prepare_adapter(args, m) - logger.info(f'model: {m}') + model.prepare_inputs_for_generation = None # fix error + model = prepare_adapter(args, model) + logger.info(f'model: {model}') logger.info_if( - f'[rank{dist.get_rank()}] model_parameter_info: {get_model_parameter_info(m)}', + f'[rank{dist.get_rank()}] model_parameter_info: {get_model_parameter_info(model)}', cond=mpu.get_data_parallel_rank() == 0) return model From 6e515672555c130a5121ba6387df8931d19d55c2 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Feb 2026 23:52:02 +0800 Subject: [PATCH 31/43] update --- swift/megatron/arguments/megatron_args.py | 61 +++++- swift/megatron/pipelines/train/sft.py | 2 +- swift/megatron/trainers/base.py | 237 +++++----------------- swift/megatron/utils/__init__.py | 3 +- swift/megatron/utils/megatron_lm_utils.py | 97 ++++++++- 5 files changed, 206 insertions(+), 194 deletions(-) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 07b77dc2b3..e313a2ad79 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -7,6 +7,7 @@ import json import megatron.core import torch +from megatron.core import mpu from packaging import version from transformers.utils import is_torch_npu_available from transformers.utils.versions import require_version @@ -353,6 +354,7 @@ class MegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): manual_gc_interval: int = 0 # learning rate + lr_warmup_init: float = 0. lr: Optional[float] = None lr_decay_style: Literal['cosine', 'linear', 'constant'] = 'cosine' # The default is None, which will be set to `train_iters`. @@ -363,6 +365,9 @@ class MegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): # regularization weight_decay: float = 0.1 + weight_decay_incr_style: Literal['constant', 'linear', 'cosine'] = 'constant' + start_weight_decay: Optional[float] = None + end_weight_decay: Optional[float] = None clip_grad: float = 1. adam_beta1: float = 0.9 adam_beta2: float = 0.95 @@ -372,7 +377,7 @@ class MegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): # checkpoint save: Optional[str] = None save_interval: int = 500 - # save_retain_interval: Optional[int] = None + save_retain_interval: Optional[int] = None no_save_optim: bool = False no_save_rng: bool = False load: Optional[str] = None @@ -610,6 +615,7 @@ def __post_init__(self): if self.adapters: self._load_adapter_config() self._init_mixed_precision() + self._init_multimodal_full() initialize_megatron(self) @@ -635,3 +641,56 @@ def _load_adapter_config(self): if v != getattr(self, k): setattr(self, k, v) logger.info(f'Setting {k}: {v}') + + def init_iters(self, train_dataset, val_dataset): + data_parallel_size = mpu.get_data_parallel_world_size() + step_batch_size = self.micro_batch_size * data_parallel_size + num_generations = self.num_generations if self.rlhf_type == 'grpo' else 1 + if self.save_strategy == 'epoch': + if hasattr(train_dataset, '__len__'): + dataset_sample = len(train_dataset) // step_batch_size * step_batch_size * num_generations + self.save_interval = dataset_sample // self.global_batch_size + self.eval_interval = self.save_interval + # TODO + if getattr(self, 'save_retain_interval', None) is not None: + self.save_retain_interval *= self.save_interval + else: + raise ValueError('streaming dataset is not supported with `--save_strategy epoch`.') + if self.max_epochs is not None: + if hasattr(train_dataset, '__len__'): + dataset_sample = len(train_dataset) // step_batch_size * step_batch_size * num_generations + self.train_iters = dataset_sample * self.max_epochs // self.global_batch_size + elif self.train_iters is None: + raise ValueError( + 'You are using a streaming training dataset. Please explicitly specify `--train_iters`.') + if self.eval_iters < 0: + if val_dataset is None: + self.eval_iters = 0 + elif hasattr(val_dataset, '__len__'): + dataset_sample = len(val_dataset) // step_batch_size * step_batch_size + dataset_sample = dataset_sample * num_generations + self.eval_iters = max(dataset_sample // self.global_batch_size, 1) + else: + raise ValueError( + 'You are using a streaming validation dataset. Please explicitly specify `--eval_iters`.') + logger.info(f'Setting args.eval_iters: {self.eval_iters}') + + def _init_multimodal_full(self): + visual_cls = self.megatron_model_meta.visual_cls + if self.tuner_type == 'full' and self.is_multimodal and visual_cls is not None: + vision_tower = [f'visual.{vit}' for vit in getattr(visual_cls, '_vision_tower', [])] + aligner = [f'visual.{aligner}' for aligner in getattr(visual_cls, '_aligner', [])] + generator = [f'visual.{generator}' for generator in getattr(visual_cls, '_generator', [])] + if self.freeze_llm: + self.freeze_parameters.append('language_model') + if self.freeze_vit: + self.freeze_parameters += vision_tower + if self.freeze_aligner: + self.freeze_parameters += aligner + else: + self.trainable_parameters += aligner + self.freeze_parameters += generator + if self.freeze_parameters: + logger.info(f'freeze_parameters: {self.freeze_parameters}') + if self.trainable_parameters: + logger.info(f'additional trainable_parameters: {self.trainable_parameters}') diff --git a/swift/megatron/pipelines/train/sft.py b/swift/megatron/pipelines/train/sft.py index 52ce25cac1..f50c0e2cc7 100644 --- a/swift/megatron/pipelines/train/sft.py +++ b/swift/megatron/pipelines/train/sft.py @@ -61,7 +61,7 @@ def __init__(self, args: Optional[Union[List[str], MegatronSftArguments]] = None def run(self): args = self.args train_dataset, val_dataset = self._prepare_dataset() - + args.init_iters(train_dataset, val_dataset) # if args.streaming: # train_dataset = build_streaming_dataloader(args, train_dataset, data_collator) # if val_dataset is not None: diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 364c87f6a8..8ba36a92db 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import collections +import dataclasses import logging import math import os @@ -14,18 +15,16 @@ import megatron.core import torch import torch.nn -from megatron.core import mpu +from megatron.core import mpu, tensor_parallel from megatron.core.datasets.utils import Split -from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.enums import ModelType from megatron.core.num_microbatches_calculator import get_num_microbatches, update_num_microbatches -from megatron.core.optimizer import _update_min_and_max_lr_in_param_groups +from megatron.core.optimizer import OptimizerConfig, _update_min_and_max_lr_in_param_groups, get_megatron_optimizer from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine from megatron.core.transformer.module import Float16Module, MegatronModule from megatron.core.transformer.moe.moe_utils import track_moe_metrics from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper -from megatron.core.utils import StragglerDetector, unwrap_model # from megatron.training import (checkpointing, ft_integration, get_args, get_model, get_tensorboard_writer, # get_wandb_writer, initialize, is_last_rank, one_logger_utils, pretrain, print_rank_0, # print_rank_last, training) @@ -40,8 +39,9 @@ from swift.megatron.model import get_mcore_model from swift.megatron.tuners import LoraParallelLinear -from swift.megatron.utils import (adapter_state_dict_context, copy_original_module_weight, get_padding_to, - load_mcore_checkpoint, patch_merge_fn, prepare_mcore_model) +from swift.megatron.utils import (adapter_state_dict_context, copy_original_module_weight, + get_optimizer_param_scheduler, get_padding_to, load_mcore_checkpoint, patch_merge_fn, + prepare_mcore_model, wrap_model) from swift.metrics import MeanMetric from swift.template import Template from swift.trainers import SwiftMixin, dynamic_gradient_checkpointing @@ -68,17 +68,33 @@ class BaseMegatronTrainer(ABC): def __init__(self, args, template: Template): self.args = args self.template = template - self._bridge = None - hf_config = template.config - self.unwrapped_models = [] - self.peft_models = [] - self.wrapped_models = [] + self.bridge = args.megatron_model_meta.bridge_cls(args) self.prepare_model() + self.optimizer, self.scheduler = self.get_optimizer_and_scheduler() + self.data_collator = self._get_data_collator() + args.iteration = 0 + args.num_floating_point_operations_so_far = 0 + if args.initialize_embedding: + for m in self.unwrapped_models: + self._initialize_embedding(m) + if args.tuner_type != 'full' and args.modules_to_save: + for m in self.unwrapped_models: + copy_original_module_weight(m) + if args.ref_adapter_load is not None: + with self._patch_load_state_dict(self._load_adapter_base_checkpoint): + load_mcore_checkpoint(args, self.wrapped_models, load_arg='ref_adapter_load') + if args.adapter_load is not None: + with adapter_state_dict_context(): + args.iteration, args.num_floating_point_operations_so_far = load_mcore_checkpoint( + args, self.wrapped_models, self.optimizer, self.scheduler, load_arg='adapter_load', strict=False) + if args.is_multimodal: + for m in self.unwrapped_models: + self._prepare_vit_gradient_checkpointing(m) + self.eval_metrics = None logging_path = os.path.join(args.save, 'logging.jsonl') logger.info(f'logging_path: {logging_path}') self.jsonl_writer = JsonlWriter(logging_path, enable_async=True, write_on_rank='last') # for evaluate - # self._patch_megatron() if args.check_model and hasattr(args, 'model_dir'): with ms_logger_context(logging.CRITICAL), patch_modelscope_hub_timeout(): @@ -101,6 +117,10 @@ def _get_mean_metric(): self.mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') def prepare_model(self): + args = self.args + hf_config = self.template.config + self.peft_models = [] + self.wrapped_models = [] self.unwrapped_models = get_mcore_model(args, hf_config) for model in self.unwrapped_models: if args.load is None: @@ -115,70 +135,22 @@ def prepare_model(self): assert len(args.ref_adapters) == 1, 'Currently only support one adapter.' self.bridge.load_weights( model, args.ref_adapters[0], is_peft_format=True, adapter_name='ref_adapter') - # Set tensor model parallel attributes if not set. - # Only parameters that are already tensor model parallel have these - # attributes set for them. We should make sure the default attributes - # are set for all params so the optimizer can use them. - for param in model.parameters(): - tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) - if args.use_cpu_initialization: - model.cuda(torch.cuda.current_device()) self.peft_models.append(model) - # Fp16 - if args.fp16 or args.bf16: - config = get_model_config(model[0]) - model = [Float16Module(config, model_module) for model_module in model] - - # DDP - kwargs = {} - for f in dataclasses.fields(DistributedDataParallelConfig): - if hasattr(args, f.name): - kwargs[f.name] = getattr(args, f.name) - kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32 - kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad - kwargs['check_for_large_grads'] = args.check_for_large_grads - if args.ddp_num_buckets is not None: - assert args.ddp_bucket_size is None, \ - 'Cannot specify both --ddp-num-buckets and --ddp-bucket-size' - assert args.ddp_num_buckets > 0, \ - '--ddp-num-buckets must be greater than 0' - kwargs['bucket_size'] = num_parameters // args.ddp_num_buckets - else: - kwargs['bucket_size'] = args.ddp_bucket_size - kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw - kwargs['average_in_collective'] = args.ddp_average_in_collective - if args.use_megatron_fsdp and args.use_precision_aware_optimizer: - kwargs['preserve_fp32_weights'] = False - ddp_config = DistributedDataParallelConfig(**kwargs) - - # In the Megatron FSDP and DDP use path, we need to initialize the bucket size. - # If bucket_size is not provided as an input, use sane default. - # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL - # ring-reduce implementations are large enough to remain bandwidth-bound rather than - # latency-bound. - if ddp_config.bucket_size is None: - ddp_config.bucket_size = max(40000000, - 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True)) - # Set bucket_size to infinity if overlap_grad_reduce is False. - if not ddp_config.overlap_grad_reduce: - ddp_config.bucket_size = None - - with torch.cuda.stream(torch.cuda.Stream()): - model = [ - DP( - config=config, - ddp_config=ddp_config, - module=model_chunk, - # Turn off bucketing for model_chunk 2 onwards, since communication for these - # model chunks is overlapped with compute anyway. - disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step, - ) for (model_chunk_idx, model_chunk) in enumerate(model) - ] - - # Broadcast params from data parallel src rank to other data parallel ranks. - if args.data_parallel_random_init: - for model_module in model: - model_module.broadcast_params() + self.wrapped_models.append(wrap_model(args, model)) + + def get_optimizer_and_scheduler(self): + args = self.args + kwargs = {} + for f in dataclasses.fields(OptimizerConfig): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + config = OptimizerConfig(**kwargs) + optimizer = get_megatron_optimizer( + config, + self.wrapped_models, + ) + scheduler = get_optimizer_param_scheduler(args, optimizer) + return optimizer, scheduler def _get_data_collator(self): data_collator = self.template.data_collator @@ -187,62 +159,6 @@ def _get_data_collator(self): data_collator = partial(data_collator, padding_to=padding_to) return data_collator - @property - def bridge(self): - if self._bridge is None: - self._bridge = self.args.megatron_model_meta.bridge_cls(self.args) - return self._bridge - - @contextmanager - def _get_iters(self, train_dataset, val_dataset): - origin_initialize_megatron = training.initialize_megatron - origin_validate_args = initialize.validate_args - - def initialize_megatron(*_args, **kwargs): - res = origin_initialize_megatron(*_args, **kwargs) - args = self.args - data_parallel_size = mpu.get_data_parallel_world_size() - step_batch_size = args.micro_batch_size * data_parallel_size - num_generations = args.num_generations if args.rlhf_type == 'grpo' else 1 - if args.save_strategy == 'epoch': - if hasattr(train_dataset, '__len__'): - dataset_sample = len(train_dataset) // step_batch_size * step_batch_size * num_generations - args.save_interval = dataset_sample // args.global_batch_size - args.eval_interval = args.save_interval - if getattr(args, 'save_retain_interval', None) is not None: - args.save_retain_interval *= args.save_interval - else: - raise ValueError('streaming dataset is not supported with `--save_strategy epoch`.') - if args.max_epochs is not None: - if hasattr(train_dataset, '__len__'): - dataset_sample = len(train_dataset) // step_batch_size * step_batch_size * num_generations - args.train_iters = dataset_sample * args.max_epochs // args.global_batch_size - elif args.train_iters is None: - raise ValueError( - 'You are using a streaming training dataset. Please explicitly specify `--train_iters`.') - if args.eval_iters < 0: - if val_dataset is None: - args.eval_iters = 0 - elif hasattr(val_dataset, '__len__'): - dataset_sample = len(val_dataset) // step_batch_size * step_batch_size - dataset_sample = dataset_sample * num_generations - args.eval_iters = max(dataset_sample // args.global_batch_size, 1) - else: - raise ValueError( - 'You are using a streaming validation dataset. Please explicitly specify `--eval_iters`.') - logger.info(f'Setting args.eval_iters: {args.eval_iters}') - return res - - self._origin_validate_args = origin_validate_args - - training.initialize_megatron = initialize_megatron - initialize.validate_args = self.patched_validate_args - try: - yield - finally: - training.initialize_megatron = origin_initialize_megatron - initialize.validate_args = self._origin_validate_args - def new_cyclic_iter(self, iterable): training = self.unwrapped_models[0].training if not training: @@ -566,25 +482,6 @@ def _load_iteration(self): def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): - def new_model_provider_func(*_args, **kwargs): - model = model_provider_func(self.args, *_args, **kwargs) - if args.load is None: - self.bridge.load_weights(model, args.model_dir) - self.unwrapped_models.append(model) - peft_model = prepare_mcore_model(model) - if args.tuner_type == 'lora': - if args.adapters and args.adapter_load is None: - assert len(args.adapters) == 1, 'Currently only support one adapter.' - self.bridge.load_weights(model, args.adapters[0], is_peft_format=True, adapter_name='default') - if args.ref_adapters and args.ref_adapter_load is None: - assert len(args.ref_adapters) == 1, 'Currently only support one adapter.' - self.bridge.load_weights( - model, args.ref_adapters[0], is_peft_format=True, adapter_name='ref_adapter') - - self.peft_models.append(peft_model) - return model - - self._init_multimodal_full() # read iteration args = self.args if not args.finetune: @@ -598,22 +495,6 @@ def new_model_provider_func(*_args, **kwargs): model, optimizer, opt_param_scheduler = self._origin_setup_model_and_optimizer( new_model_provider_func, model_type, *_args, **kwargs) self.wrapped_models = model - if args.initialize_embedding: - for m in self.unwrapped_models: - self._initialize_embedding(m) - if args.tuner_type != 'full' and args.modules_to_save: - for m in self.unwrapped_models: - copy_original_module_weight(m) - if args.ref_adapter_load is not None: - with self._patch_load_state_dict(self._load_adapter_base_checkpoint): - load_mcore_checkpoint(model, optimizer, opt_param_scheduler, load_arg='ref_adapter_load', strict=False) - if args.adapter_load is not None: - with adapter_state_dict_context(): - args.iteration, args.num_floating_point_operations_so_far = load_mcore_checkpoint( - model, optimizer, opt_param_scheduler, load_arg='adapter_load', strict=False) - if args.is_multimodal: - for m in self.unwrapped_models: - self._prepare_vit_gradient_checkpointing(m) return model, optimizer, opt_param_scheduler def _prepare_vit_gradient_checkpointing(self, model): @@ -1287,27 +1168,6 @@ def _patch_megatron(self): self._origin_save_checkpoint = training.save_checkpoint training.save_checkpoint = self.save_checkpoint - def _init_multimodal_full(self): - args = self.args - visual_cls = self.args.megatron_model_meta.visual_cls - if args.tuner_type == 'full' and args.is_multimodal and visual_cls is not None: - vision_tower = [f'visual.{vit}' for vit in getattr(visual_cls, '_vision_tower', [])] - aligner = [f'visual.{aligner}' for aligner in getattr(visual_cls, '_aligner', [])] - generator = [f'visual.{generator}' for generator in getattr(visual_cls, '_generator', [])] - if args.freeze_llm: - args.freeze_parameters.append('language_model') - if args.freeze_vit: - args.freeze_parameters += vision_tower - if args.freeze_aligner: - args.freeze_parameters += aligner - else: - args.trainable_parameters += aligner - args.freeze_parameters += generator - if args.freeze_parameters: - logger.info(f'freeze_parameters: {args.freeze_parameters}') - if args.trainable_parameters: - logger.info(f'additional trainable_parameters: {args.trainable_parameters}') - def train(self, train_dataset, val_dataset): args = self.args datasets_provider = get_swift_datasets_provider(train_dataset, val_dataset) @@ -1453,6 +1313,3 @@ def get_last_tokens(self, output_tensor, packed_seq_params=None, attention_mask= last_token_idx = packed_seq_params.cu_seqlens_q[1:num_samples + 1] - 1 last_tokens = output_tensor[0, last_token_idx] return last_tokens - - def patched_validate_args(self, args, *_args, **kwargs): - return self._origin_validate_args(args, *_args, **kwargs) diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 9c02183206..226f315623 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -2,7 +2,8 @@ from .config import convert_hf_config from .convert_utils import test_convert_precision -from .megatron_lm_utils import initialize_megatron, load_mcore_checkpoint, save_mcore_checkpoint +from .megatron_lm_utils import (get_optimizer_param_scheduler, initialize_megatron, load_mcore_checkpoint, + save_mcore_checkpoint, wrap_model) from .patcher import patch_merge_fn, patch_torch_dist_shard from .utils import (MegatronTrainerState, adapter_state_dict_context, copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index d26f20ffc3..1755d4330d 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -15,8 +15,12 @@ get_default_save_sharded_strategy) from megatron.core.dist_checkpointing.strategies.fully_parallel import (FullyParallelLoadStrategyWrapper, FullyParallelSaveStrategyWrapper) +from megatron.core.distributed import DistributedDataParallel as DDP +from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.msc_utils import open_file from megatron.core.num_microbatches_calculator import update_num_microbatches +from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler +from megatron.core.transformer.module import Float16Module, MegatronModule from megatron.core.utils import unwrap_model from swift.utils import check_json_format, get_logger, init_process_group, is_master, seed_everything, set_device @@ -213,7 +217,7 @@ def _load_iteration(tracker_path: str): return iteration -def load_mcore_checkpoint(args, model, load_arg: str = 'load'): +def load_mcore_checkpoint(args, model, optimizer, scheduler, load_arg: str = 'load'): if load_arg in {'adapter_load', 'ref_adapter_load'}: is_peft_format = True elif load_arg in {'load', 'ref_load'}: @@ -282,3 +286,94 @@ def load_mcore_checkpoint(args, model, load_arg: str = 'load'): logger.info(f'Successfully loaded Megatron model weights from: {args.load}') return iteration, num_floating_point_operations_so_far + + +def wrap_model(args, model, wrap_with_ddp: bool = True): + # Set tensor model parallel attributes if not set. + # Only parameters that are already tensor model parallel have these + # attributes set for them. We should make sure the default attributes + # are set for all params so the optimizer can use them. + for param in model.parameters(): + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + if args.use_cpu_initialization: + model.cuda(torch.cuda.current_device()) + # Fp16 + config = model[0].config + if args.fp16 or args.bf16: + model = [Float16Module(config, model_module) for model_module in model] + + # DDP + if not wrap_with_ddp: + return + kwargs = {} + for f in dataclasses.fields(DistributedDataParallelConfig): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32 + kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad + kwargs['check_for_large_grads'] = args.check_for_large_grads + kwargs['bucket_size'] = args.ddp_bucket_size + kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw + kwargs['average_in_collective'] = args.ddp_average_in_collective + ddp_config = DistributedDataParallelConfig(**kwargs) + + # In the Megatron FSDP and DDP use path, we need to initialize the bucket size. + # If bucket_size is not provided as an input, use sane default. + # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL + # ring-reduce implementations are large enough to remain bandwidth-bound rather than + # latency-bound. + if ddp_config.bucket_size is None: + ddp_config.bucket_size = max(40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True)) + # Set bucket_size to infinity if overlap_grad_reduce is False. + if not ddp_config.overlap_grad_reduce: + ddp_config.bucket_size = None + + with torch.cuda.stream(torch.cuda.Stream()): + model = [ + DDP( + config=config, + ddp_config=ddp_config, + module=model_chunk, + # Turn off bucketing for model_chunk 2 onwards, since communication for these + # model chunks is overlapped with compute anyway. + disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step, + ) for (model_chunk_idx, model_chunk) in enumerate(model) + ] + + # Broadcast params from data parallel src rank to other data parallel ranks. + if args.data_parallel_random_init: + for model_module in model: + model_module.broadcast_params() + + return model + + +def get_optimizer_param_scheduler(args, optimizer): + # Iteration-based training. + if args.lr_decay_iters is None: + args.lr_decay_iters = args.train_iters + lr_decay_steps = args.lr_decay_iters * args.global_batch_size + wd_incr_steps = args.train_iters * args.global_batch_size + wsd_decay_steps = None + if args.lr_wsd_decay_iters is not None: + wsd_decay_steps = args.lr_wsd_decay_iters * args.global_batch_size + if args.lr_warmup_fraction is not None: + lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps + else: + lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size + + opt_param_scheduler = OptimizerParamScheduler( + optimizer, + init_lr=args.lr_warmup_init, + max_lr=args.lr, + min_lr=args.min_lr, + lr_warmup_steps=lr_warmup_steps, + lr_decay_steps=lr_decay_steps, + lr_decay_style=args.lr_decay_style, + start_wd=args.start_weight_decay, + end_wd=args.end_weight_decay, + wd_incr_steps=wd_incr_steps, + wd_incr_style=args.weight_decay_incr_style, + ) + + return opt_param_scheduler From cb0b21936d0962687e01dea8e56373447446d31e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Feb 2026 01:07:00 +0800 Subject: [PATCH 32/43] update --- swift/megatron/arguments/megatron_args.py | 1 + swift/megatron/trainers/base.py | 2 +- swift/megatron/utils/megatron_lm_utils.py | 33 +++++++++++------------ 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index e313a2ad79..88df75eb1a 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -465,6 +465,7 @@ class MegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): # other seed: int = 42 + data_parallel_random_init: Optional[bool] = False num_workers: int = 4 data_sharding: bool = False diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 8ba36a92db..8d01e16a67 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -136,7 +136,7 @@ def prepare_model(self): self.bridge.load_weights( model, args.ref_adapters[0], is_peft_format=True, adapter_name='ref_adapter') self.peft_models.append(model) - self.wrapped_models.append(wrap_model(args, model)) + self.wrapped_models = wrap_model(args, self.peft_models) def get_optimizer_and_scheduler(self): args = self.args diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index 1755d4330d..d36ccd1d92 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -22,6 +22,7 @@ from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.transformer.module import Float16Module, MegatronModule from megatron.core.utils import unwrap_model +from peft import PeftModel from swift.utils import check_json_format, get_logger, init_process_group, is_master, seed_everything, set_device from .patcher import patch_merge_fn @@ -102,7 +103,7 @@ def initialize_megatron(args): # Random seeds for reproducibility. logger.info(f'Setting random seeds to {args.seed}.') - _set_random_seed(args.seed) + _set_random_seed(args.seed, args.data_parallel_random_init) # Setup MoE aux loss scale value. if args.model_info.is_moe_model: @@ -136,7 +137,7 @@ def _get_rng_state(): return rng_state_list -def _generate_state_dict(args, model, iteration=None, model_sd_kwargs=None): +def _generate_state_dict(args, model: list, iteration=None, model_sd_kwargs=None): model_sd_kwargs = model_sd_kwargs or {} state_dict = {'args': Namespace(**check_json_format(args.__dict__))} if iteration is not None: @@ -150,7 +151,7 @@ def _generate_state_dict(args, model, iteration=None, model_sd_kwargs=None): return state_dict -def save_mcore_checkpoint(args, model, iteration=1): +def save_mcore_checkpoint(args, model: list, iteration=1): model = unwrap_model(model) rng_state = _get_rng_state() checkpoint_dir = os.path.join(args.save, f'iter_{iteration:07d}') @@ -217,7 +218,7 @@ def _load_iteration(tracker_path: str): return iteration -def load_mcore_checkpoint(args, model, optimizer, scheduler, load_arg: str = 'load'): +def load_mcore_checkpoint(args, model: list, optimizer, scheduler, load_arg: str = 'load'): if load_arg in {'adapter_load', 'ref_adapter_load'}: is_peft_format = True elif load_arg in {'load', 'ref_load'}: @@ -288,15 +289,16 @@ def load_mcore_checkpoint(args, model, optimizer, scheduler, load_arg: str = 'lo return iteration, num_floating_point_operations_so_far -def wrap_model(args, model, wrap_with_ddp: bool = True): +def wrap_model(args, model: list, wrap_with_ddp: bool = True): # Set tensor model parallel attributes if not set. # Only parameters that are already tensor model parallel have these # attributes set for them. We should make sure the default attributes # are set for all params so the optimizer can use them. - for param in model.parameters(): - tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) - if args.use_cpu_initialization: - model.cuda(torch.cuda.current_device()) + for m in model: + for param in m.parameters(): + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + if args.use_cpu_initialization: + m.cuda(torch.cuda.current_device()) # Fp16 config = model[0].config if args.fp16 or args.bf16: @@ -309,12 +311,7 @@ def wrap_model(args, model, wrap_with_ddp: bool = True): for f in dataclasses.fields(DistributedDataParallelConfig): if hasattr(args, f.name): kwargs[f.name] = getattr(args, f.name) - kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32 - kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad - kwargs['check_for_large_grads'] = args.check_for_large_grads - kwargs['bucket_size'] = args.ddp_bucket_size - kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw - kwargs['average_in_collective'] = args.ddp_average_in_collective + kwargs['check_for_nan_in_grad'] = True ddp_config = DistributedDataParallelConfig(**kwargs) # In the Megatron FSDP and DDP use path, we need to initialize the bucket size. @@ -336,14 +333,14 @@ def wrap_model(args, model, wrap_with_ddp: bool = True): module=model_chunk, # Turn off bucketing for model_chunk 2 onwards, since communication for these # model chunks is overlapped with compute anyway. - disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step, + disable_bucketing=model_chunk_idx > 0, ) for (model_chunk_idx, model_chunk) in enumerate(model) ] # Broadcast params from data parallel src rank to other data parallel ranks. if args.data_parallel_random_init: - for model_module in model: - model_module.broadcast_params() + for m in model: + m.broadcast_params() return model From c9a7e167a174464ada8ce77f8bf33e21a2a241f3 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Feb 2026 10:38:28 +0800 Subject: [PATCH 33/43] update --- swift/megatron/arguments/megatron_args.py | 22 ++++++++++++++++++++++ swift/megatron/trainers/base.py | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 88df75eb1a..b7f0262fff 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -617,6 +617,8 @@ def __post_init__(self): self._load_adapter_config() self._init_mixed_precision() self._init_multimodal_full() + self._map_dtype() + self._init_weigh_decay() initialize_megatron(self) @@ -695,3 +697,23 @@ def _init_multimodal_full(self): logger.info(f'freeze_parameters: {self.freeze_parameters}') if self.trainable_parameters: logger.info(f'additional trainable_parameters: {self.trainable_parameters}') + + + def _map_dtype(self): + dtype_map = { + 'fp32': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp8': torch.uint8, + } + self.main_grads_dtype = dtype_map[self.main_grads_dtype] + self.main_params_dtype = dtype_map[self.main_params_dtype] + self.exp_avg_dtype = dtype_map[self.exp_avg_dtype] + self.exp_avg_sq_dtype = dtype_map[self.exp_avg_sq_dtype] + + + def _init_weigh_decay(self): + if self.weight_decay_incr_style == 'constant': + assert self.start_weight_decay is None + assert self.end_weight_decay is None + self.start_weight_decay = self.end_weight_decay = self.weight_decay + else: + assert self.start_weight_decay is not None + assert self.end_weight_decay is not None diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 8d01e16a67..0831ecd5aa 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -142,7 +142,7 @@ def get_optimizer_and_scheduler(self): args = self.args kwargs = {} for f in dataclasses.fields(OptimizerConfig): - if hasattr(args, f.name): + if hasattr(args, f.name) and f.name != 'loss_scale': kwargs[f.name] = getattr(args, f.name) config = OptimizerConfig(**kwargs) optimizer = get_megatron_optimizer( From 96b41692654eaff1d5e9f3b73ca7d569274f9f10 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Feb 2026 14:41:30 +0800 Subject: [PATCH 34/43] update --- swift/megatron/arguments/megatron_args.py | 17 +- swift/megatron/model/gpt_model.py | 4 +- swift/megatron/trainers/base.py | 256 ++++++++++++---------- swift/megatron/trainers/batch_sampler.py | 147 +++++++++++++ swift/megatron/trainers/utils.py | 144 +++--------- swift/megatron/utils/megatron_lm_utils.py | 3 +- 6 files changed, 334 insertions(+), 237 deletions(-) create mode 100644 swift/megatron/trainers/batch_sampler.py diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index b7f0262fff..2f5af77fdf 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -466,7 +466,6 @@ class MegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): # other seed: int = 42 data_parallel_random_init: Optional[bool] = False - num_workers: int = 4 data_sharding: bool = False check_model: bool = True @@ -492,6 +491,7 @@ class MegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): # dataloader train_dataloader_shuffle: bool = True + dataloader_num_workers: int = 4 dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = True dataloader_prefetch_factor: int = 2 @@ -678,6 +678,14 @@ def init_iters(self, train_dataset, val_dataset): 'You are using a streaming validation dataset. Please explicitly specify `--eval_iters`.') logger.info(f'Setting args.eval_iters: {self.eval_iters}') + data_parallel_size = mpu.get_data_parallel_world_size() + step_batch_size = self.micro_batch_size * data_parallel_size + # To avoid errors caused by the validation set being insufficient to complete a single step. + if val_dataset is not None and hasattr(val_dataset, '__len__') and len(val_dataset) < step_batch_size: + val_dataset = None + if val_dataset is None: + self.eval_iters = 0 + def _init_multimodal_full(self): visual_cls = self.megatron_model_meta.visual_cls if self.tuner_type == 'full' and self.is_multimodal and visual_cls is not None: @@ -698,17 +706,18 @@ def _init_multimodal_full(self): if self.trainable_parameters: logger.info(f'additional trainable_parameters: {self.trainable_parameters}') - def _map_dtype(self): dtype_map = { - 'fp32': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp8': torch.uint8, + 'fp32': torch.float32, + 'bf16': torch.bfloat16, + 'fp16': torch.float16, + 'fp8': torch.uint8, } self.main_grads_dtype = dtype_map[self.main_grads_dtype] self.main_params_dtype = dtype_map[self.main_params_dtype] self.exp_avg_dtype = dtype_map[self.exp_avg_dtype] self.exp_avg_sq_dtype = dtype_map[self.exp_avg_sq_dtype] - def _init_weigh_decay(self): if self.weight_decay_incr_style == 'constant': assert self.start_weight_decay is None diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 9084a30db4..cc2bcac8b2 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -108,11 +108,11 @@ def __init__( self.attention_scaling = 1. new_inv_freq, self.attention_scaling = get_rope_inv_freq(config) self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) - self.args = args = config.args + self.args = config.args if self.args.task_type == 'seq_cls' and self.post_process: self.output_layer = OutputLayerLinear( config.hidden_size, - self.config.num_labels, + self.args.num_labels, config=config, init_method=config.init_method, bias=False, diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 0831ecd5aa..ad4ab814a1 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -17,6 +17,7 @@ import torch.nn from megatron.core import mpu, tensor_parallel from megatron.core.datasets.utils import Split +from megatron.core.distributed import finalize_model_grads from megatron.core.enums import ModelType from megatron.core.num_microbatches_calculator import get_num_microbatches, update_num_microbatches from megatron.core.optimizer import OptimizerConfig, _update_min_and_max_lr_in_param_groups, get_megatron_optimizer @@ -29,9 +30,7 @@ # get_wandb_writer, initialize, is_last_rank, one_logger_utils, pretrain, print_rank_0, # print_rank_last, training) # from megatron.training.checkpointing import check_checkpoint_args, set_checkpoint_version -# from megatron.training.dist_signal_handler import DistributedSignalHandler # from megatron.training.theoretical_memory_usage import report_theoretical_memory -# from megatron.training.training import num_floating_point_operations # from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory from modelscope import check_local_model_is_latest from packaging import version @@ -46,9 +45,11 @@ from swift.template import Template from swift.trainers import SwiftMixin, dynamic_gradient_checkpointing from swift.trainers.utils import patch_modelscope_hub_timeout -from swift.utils import JsonlWriter, deep_getattr, format_time, get_last_valid_indices, get_logger, ms_logger_context -from .utils import (MegatronPretrainingRandomSampler, get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, - get_packed_seq_params, get_swift_datasets_provider) +from swift.utils import (JsonlWriter, deep_getattr, format_time, get_last_valid_indices, get_logger, is_last_rank, + ms_logger_context) +from .batch_sampler import MegatronPretrainingRandomSampler, MegatronPretrainingSampler +from .utils import (get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, + logical_and_across_model_parallel_group, reduce_max_stat_across_model_parallel_group) # try: # from megatron.training.datasets.data_samplers import MegatronPretrainingSampler @@ -70,10 +71,10 @@ def __init__(self, args, template: Template): self.template = template self.bridge = args.megatron_model_meta.bridge_cls(args) self.prepare_model() + self.config = self.unwrapped_models[0].config self.optimizer, self.scheduler = self.get_optimizer_and_scheduler() self.data_collator = self._get_data_collator() args.iteration = 0 - args.num_floating_point_operations_so_far = 0 if args.initialize_embedding: for m in self.unwrapped_models: self._initialize_embedding(m) @@ -85,11 +86,8 @@ def __init__(self, args, template: Template): load_mcore_checkpoint(args, self.wrapped_models, load_arg='ref_adapter_load') if args.adapter_load is not None: with adapter_state_dict_context(): - args.iteration, args.num_floating_point_operations_so_far = load_mcore_checkpoint( - args, self.wrapped_models, self.optimizer, self.scheduler, load_arg='adapter_load', strict=False) - if args.is_multimodal: - for m in self.unwrapped_models: - self._prepare_vit_gradient_checkpointing(m) + args.iteration = load_mcore_checkpoint( + args, self.wrapped_models, self.optimizer, self.scheduler, load_arg='adapter_load') self.eval_metrics = None logging_path = os.path.join(args.save, 'logging.jsonl') @@ -467,7 +465,6 @@ def _load_iteration(self): state_dict = torch.load(common_path) set_checkpoint_version(state_dict.get('checkpoint_version', 0)) - num_floating_point_operations_so_far = state_dict.get('num_floating_point_operations_so_far', 0) if 'args' in state_dict and not args.finetune: checkpoint_args = state_dict['args'] check_checkpoint_args(checkpoint_args) @@ -478,14 +475,14 @@ def _load_iteration(self): else: print_rank_0('could not find arguments in the checkpoint ...') - return iteration, num_floating_point_operations_so_far + return iteration def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): # read iteration args = self.args if not args.finetune: - args.iteration, args.num_floating_point_operations_so_far = self._load_iteration() + args.iteration = self._load_iteration() if args.apply_wd_to_qk_layernorm or self.args.vit_lr is not None or self.args.aligner_lr is not None: param_groups_context = self._patch_get_param_groups() @@ -540,14 +537,8 @@ def _all_reduce_metric(self, torch.distributed.all_reduce(reporting_metric, reduction, group=mpu.get_data_parallel_group()) return {k: reporting_metric[i] for i, k in enumerate(metric.keys())} - def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, *args, - **kwargs): - new_data_iterator = self._replace_data_iterator(data_iterator, model) - return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler, - config, *args, **kwargs) - # Code borrowed from NVIDIA/Megatron-LM - def evaluate( + def _evaluate( self, forward_step_func, data_iterator, @@ -612,7 +603,6 @@ def evaluate( num_microbatches=eval_num_microbatches, seq_length=args.seq_length, micro_batch_size=args.micro_batch_size, - decoder_seq_length=args.decoder_seq_length, forward_only=True, ) ft_integration.on_eval_step_end() @@ -675,7 +665,6 @@ def evaluate( num_microbatches=get_num_microbatches(), seq_length=args.seq_length, micro_batch_size=args.micro_batch_size, - decoder_seq_length=args.decoder_seq_length, forward_only=True, collect_non_loss_data=True, ) @@ -703,7 +692,7 @@ def evaluate( self.jsonl_writer.append(logs) return total_loss_dict, collected_non_loss_data, False - def evaluate_and_print_results( + def _evaluate_and_print_results( self, prefix, forward_step_func, @@ -822,8 +811,8 @@ def custom_log(self, total_loss_dict, mode: Literal['train', 'eval'], iteration= wandb_writer.log(metrics, iteration) # Code borrowed from NVIDIA/Megatron-LM - def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale, - report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad): + def _training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale, + report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad): """Log training information such as losses, timing, ....""" args = self.args timers = get_timers() @@ -1110,7 +1099,7 @@ def copy_path(src_path: str, tgt_path: str): else: raise ValueError(f'Source path is neither a file nor a directory: {src_path}') - def save_checkpoint(self, iteration, model, *_args, **kwargs): + def _save_checkpoint(self, iteration, model, *_args, **kwargs): args = self.args output_dir = os.path.join(args.save, f'checkpoint-{iteration}') os.makedirs(output_dir, exist_ok=True) @@ -1149,116 +1138,157 @@ def save_checkpoint(self, iteration, model, *_args, **kwargs): if args.tuner_type == 'lora' and args.merge_lora: self.unmerge_lora_adapters() - def _patch_megatron(self): - # support max_epochs - self._origin_train_step = training.train_step - training.train_step = self.train_step - self._origin_cyclic_iter = training.cyclic_iter - training.cyclic_iter = self.new_cyclic_iter - # patch training_log - self._origin_training_log = training.training_log - training.training_log = self.training_log - # patch evaluate - self._origin_evaluate_and_print_results = training.evaluate_and_print_results - training.evaluate_and_print_results = self.evaluate_and_print_results - # patch model and optimizer - self._origin_setup_model_and_optimizer = training.setup_model_and_optimizer - training.setup_model_and_optimizer = self.setup_model_and_optimizer - # patch save_checkpoint - self._origin_save_checkpoint = training.save_checkpoint - training.save_checkpoint = self.save_checkpoint - def train(self, train_dataset, val_dataset): args = self.args - datasets_provider = get_swift_datasets_provider(train_dataset, val_dataset) - datasets_provider.is_distributed = True - with self.patch_megatron_data_collator(data_collator), self._get_iters(train_dataset, val_dataset): - pretrain(datasets_provider, args.megatron_model_meta.model_provider, ModelType.encoder_or_decoder, - self.forward_step) + train_dataloader, val_dataloader = self.prepare_dataloader(train_dataset, val_dataset) + for m in self.wrapped_models: + m.train() - # Code borrowed from NVIDIA/Megatron-LM - def build_pretraining_data_loader(self, dataset, consumed_samples, data_collator=None): - """Build dataloader given an input dataset.""" + if args.is_multimodal: + for m in self.unwrapped_models: + self._prepare_vit_gradient_checkpointing(m) + + self.config.finalize_model_grads_func = finalize_model_grads + # TODO: manual_gc + train_metrics = {} + while args.iteration < args.train_iters: + metrics, grad_norm = self.train_step(train_dataloader) + if mpu.is_pipeline_last_stage(ignore_virtual=True): + self.aggregated_metrics(metrics, train_metrics) + self.training_log(train_metrics, grad_norm) - if dataset is None: - return None + if args.eval_interval and args.iteration % args.eval_interval == 0: + self.evaluate(val_dataloader) + if args.save and args.save_interval and args.iteration % args.save_interval == 0: + self.save_checkpoint() + + def save_checkpoint(self): + print + + def training_log(self, metrics, grad_norm): + learning_rate = None + for param_group in self.optimizer.param_groups: + if len(param_group['params']) == 0: + continue + learning_rate = param_group['lr'] + logger.info(f'metrics: {metrics}, grad_norm: {grad_norm}, learning_rate: {learning_rate}') + + def evaluate(self, val_dataloader): + # TODO: 兼容transformers callback, eval_metrics等 args = self.args - if args.dataloader_type == 'external': - # External dataloaders are passed through. User is expected to provide a - # torch-compatible dataloader and define samplers, if needed. - return dataset - - if hasattr(dataset, 'split'): - split = dataset.split - elif hasattr(dataset, 'index_split'): - split = dataset.index_split - else: - split = None + for m in self.wrapped_models: + m.eval() + eval_metrics = {} + forward_backward_func = get_forward_backward_func() + + with torch.no_grad(), tqdm( + total=val_dataloader, dynamic_ncols=True, disable=not is_last_rank(), desc='Evaluate: ') as prog_bar: + iteration = 0 + while iteration < args.eval_iters: + prog_bar.update() + iteration += 1 + + metrics = forward_backward_func( + forward_step_func=self.forward_step, + data_iterator=val_dataloader, + model=self.wrapped_models, + num_microbatches=get_num_microbatches(), + seq_length=args.max_length, + micro_batch_size=args.micro_batch_size, + forward_only=True, + ) + if mpu.is_pipeline_last_stage(ignore_virtual=True): + self.aggregated_metrics(metrics, eval_metrics) + # TODO: log metrics + logger.info(f'eval_metrics: {eval_metrics}') + for m in self.wrapped_models: + m.train() - is_val_dataset = getattr(dataset, 'dataset_type', None) == 'validation' + def train_step(self, train_dataloader): + args = self.args + forward_backward_func = get_forward_backward_func() + for m in self.wrapped_models: + m.zero_grad_buffer() + self.optimizer.zero_grad() + metrics = forward_backward_func( + forward_step_func=self.forward_step, + data_iterator=train_dataloader, + model=self.wrapped_models, + num_microbatches=get_num_microbatches(), + seq_length=args.max_length, + micro_batch_size=args.micro_batch_size, + forward_only=False, + ) + + update_successful, grad_norm, _ = self.optimizer.step() + update_successful = logical_and_across_model_parallel_group(update_successful) + grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm) + + # Update learning rate. + if update_successful: + increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size + args.iteration += 1 + args.consumed_train_samples += increment + self.scheduler.step(increment=increment) + + return metrics, grad_norm + + def aggregated_metrics(self, metrics, total_metrics): + for key in metrics[0].keys(): + if key not in total_metrics: + total_metrics[key] = torch.tensor([0.0], dtype=torch.float32, device=torch.cuda.current_device()) + val = [x[key].view(-1) for x in metrics] + val = torch.concat(val, dim=0) + if val[0].numel() == 2: + val = val.sum(dim=0) + total_metrics[key] += val[0] / val[1] + elif val[0].numel() == 1: + total_metrics[key] += val.sum() + else: + raise ValueError(f'Invalid value shape: {val[0].shape} for key {key}') - if split == Split.valid and args.full_validation: - batch_sampler = MegatronPretrainingSampler( - total_samples=len(dataset), + def prepare_dataloader(self, train_dataset, val_dataset): + args = self.args + + train_batch_sampler = MegatronPretrainingRandomSampler( + train_dataset, + total_samples=len(train_dataset), + consumed_samples=args.consumed_train_samples, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size(), + data_sharding=args.data_sharding, + shuffle=args.train_dataloader_shuffle, + group_by_length=args.group_by_length, + ) + train_dataloader = self._create_dataloader(train_dataset, train_batch_sampler) + val_dataloader = None + if val_dataset is not None: + val_batch_sampler = MegatronPretrainingSampler( + total_samples=len(val_dataset), consumed_samples=0, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size(), ) - elif args.dataloader_type == 'single' or is_val_dataset: - if is_val_dataset: - consumed_samples = 0 - # Megatron sampler - batch_sampler = MegatronPretrainingSampler( - total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=args.micro_batch_size, - data_parallel_rank=mpu.get_data_parallel_rank(), - data_parallel_size=mpu.get_data_parallel_world_size(), - ) - elif args.dataloader_type == 'cyclic': - batch_sampler = MegatronPretrainingRandomSampler( - dataset, - total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=args.micro_batch_size, - data_parallel_rank=mpu.get_data_parallel_rank(), - data_parallel_size=mpu.get_data_parallel_world_size(), - data_sharding=args.data_sharding, - shuffle=args.train_dataloader_shuffle, - group_by_length=args.group_by_length, - ) - else: - raise Exception('{} dataloader type is not supported.'.format(args.dataloader_type)) + val_dataloader = self._create_dataloader(val_dataset, val_batch_sampler) + return train_dataloader, val_dataloader - def worker_init_fn(_): - DistributedSignalHandler(args.exit_signal).__enter__() + def _create_dataloader(self, dataset, batch_sampler): + args = self.args - maybe_worker_init_fn = (worker_init_fn if args.exit_signal_handler and args.num_workers > 0 else None) - # Torch dataloader. dataloader = torch.utils.data.DataLoader( dataset, batch_sampler=batch_sampler, - num_workers=args.num_workers, + num_workers=args.dataloader_num_workers, pin_memory=args.dataloader_pin_memory, persistent_workers=args.dataloader_persistent_workers if args.num_workers > 0 else False, prefetch_factor=args.dataloader_prefetch_factor if args.num_workers > 0 else None, - worker_init_fn=maybe_worker_init_fn, - collate_fn=data_collator, + collate_fn=self.data_collator, ) return dataloader - @contextmanager - def patch_megatron_data_collator(self, data_collator): - origin_build_pretraining_data_loader = training.build_pretraining_data_loader - training.build_pretraining_data_loader = partial( - self.build_pretraining_data_loader, data_collator=data_collator) - try: - yield - finally: - training.build_pretraining_data_loader = origin_build_pretraining_data_loader - @abstractmethod def forward_step(self, data_iterator, model): pass diff --git a/swift/megatron/trainers/batch_sampler.py b/swift/megatron/trainers/batch_sampler.py new file mode 100644 index 0000000000..6cc9892705 --- /dev/null +++ b/swift/megatron/trainers/batch_sampler.py @@ -0,0 +1,147 @@ +# Code borrowed from megatron-lm +class MegatronPretrainingSampler: + + def __init__(self, + total_samples, + consumed_samples, + micro_batch_size, + data_parallel_rank, + data_parallel_size, + drop_last=True): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.drop_last = drop_last + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.consumed_samples < self.total_samples, \ + 'no samples left to consume: {}, {}'.format(self.consumed_samples, + self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def get_start_end_idx(self): + start_idx = self.data_parallel_rank * self.micro_batch_size + end_idx = start_idx + self.micro_batch_size + return start_idx, end_idx + + def __iter__(self): + batch = [] + # Last batch will be dropped if drop_last is not set False + for idx in range(self.consumed_samples, self.total_samples): + batch.append(idx) + if len(batch) == self.micro_batch_times_data_parallel_size: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + + +# Code borrowed from megatron-lm +class MegatronPretrainingRandomSampler: + + def __init__( + self, + dataset, + total_samples, + consumed_samples, + micro_batch_size, + data_parallel_rank, + data_parallel_size, + data_sharding, + shuffle: bool = True, + group_by_length: bool = False, + ): + # Keep a copy of input params for later use. + self.dataset = dataset + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.data_parallel_size = data_parallel_size + if group_by_length: + if data_sharding: + data_sharding = False + logger.warning('`group_by_length=True` is incompatible with `data_sharding=True`. ' + 'Setting `data_sharding=False` to enable length grouping.') + if not shuffle: + raise ValueError('shuffle must be True when group_by_length is True') + self.data_sharding = data_sharding + self.shuffle = shuffle + self.group_by_length = group_by_length + self.lengths = self.dataset['lengths'] if group_by_length else None + if self.lengths is not None: + self.lengths = [max(length) if isinstance(length, list) else length for length in self.lengths] + self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size + self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size + + # Sanity checks. + assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, ( + 'data_parallel_rank should be smaller than data size: {}, ' + '{}'.format(self.data_parallel_rank, data_parallel_size)) + + def __len__(self): + return self.total_samples + + def __iter__(self): + active_total_samples = self.total_samples - self.last_batch_size + self.epoch = self.consumed_samples // active_total_samples + current_epoch_samples = self.consumed_samples % active_total_samples + assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 + + if self.shuffle: + # data sharding and random sampling + if self.data_sharding: + bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size + bucket_offset = current_epoch_samples // self.data_parallel_size + start_idx = self.data_parallel_rank * bucket_size + + g = torch.Generator() + g.manual_seed(self.epoch) + random_idx = torch.randperm(bucket_size, generator=g).tolist() + idx_range = [start_idx + x for x in random_idx[bucket_offset:]] + else: + full_bucket_size = (self.total_samples // self.micro_batch_size) * self.micro_batch_size + full_bucket_offset = current_epoch_samples + g = torch.Generator() + g.manual_seed(self.epoch) + if self.group_by_length: + from transformers.trainer_pt_utils import get_length_grouped_indices + idx_range_total = get_length_grouped_indices( + self.lengths, self.micro_batch_times_data_parallel_size, generator=g) + else: + idx_range_total = torch.randperm(full_bucket_size, generator=g).tolist() + idx_range_active = idx_range_total[full_bucket_offset:] + idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size] + else: + full_bucket_size = (self.total_samples // self.micro_batch_size) * self.micro_batch_size + full_bucket_offset = current_epoch_samples + idx_range = range(full_bucket_offset + self.data_parallel_rank, full_bucket_size, self.data_parallel_size) + + batch = [] + # Last batch if not complete will be dropped. + for idx in idx_range: + batch.append(idx) + if len(batch) == self.micro_batch_size: + self.consumed_samples += self.micro_batch_times_data_parallel_size + yield batch + batch = [] diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 1b23d7a13b..c426d2b838 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -25,29 +25,8 @@ logger = get_logger() -def get_swift_datasets_provider(train_dataset, val_dataset): - - def swift_datasets_provider(train_val_test_num_samples, vp_stage=None): - nonlocal val_dataset - args = get_args() - data_parallel_size = mpu.get_data_parallel_world_size() - step_batch_size = args.micro_batch_size * data_parallel_size - # To avoid errors caused by the validation set being insufficient to complete a single step. - if val_dataset is not None and hasattr(val_dataset, '__len__') and len(val_dataset) < step_batch_size: - val_dataset = None - if val_dataset is None: - args.eval_iters = 0 - else: - val_dataset.dataset_type = 'validation' - return train_dataset, val_dataset, None - - return swift_datasets_provider - - # Code borrowed from NVIDIA/Megatron-LM -def get_batch_on_this_tp_rank(data, vp_stage=None): - args = get_args() - +def get_batch_on_this_tp_rank(args, data, vp_stage=None): if args.task_type == 'causal_lm': data['labels'] = torch.roll(data['labels'], -1, dims=-1) if 'loss_scale' in data: @@ -109,7 +88,7 @@ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], di return torch.cat(new_inputs, dim=dim) -def get_batch_on_this_cp_rank(batch: Dict[str, Any]): +def get_batch_on_this_cp_rank(args, batch: Dict[str, Any]): """Slice batch input along sequence dimension into multiple chunks, which are parallelized across GPUs in a context parallel group. """ @@ -122,7 +101,6 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): # that we can get balanced workload among GPUs in a context parallel group. cp_size = mpu.get_context_parallel_world_size() if cp_size > 1: - args = get_args() keys = ['labels', 'position_ids', 'loss_scale'] if not args.is_multimodal: # Multimodal models will handle CP in input_embeds. @@ -382,95 +360,29 @@ def log_gpu_memory(prefix: str = '', info_once: bool = False): logger.info(log_msg) -# Code borrowed from megatron-lm -class MegatronPretrainingRandomSampler: - - def __init__( - self, - dataset, - total_samples, - consumed_samples, - micro_batch_size, - data_parallel_rank, - data_parallel_size, - data_sharding, - shuffle: bool = True, - group_by_length: bool = False, - ): - # Keep a copy of input params for later use. - self.dataset = dataset - self.total_samples = total_samples - self.consumed_samples = consumed_samples - self.micro_batch_size = micro_batch_size - self.data_parallel_rank = data_parallel_rank - self.data_parallel_size = data_parallel_size - if group_by_length: - if data_sharding: - data_sharding = False - logger.warning('`group_by_length=True` is incompatible with `data_sharding=True`. ' - 'Setting `data_sharding=False` to enable length grouping.') - if not shuffle: - raise ValueError('shuffle must be True when group_by_length is True') - self.data_sharding = data_sharding - self.shuffle = shuffle - self.group_by_length = group_by_length - self.lengths = self.dataset['lengths'] if group_by_length else None - if self.lengths is not None: - self.lengths = [max(length) if isinstance(length, list) else length for length in self.lengths] - self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size - self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size - - # Sanity checks. - assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples) - assert self.micro_batch_size > 0 - assert data_parallel_size > 0 - assert self.data_parallel_rank < data_parallel_size, ( - 'data_parallel_rank should be smaller than data size: {}, ' - '{}'.format(self.data_parallel_rank, data_parallel_size)) - - def __len__(self): - return self.total_samples - - def __iter__(self): - active_total_samples = self.total_samples - self.last_batch_size - self.epoch = self.consumed_samples // active_total_samples - current_epoch_samples = self.consumed_samples % active_total_samples - assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 - - if self.shuffle: - # data sharding and random sampling - if self.data_sharding: - bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size - bucket_offset = current_epoch_samples // self.data_parallel_size - start_idx = self.data_parallel_rank * bucket_size - - g = torch.Generator() - g.manual_seed(self.epoch) - random_idx = torch.randperm(bucket_size, generator=g).tolist() - idx_range = [start_idx + x for x in random_idx[bucket_offset:]] - else: - full_bucket_size = (self.total_samples // self.micro_batch_size) * self.micro_batch_size - full_bucket_offset = current_epoch_samples - g = torch.Generator() - g.manual_seed(self.epoch) - if self.group_by_length: - from transformers.trainer_pt_utils import get_length_grouped_indices - idx_range_total = get_length_grouped_indices( - self.lengths, self.micro_batch_times_data_parallel_size, generator=g) - else: - idx_range_total = torch.randperm(full_bucket_size, generator=g).tolist() - idx_range_active = idx_range_total[full_bucket_offset:] - idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size] - else: - full_bucket_size = (self.total_samples // self.micro_batch_size) * self.micro_batch_size - full_bucket_offset = current_epoch_samples - idx_range = range(full_bucket_offset + self.data_parallel_rank, full_bucket_size, self.data_parallel_size) - - batch = [] - # Last batch if not complete will be dropped. - for idx in idx_range: - batch.append(idx) - if len(batch) == self.micro_batch_size: - self.consumed_samples += self.micro_batch_times_data_parallel_size - yield batch - batch = [] +def reduce_max_stat_across_model_parallel_group(stat: float) -> float: + """ + Ranks without an optimizer will have no grad_norm or num_zeros_in_grad stats. + We need to ensure the logging and writer rank has those values. + This function reduces a stat tensor across the model parallel group. + + We use an all_reduce max since the values have already been summed across optimizer ranks where possible + """ + if stat is None: + stat = -1.0 + stat = torch.tensor([stat], dtype=torch.float32, device=torch.cuda.current_device()) + torch.distributed.all_reduce(stat, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()) + if stat.item() == -1.0: + return None + else: + return stat.item() + + +def logical_and_across_model_parallel_group(input: bool) -> bool: + """ + This function gathers a bool value across the model parallel group + """ + input = int(bool(input)) + input = torch.tensor([input], dtype=torch.int, device=torch.cuda.current_device()) + torch.distributed.all_reduce(input, op=torch.distributed.ReduceOp.MIN, group=mpu.get_model_parallel_group()) + return bool(input.item()) diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index d36ccd1d92..a5d381e8ab 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -266,7 +266,6 @@ def load_mcore_checkpoint(args, model: list, optimizer, scheduler, load_arg: str if args.finetune: iteration = 0 - num_floating_point_operations_so_far = state_dict.get('num_floating_point_operations_so_far', 0) if 'args' in state_dict and not args.finetune: checkpoint_args = state_dict['args'] args.consumed_train_samples = getattr(checkpoint_args, 'consumed_train_samples', 0) @@ -286,7 +285,7 @@ def load_mcore_checkpoint(args, model: list, optimizer, scheduler, load_arg: str torch.distributed.barrier() logger.info(f'Successfully loaded Megatron model weights from: {args.load}') - return iteration, num_floating_point_operations_so_far + return iteration def wrap_model(args, model: list, wrap_with_ddp: bool = True): From 39ac20fd089631897f715948295da9973be88620 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Feb 2026 14:50:24 +0800 Subject: [PATCH 35/43] update --- swift/megatron/trainers/base.py | 7 +++++-- swift/megatron/utils/megatron_lm_utils.py | 10 ++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 0831ecd5aa..32ee3ad45e 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -72,8 +72,10 @@ def __init__(self, args, template: Template): self.prepare_model() self.optimizer, self.scheduler = self.get_optimizer_and_scheduler() self.data_collator = self._get_data_collator() + # TODO: resume_from_checkpoint args.iteration = 0 args.num_floating_point_operations_so_far = 0 + args.consumed_train_samples = 0 if args.initialize_embedding: for m in self.unwrapped_models: self._initialize_embedding(m) @@ -474,7 +476,8 @@ def _load_iteration(self): args.consumed_train_samples = getattr(checkpoint_args, 'consumed_train_samples', 0) args.skipped_train_samples = getattr(checkpoint_args, 'skipped_train_samples', 0) update_num_microbatches(consumed_samples=args.consumed_train_samples, verbose=True) - args.consumed_valid_samples = getattr(checkpoint_args, 'consumed_valid_samples', 0) + # TODO: ignore + # args.consumed_valid_samples = getattr(checkpoint_args, 'consumed_valid_samples', 0) else: print_rank_0('could not find arguments in the checkpoint ...') @@ -652,7 +655,7 @@ def evaluate( else: total_loss_dict[key][0] += val total_loss_dict[key][1] += 1 - args.consumed_valid_samples += eval_batch_size + # args.consumed_valid_samples += eval_batch_size if args.exit_duration_in_mins: train_time = (time.time() - training._TRAIN_START_TIME) / 60.0 diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index d36ccd1d92..92f942e415 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -272,7 +272,8 @@ def load_mcore_checkpoint(args, model: list, optimizer, scheduler, load_arg: str args.consumed_train_samples = getattr(checkpoint_args, 'consumed_train_samples', 0) args.skipped_train_samples = getattr(checkpoint_args, 'skipped_train_samples', 0) update_num_microbatches(consumed_samples=args.consumed_train_samples, verbose=True) - args.consumed_valid_samples = getattr(checkpoint_args, 'consumed_valid_samples', 0) + # TODO: + # args.consumed_valid_samples = getattr(checkpoint_args, 'consumed_valid_samples', 0) if len(model) == 1: model[0].load_state_dict(state_dict['model']) @@ -351,9 +352,10 @@ def get_optimizer_param_scheduler(args, optimizer): args.lr_decay_iters = args.train_iters lr_decay_steps = args.lr_decay_iters * args.global_batch_size wd_incr_steps = args.train_iters * args.global_batch_size - wsd_decay_steps = None - if args.lr_wsd_decay_iters is not None: - wsd_decay_steps = args.lr_wsd_decay_iters * args.global_batch_size + # TODO + # wsd_decay_steps = None + # if args.lr_wsd_decay_iters is not None: + # wsd_decay_steps = args.lr_wsd_decay_iters * args.global_batch_size if args.lr_warmup_fraction is not None: lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps else: From 2eee590808fa082513d55c0f4debfa6341e289d8 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Feb 2026 15:05:35 +0800 Subject: [PATCH 36/43] update --- swift/megatron/arguments/megatron_base_args.py | 6 +++--- swift/megatron/pipelines/train/utils.py | 6 +++--- swift/megatron/trainers/base.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/swift/megatron/arguments/megatron_base_args.py b/swift/megatron/arguments/megatron_base_args.py index d122e000a8..3dc2a4f58d 100644 --- a/swift/megatron/arguments/megatron_base_args.py +++ b/swift/megatron/arguments/megatron_base_args.py @@ -21,6 +21,6 @@ def __post_init__(self): MegatronArguments.__post_init__(self) if self.streaming: self.dataloader_type = 'external' - if self.num_workers > 1: - self.num_workers = 1 - logger.info('Using streaming dataset, setting args.num_workers to 1.') + if self.dataloader_num_workers > 1: + self.dataloader_num_workers = 1 + logger.info('Using streaming dataset, setting args.dataloader_num_workers to 1.') diff --git a/swift/megatron/pipelines/train/utils.py b/swift/megatron/pipelines/train/utils.py index 9771868d3f..456348a7c2 100644 --- a/swift/megatron/pipelines/train/utils.py +++ b/swift/megatron/pipelines/train/utils.py @@ -16,11 +16,11 @@ def build_streaming_dataloader(args, dataset, collate_fn): from megatron.training.training import cyclic_iter base_dataloader = torch.utils.data.DataLoader( dataset, - num_workers=args.num_workers, + num_workers=args.dataloader_num_workers, pin_memory=args.dataloader_pin_memory, collate_fn=collate_fn, batch_size=args.micro_batch_size, - prefetch_factor=args.dataloader_prefetch_factor if args.num_workers > 0 else None, - persistent_workers=args.dataloader_persistent_workers if args.num_workers > 0 else False, + prefetch_factor=args.dataloader_prefetch_factor if args.dataloader_num_workers > 0 else None, + persistent_workers=args.dataloader_persistent_workers if args.dataloader_num_workers > 0 else False, ) return iter(cyclic_iter(MegatronDataLoaderDispatcher(base_dataloader))) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 2a562f6bce..b58012be75 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -76,6 +76,7 @@ def __init__(self, args, template: Template): self.data_collator = self._get_data_collator() # TODO: resume_from_checkpoint args.iteration = 0 + args.consumed_train_samples = 0 if args.initialize_embedding: for m in self.unwrapped_models: self._initialize_embedding(m) @@ -470,7 +471,6 @@ def _load_iteration(self): checkpoint_args = state_dict['args'] check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr(checkpoint_args, 'consumed_train_samples', 0) - args.skipped_train_samples = getattr(checkpoint_args, 'skipped_train_samples', 0) update_num_microbatches(consumed_samples=args.consumed_train_samples, verbose=True) # TODO: ignore # args.consumed_valid_samples = getattr(checkpoint_args, 'consumed_valid_samples', 0) @@ -1285,8 +1285,8 @@ def _create_dataloader(self, dataset, batch_sampler): batch_sampler=batch_sampler, num_workers=args.dataloader_num_workers, pin_memory=args.dataloader_pin_memory, - persistent_workers=args.dataloader_persistent_workers if args.num_workers > 0 else False, - prefetch_factor=args.dataloader_prefetch_factor if args.num_workers > 0 else None, + persistent_workers=args.dataloader_persistent_workers if args.dataloader_num_workers > 0 else False, + prefetch_factor=args.dataloader_prefetch_factor if args.dataloader_num_workers > 0 else None, collate_fn=self.data_collator, ) return dataloader From c6238eabc12c3232b901a89475b610fa0f86188a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Feb 2026 17:08:02 +0800 Subject: [PATCH 37/43] update --- swift/megatron/arguments/megatron_args.py | 2 ++ swift/megatron/trainers/base.py | 8 ++++++++ swift/megatron/trainers/batch_sampler.py | 4 ++++ 3 files changed, 14 insertions(+) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 2f5af77fdf..e5e4e45e9c 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -621,6 +621,8 @@ def __post_init__(self): self._init_weigh_decay() initialize_megatron(self) + total_model_size = self.tensor_model_parallel_size * self.pipeline_model_parallel_size * self.context_parallel_size + self.data_parallel_size = self.world_size // total_model_size def _init_vpp_size(self): # TODO diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index b58012be75..9dc63e5cf7 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -24,6 +24,7 @@ from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine from megatron.core.transformer.module import Float16Module, MegatronModule +from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator from megatron.core.transformer.moe.moe_utils import track_moe_metrics from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper # from megatron.training import (checkpointing, ft_integration, get_args, get_model, get_tensorboard_writer, @@ -74,6 +75,13 @@ def __init__(self, args, template: Template): self.config = self.unwrapped_models[0].config self.optimizer, self.scheduler = self.get_optimizer_and_scheduler() self.data_collator = self._get_data_collator() + init_num_microbatches_calculator( + args.rank, + None, + args.global_batch_size, + args.micro_batch_size, + args.data_parallel_size, + ) # TODO: resume_from_checkpoint args.iteration = 0 args.consumed_train_samples = 0 diff --git a/swift/megatron/trainers/batch_sampler.py b/swift/megatron/trainers/batch_sampler.py index 6cc9892705..35656169d8 100644 --- a/swift/megatron/trainers/batch_sampler.py +++ b/swift/megatron/trainers/batch_sampler.py @@ -1,3 +1,7 @@ +import torch +from swift.utils import get_logger +logger = get_logger() + # Code borrowed from megatron-lm class MegatronPretrainingSampler: From ba23ee18dc5d507553eb42eda43736ce41ed5bf6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Feb 2026 17:09:04 +0800 Subject: [PATCH 38/43] update --- swift/megatron/trainers/base.py | 36 ++++++++++++------------ swift/megatron/trainers/batch_sampler.py | 1 + swift/megatron/trainers/trainer.py | 33 ---------------------- swift/megatron/trainers/utils.py | 7 +---- 4 files changed, 20 insertions(+), 57 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 9dc63e5cf7..f88932831f 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -134,7 +134,7 @@ def prepare_model(self): if args.load is None: self.bridge.load_weights(model, args.model_dir) - model = prepare_mcore_model(args, model) + peft_model = prepare_mcore_model(args, model) if args.tuner_type == 'lora': if args.adapters and args.adapter_load is None: assert len(args.adapters) == 1, 'Currently only support one adapter.' @@ -143,8 +143,8 @@ def prepare_model(self): assert len(args.ref_adapters) == 1, 'Currently only support one adapter.' self.bridge.load_weights( model, args.ref_adapters[0], is_peft_format=True, adapter_name='ref_adapter') - self.peft_models.append(model) - self.wrapped_models = wrap_model(args, self.peft_models) + self.peft_models.append(peft_model) + self.wrapped_models = wrap_model(args, self.unwrapped_models) def get_optimizer_and_scheduler(self): args = self.args @@ -167,11 +167,9 @@ def _get_data_collator(self): data_collator = partial(data_collator, padding_to=padding_to) return data_collator - def new_cyclic_iter(self, iterable): + def cyclic_iter(self, iterable): training = self.unwrapped_models[0].training - if not training: - yield from self._origin_cyclic_iter(iterable) - return + assert training, 'training must be True' args = self.args n_epoch = 0 @@ -181,11 +179,11 @@ def new_cyclic_iter(self, iterable): logger.info(f'The training of Epoch {n_epoch} starts...') for x in iterable: yield x + # streaming if training and args.max_epochs and n_epoch >= args.max_epochs - 1: is_finished = True n_epoch += 1 if is_finished: - # streaming # Note that this approach will train for one additional step. logger.info(f'Training of {n_epoch} epochs has been completed, the training has finished.') args.train_iters = args.curr_iteration + 1 @@ -1161,8 +1159,9 @@ def train(self, train_dataset, val_dataset): self.config.finalize_model_grads_func = finalize_model_grads # TODO: manual_gc train_metrics = {} + train_data_iterator = iter(self.cyclic_iter(train_dataloader)) while args.iteration < args.train_iters: - metrics, grad_norm = self.train_step(train_dataloader) + metrics = self.train_step(train_data_iterator) if mpu.is_pipeline_last_stage(ignore_virtual=True): self.aggregated_metrics(metrics, train_metrics) self.training_log(train_metrics, grad_norm) @@ -1191,9 +1190,9 @@ def evaluate(self, val_dataloader): m.eval() eval_metrics = {} forward_backward_func = get_forward_backward_func() - + val_data_iterator = iter(val_dataloader) with torch.no_grad(), tqdm( - total=val_dataloader, dynamic_ncols=True, disable=not is_last_rank(), desc='Evaluate: ') as prog_bar: + total=args.eval_iters, dynamic_ncols=True, disable=not is_last_rank(), desc='Evaluate: ') as prog_bar: iteration = 0 while iteration < args.eval_iters: prog_bar.update() @@ -1201,7 +1200,7 @@ def evaluate(self, val_dataloader): metrics = forward_backward_func( forward_step_func=self.forward_step, - data_iterator=val_dataloader, + data_iterator=val_data_iterator, model=self.wrapped_models, num_microbatches=get_num_microbatches(), seq_length=args.max_length, @@ -1215,7 +1214,7 @@ def evaluate(self, val_dataloader): for m in self.wrapped_models: m.train() - def train_step(self, train_dataloader): + def train_step(self, train_data_iterator): args = self.args forward_backward_func = get_forward_backward_func() for m in self.wrapped_models: @@ -1223,7 +1222,7 @@ def train_step(self, train_dataloader): self.optimizer.zero_grad() metrics = forward_backward_func( forward_step_func=self.forward_step, - data_iterator=train_dataloader, + data_iterator=train_data_iterator, model=self.wrapped_models, num_microbatches=get_num_microbatches(), seq_length=args.max_length, @@ -1242,14 +1241,15 @@ def train_step(self, train_dataloader): args.consumed_train_samples += increment self.scheduler.step(increment=increment) - return metrics, grad_norm + metrics['grad_norm'] = torch.tensor([grad_norm], dtype=torch.float32, device=torch.cuda.current_device()) + return metrics def aggregated_metrics(self, metrics, total_metrics): for key in metrics[0].keys(): if key not in total_metrics: total_metrics[key] = torch.tensor([0.0], dtype=torch.float32, device=torch.cuda.current_device()) val = [x[key].view(-1) for x in metrics] - val = torch.concat(val, dim=0) + val = torch.stack(val, dim=0) if val[0].numel() == 2: val = val.sum(dim=0) total_metrics[key] += val[0] / val[1] @@ -1304,7 +1304,7 @@ def forward_step(self, data_iterator, model): pass def _prepare_batch(self, data, vp_stage=None, num_samples=None): - batch = get_batch_on_this_tp_rank(data, vp_stage=vp_stage) + batch = get_batch_on_this_tp_rank(self.args, data, vp_stage=vp_stage) if num_samples is None: num_samples = batch.pop('num_samples') args = self.args @@ -1316,7 +1316,7 @@ def _prepare_batch(self, data, vp_stage=None, num_samples=None): batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) batch['packed_seq_params'].num_samples = num_samples # slice batch along sequence dimension for context parallelism - batch = get_batch_on_this_cp_rank(batch) + batch = get_batch_on_this_cp_rank(args, batch) return batch def get_batch(self, data_iterator, vp_stage=None): diff --git a/swift/megatron/trainers/batch_sampler.py b/swift/megatron/trainers/batch_sampler.py index 35656169d8..6c981df9a4 100644 --- a/swift/megatron/trainers/batch_sampler.py +++ b/swift/megatron/trainers/batch_sampler.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. import torch from swift.utils import get_logger logger = get_logger() diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index 6808f76cc0..5d7c8ef987 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -5,7 +5,6 @@ import torch import torch.nn from megatron.core import mpu -from megatron.core.rerun_state_machine import get_rerun_state_machine from torch.distributed.nn import all_reduce from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -80,38 +79,6 @@ def loss_func(self, if args.context_parallel_size > 1 and not self.mcore_013: loss = all_reduce(loss, group=mpu.get_context_parallel_group()) - # Check individual rank losses are not NaN prior to DP all-reduce. - rerun_state_machine = get_rerun_state_machine() - if args.check_for_nan_in_loss_and_grad: - rerun_state_machine.validate_result( - result=loss[0], - rejection_func=torch.isnan, - message='found NaN in local forward loss calculation', - tolerance=0.0, # forward pass calculations are determinisic - fatal=True, - ) - rerun_state_machine.validate_result( - result=loss[0], - rejection_func=torch.isinf, - message='found Inf in local forward loss calculation', - tolerance=0.0, # forward pass calculations are determinisic - fatal=True, - ) - # Check for spiky loss - if args.check_for_spiky_loss: - # define spiky loss as a loss that's 10x the max loss observed - SPIKY_LOSS_FACTOR = 10 - rerun_state_machine.validate_result( - result=loss[0], - rejection_func=partial( - rerun_state_machine.is_unexpectedly_large, - threshold=SPIKY_LOSS_FACTOR, - context='loss', - ), - message='Spiky loss', - tolerance=0.0, # forward pass calculations are determinisic - fatal=False, - ) # Reduce loss for logging. reporting_loss = loss.detach().clone() lm_loss = loss[0] diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index c426d2b838..4fa07b1d86 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -368,14 +368,9 @@ def reduce_max_stat_across_model_parallel_group(stat: float) -> float: We use an all_reduce max since the values have already been summed across optimizer ranks where possible """ - if stat is None: - stat = -1.0 stat = torch.tensor([stat], dtype=torch.float32, device=torch.cuda.current_device()) torch.distributed.all_reduce(stat, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()) - if stat.item() == -1.0: - return None - else: - return stat.item() + return stat def logical_and_across_model_parallel_group(input: bool) -> bool: From c17b863e9408a8f8da914650332ca51f00b79603 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Feb 2026 17:39:13 +0800 Subject: [PATCH 39/43] update --- swift/megatron/arguments/megatron_args.py | 79 ++++++++++++++++++++++- swift/megatron/trainers/base.py | 2 +- 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 2f5af77fdf..c4af50c52c 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -623,9 +623,86 @@ def __post_init__(self): initialize_megatron(self) def _init_vpp_size(self): - # TODO + # TODO: self.virtual_pipeline_model_parallel_size = None + if self.pipeline_model_parallel_layout is not None: + # Parse the input flattened layout to a list and get the vpp size. + # We will validate the layout more carefully in the TransformerConfig constructor. + num_stages = PipelineParallelLayerLayout.get_num_stages_from_str(args.pipeline_model_parallel_layout) + assert num_stages % self.pipeline_model_parallel_size == 0, ( + f"The length of pipeline_model_parallel_layout must be divisible" + f" by pipeline_model_parallel_size ({num_stages=}," + f" {self.pipeline_model_parallel_size=})" + ) + self.virtual_pipeline_model_parallel_size = num_stages // self.pipeline_model_parallel_size + if self.virtual_pipeline_model_parallel_size == 1: + self.virtual_pipeline_model_parallel_size = None + elif self.num_layers_per_virtual_pipeline_stage is not None or self.num_virtual_stages_per_pipeline_rank is not None: + if self.num_virtual_stages_per_pipeline_rank is None: + assert self.decoder_first_pipeline_num_layers is None and self.decoder_last_pipeline_num_layers is None, \ + 'please use --num-virtual-stages-per-pipeline-rank to specify virtual pipeline parallel degree when enable uneven pipeline parallelism' + if self.num_layers is not None: + num_layers = self.num_layers + else: + num_layers = self.decoder_num_layers + + if args.account_for_embedding_in_pipeline_split: + num_layers += 1 + + if args.account_for_loss_in_pipeline_split: + num_layers += 1 + + assert num_layers % args.transformer_pipeline_model_parallel_size == 0, \ + 'number of layers of the model must be divisible pipeline model parallel size' + num_layers_per_pipeline_stage = num_layers // args.transformer_pipeline_model_parallel_size + + assert num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage == 0, \ + 'number of layers per pipeline stage must be divisible number of layers per virtual pipeline stage' + args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \ + args.num_layers_per_virtual_pipeline_stage + else: + args.virtual_pipeline_model_parallel_size = args.num_virtual_stages_per_pipeline_rank + if args.virtual_pipeline_model_parallel_size == 1: + args.virtual_pipeline_model_parallel_size = None + else: + args.virtual_pipeline_model_parallel_size = None + + if args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None: + # Divisibility check not applicable for T5 models which specify encoder_num_layers + # and decoder_num_layers. + if args.num_layers is not None: + num_layers = args.num_layers + + if args.account_for_embedding_in_pipeline_split: + num_layers += 1 + + if args.account_for_loss_in_pipeline_split: + num_layers += 1 + + assert num_layers % args.transformer_pipeline_model_parallel_size == 0, \ + 'Number of layers should be divisible by the pipeline-model-parallel size' + + if args.virtual_pipeline_model_parallel_size is not None: + if args.overlap_p2p_comm: + assert args.pipeline_model_parallel_size > 1, \ + 'When interleaved schedule is used, pipeline-model-parallel size ' \ + 'should be greater than 1' + else: + assert args.pipeline_model_parallel_size > 2, \ + 'When interleaved schedule is used and p2p communication overlap is disabled, ' \ + 'pipeline-model-parallel size should be greater than 2 to avoid having multiple ' \ + 'p2p sends and recvs between same 2 ranks per communication batch' + else: + # Overlap P2P communication is disabled if not using the interleaved schedule. + args.overlap_p2p_comm = False + args.align_param_gather = False + # Only print warning if PP size > 1. + if args.rank == 0 and args.pipeline_model_parallel_size > 1: + print('WARNING: Setting args.overlap_p2p_comm and args.align_param_gather to False ' + 'since non-interleaved schedule does not support overlapping p2p communication ' + 'and aligned param AG') + def _load_adapter_config(self): assert len(self.adapters) == 1, 'Currently only support one adapter' adapter_path = self.adapters[0] diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index ad4ab814a1..7acd33d418 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -1164,7 +1164,7 @@ def train(self, train_dataset, val_dataset): self.save_checkpoint() def save_checkpoint(self): - print + pass def training_log(self, metrics, grad_norm): learning_rate = None From 17a2df6404c70fd0389c38b1cb02eed9a1769ca6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 7 Feb 2026 14:15:53 +0800 Subject: [PATCH 40/43] update --- swift/megatron/arguments/megatron_args.py | 10 +++--- swift/megatron/callbacks/__init__.py | 1 + swift/megatron/callbacks/base.py | 40 +++++++++++++++++++++++ swift/megatron/callbacks/default_flow.py | 1 + swift/megatron/callbacks/mapping.py | 4 +++ swift/megatron/callbacks/print.py | 25 ++++++++++++++ swift/megatron/callbacks/swanlab.py | 0 swift/megatron/callbacks/wandb.py | 0 swift/megatron/trainers/__init__.py | 2 ++ swift/megatron/trainers/base.py | 38 ++++++++++++--------- swift/megatron/trainers/batch_sampler.py | 3 ++ swift/megatron/trainers/utils.py | 13 ++++++++ swift/trainers/mixin.py | 2 +- 13 files changed, 117 insertions(+), 22 deletions(-) create mode 100644 swift/megatron/callbacks/__init__.py create mode 100644 swift/megatron/callbacks/base.py create mode 100644 swift/megatron/callbacks/default_flow.py create mode 100644 swift/megatron/callbacks/mapping.py create mode 100644 swift/megatron/callbacks/print.py create mode 100644 swift/megatron/callbacks/swanlab.py create mode 100644 swift/megatron/callbacks/wandb.py diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index abd29d936a..c540250811 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -504,6 +504,7 @@ class MegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): num_labels: Optional[int] = None problem_type: Literal['regression', 'single_label_classification', 'multi_label_classification'] = None save_strategy: Literal['steps', 'epoch'] = 'steps' + callbacks: List[str] = field(default_factory=list) report_to: Optional[Literal['wandb', 'swanlab']] = None @@ -627,16 +628,15 @@ def __post_init__(self): def _init_vpp_size(self): # TODO: self.virtual_pipeline_model_parallel_size = None - + return if self.pipeline_model_parallel_layout is not None: # Parse the input flattened layout to a list and get the vpp size. # We will validate the layout more carefully in the TransformerConfig constructor. num_stages = PipelineParallelLayerLayout.get_num_stages_from_str(args.pipeline_model_parallel_layout) assert num_stages % self.pipeline_model_parallel_size == 0, ( - f"The length of pipeline_model_parallel_layout must be divisible" - f" by pipeline_model_parallel_size ({num_stages=}," - f" {self.pipeline_model_parallel_size=})" - ) + f'The length of pipeline_model_parallel_layout must be divisible' + f' by pipeline_model_parallel_size ({num_stages=},' + f' {self.pipeline_model_parallel_size=})') self.virtual_pipeline_model_parallel_size = num_stages // self.pipeline_model_parallel_size if self.virtual_pipeline_model_parallel_size == 1: self.virtual_pipeline_model_parallel_size = None diff --git a/swift/megatron/callbacks/__init__.py b/swift/megatron/callbacks/__init__.py new file mode 100644 index 0000000000..bc6b0abe5b --- /dev/null +++ b/swift/megatron/callbacks/__init__.py @@ -0,0 +1 @@ +from .mapping import megatron_callbacks_map diff --git a/swift/megatron/callbacks/base.py b/swift/megatron/callbacks/base.py new file mode 100644 index 0000000000..402d5f1732 --- /dev/null +++ b/swift/megatron/callbacks/base.py @@ -0,0 +1,40 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from swift.megatron.trainers import BaseMegatronTrainer + from swift.megatron.arguments import MegatronArguments + + +class MegatronCallback: + + def __init__(self, trainer: 'BaseMegatronTrainer'): + self.trainer = trainer + self.args = trainer.args + self.state = trainer.state + + def on_train_begin(self): + pass + + def on_train_end(self): + pass + + def on_step_end(self): + pass + + def on_epoch_begin(self): + pass + + def on_epoch_end(self): + pass + + def on_log(self, logs): + pass + + def on_evaluate_begin(self): + pass + + def on_evaluate_end(self): + pass + + def on_evaluate_step(self): + pass diff --git a/swift/megatron/callbacks/default_flow.py b/swift/megatron/callbacks/default_flow.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/swift/megatron/callbacks/default_flow.py @@ -0,0 +1 @@ + diff --git a/swift/megatron/callbacks/mapping.py b/swift/megatron/callbacks/mapping.py new file mode 100644 index 0000000000..d20a068bcb --- /dev/null +++ b/swift/megatron/callbacks/mapping.py @@ -0,0 +1,4 @@ + +megatron_callbacks_map = { + 'print': PrintCallback, +} diff --git a/swift/megatron/callbacks/print.py b/swift/megatron/callbacks/print.py new file mode 100644 index 0000000000..f34e130036 --- /dev/null +++ b/swift/megatron/callbacks/print.py @@ -0,0 +1,25 @@ +from .base import MegatronCallback +from swift.utils import is_master + +class PrintCallback(MegatronCallback): + + def __init__(self, trainer): + super().__init__(trainer) + self.training_bar = None + self.eval_bar = None + + def on_train_begin(self): + self.training_bar = tqdm(total=self.args.train_iters, dynamic_ncols=True, disable=not is_master(), desc='Train: ') + + def on_train_end(self): + self.training_bar.close() + self.training_bar = None + + def on_train_step(self): + self.training_bar.update() + + def on_eval_begin(self): + self.eval_bar = tqdm(total=eval_iters, dynamic_ncols=True, disable=not is_master(), desc='Evaluate: ') + + def on_log(self, logs): + print() diff --git a/swift/megatron/callbacks/swanlab.py b/swift/megatron/callbacks/swanlab.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/swift/megatron/callbacks/wandb.py b/swift/megatron/callbacks/wandb.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/swift/megatron/trainers/__init__.py b/swift/megatron/trainers/__init__.py index 0687fbda20..e73096965d 100644 --- a/swift/megatron/trainers/__init__.py +++ b/swift/megatron/trainers/__init__.py @@ -13,6 +13,7 @@ from .embedding_trainer import MegatronEmbeddingTrainer from .reranker_trainer import MegatronRerankerTrainer from .trainer import MegatronTrainer + from .base import BaseMegatronTrainer else: _import_structure = { 'dpo_trainer': ['MegatronDPOTrainer'], @@ -24,6 +25,7 @@ 'embedding_trainer': ['MegatronEmbeddingTrainer'], 'reranker_trainer': ['MegatronRerankerTrainer'], 'trainer': ['MegatronTrainer'], + 'base': ['BaseMegatronTrainer'], } import sys diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index afb13a8a24..7a87b6a5b1 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -19,12 +19,12 @@ from megatron.core.datasets.utils import Split from megatron.core.distributed import finalize_model_grads from megatron.core.enums import ModelType -from megatron.core.num_microbatches_calculator import get_num_microbatches, update_num_microbatches +from megatron.core.num_microbatches_calculator import (get_num_microbatches, init_num_microbatches_calculator, + update_num_microbatches) from megatron.core.optimizer import OptimizerConfig, _update_min_and_max_lr_in_param_groups, get_megatron_optimizer from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine from megatron.core.transformer.module import Float16Module, MegatronModule -from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator from megatron.core.transformer.moe.moe_utils import track_moe_metrics from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper # from megatron.training import (checkpointing, ft_integration, get_args, get_model, get_tensorboard_writer, @@ -42,6 +42,7 @@ from swift.megatron.utils import (adapter_state_dict_context, copy_original_module_weight, get_optimizer_param_scheduler, get_padding_to, load_mcore_checkpoint, patch_merge_fn, prepare_mcore_model, wrap_model) +from swift.megatron.callbacks import megatron_callbacks_map from swift.metrics import MeanMetric from swift.template import Template from swift.trainers import SwiftMixin, dynamic_gradient_checkpointing @@ -51,6 +52,7 @@ from .batch_sampler import MegatronPretrainingRandomSampler, MegatronPretrainingSampler from .utils import (get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, logical_and_across_model_parallel_group, reduce_max_stat_across_model_parallel_group) +from .trainer_state import TrainerState # try: # from megatron.training.datasets.data_samplers import MegatronPretrainingSampler @@ -83,8 +85,7 @@ def __init__(self, args, template: Template): args.data_parallel_size, ) # TODO: resume_from_checkpoint - args.iteration = 0 - args.consumed_train_samples = 0 + self.state = TrainerState() if args.initialize_embedding: for m in self.unwrapped_models: self._initialize_embedding(m) @@ -123,6 +124,14 @@ def _get_mean_metric(): 'eval': collections.defaultdict(_get_mean_metric) } self.mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + self.args.callbacks += ['print', ''] + self.callbacks = [] + for callback in self.args.callbacks: + self.callbacks(megatron_callbacks_map[callback](self)) + + def call_event(self, event): + for callback in self.callbacks: + getattr(callback, event)() def prepare_model(self): args = self.args @@ -186,7 +195,7 @@ def cyclic_iter(self, iterable): if is_finished: # Note that this approach will train for one additional step. logger.info(f'Training of {n_epoch} epochs has been completed, the training has finished.') - args.train_iters = args.curr_iteration + 1 + args.train_iters = args.iteration + 1 def _replace_data_iterator(self, data_iterator, model): return data_iterator @@ -811,7 +820,7 @@ def custom_log(self, total_loss_dict, mode: Literal['train', 'eval'], iteration= self._remove_log(total_loss_dict) if iteration is None: args = self.args - iteration = args.curr_iteration + 1 + iteration = args.iteration + 1 if writer: for k, v in metrics.items(): writer.add_scalar(k, v, iteration) @@ -1191,8 +1200,8 @@ def evaluate(self, val_dataloader): eval_metrics = {} forward_backward_func = get_forward_backward_func() val_data_iterator = iter(val_dataloader) - with torch.no_grad(), tqdm( - total=args.eval_iters, dynamic_ncols=True, disable=not is_last_rank(), desc='Evaluate: ') as prog_bar: + + with torch.no_grad(): iteration = 0 while iteration < args.eval_iters: prog_bar.update() @@ -1230,16 +1239,13 @@ def train_step(self, train_data_iterator): forward_only=False, ) - update_successful, grad_norm, _ = self.optimizer.step() - update_successful = logical_and_across_model_parallel_group(update_successful) + _, grad_norm, _ = self.optimizer.step() grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm) - # Update learning rate. - if update_successful: - increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size - args.iteration += 1 - args.consumed_train_samples += increment - self.scheduler.step(increment=increment) + increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size + args.iteration += 1 + args.consumed_train_samples += increment + self.scheduler.step(increment=increment) metrics['grad_norm'] = torch.tensor([grad_norm], dtype=torch.float32, device=torch.cuda.current_device()) return metrics diff --git a/swift/megatron/trainers/batch_sampler.py b/swift/megatron/trainers/batch_sampler.py index 6c981df9a4..dfbc788bbe 100644 --- a/swift/megatron/trainers/batch_sampler.py +++ b/swift/megatron/trainers/batch_sampler.py @@ -1,8 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import torch + from swift.utils import get_logger + logger = get_logger() + # Code borrowed from megatron-lm class MegatronPretrainingSampler: diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 4fa07b1d86..b0349ddffc 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -15,6 +15,7 @@ from megatron.core.packed_seq_params import PackedSeqParams # from megatron.training import get_wandb_writer from packaging import version +from dataclasses import dataclass from transformers.utils import is_torch_npu_available from swift.utils import empty_cache, get_current_device, get_logger @@ -381,3 +382,15 @@ def logical_and_across_model_parallel_group(input: bool) -> bool: input = torch.tensor([input], dtype=torch.int, device=torch.cuda.current_device()) torch.distributed.all_reduce(input, op=torch.distributed.ReduceOp.MIN, group=mpu.get_model_parallel_group()) return bool(input.item()) + + +@dataclass +class TrainerState: + should_save: bool = False + should_evaluate: bool = False + should_log: bool = False + + iteration: int = 0 + eval_iteration: int = 0 + epoch: int = 0 + consumed_train_samples = 0 diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index b607d05679..24237936cb 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -148,7 +148,7 @@ def _get_data_collator(self, args, template): def _add_callbacks(self): for callback in self.args.callbacks: - self.add_callback(callbacks_map[callback](self.args, self)) + self.add_callback(callbacks_map[callback](self)) def _collect_config_info(self) -> Dict[str, str]: """ From f314806e955b38161072ebe6e8ccbd2eaff994e2 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 7 Feb 2026 15:53:12 +0800 Subject: [PATCH 41/43] update --- swift/megatron/arguments/megatron_args.py | 9 ++- swift/megatron/callbacks/__init__.py | 2 + swift/megatron/callbacks/base.py | 14 ++--- swift/megatron/callbacks/default_flow.py | 22 ++++++++ swift/megatron/callbacks/mapping.py | 4 ++ swift/megatron/callbacks/print.py | 14 ++++- swift/megatron/callbacks/swanlab.py | 1 + swift/megatron/callbacks/wandb.py | 1 + swift/megatron/trainers/base.py | 69 ++++++++++------------- swift/megatron/trainers/utils.py | 4 +- 10 files changed, 84 insertions(+), 56 deletions(-) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index c540250811..f2ff638825 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -595,15 +595,13 @@ def __post_init__(self): self.untie_embeddings_and_output_weights = True if self.gradient_checkpointing_kwargs is not None: self.gradient_checkpointing_kwargs = json_parse_to_dict(self.gradient_checkpointing_kwargs) - if self.save_strategy == 'epoch': - self.save_interval = 1 - self.eval_interval = 1 if self.gradient_accumulation_fusion: try: import apex except ImportError: logger.warning('apex is not installed, so gradient accumulation fusion is disabled.') self.gradient_accumulation_fusion = False + self.callbacks += ['print', 'default_flow'] if isinstance(self.ref_adapters, str): self.ref_adapters = [self.ref_adapters] if self.eval_interval is None: @@ -624,6 +622,7 @@ def __post_init__(self): initialize_megatron(self) total_model_size = self.tensor_model_parallel_size * self.pipeline_model_parallel_size * self.context_parallel_size self.data_parallel_size = self.world_size // total_model_size + self.num_micro_batches = self.global_batch_size // self.data_parallel_size def _init_vpp_size(self): # TODO: @@ -734,8 +733,8 @@ def init_iters(self, train_dataset, val_dataset): self.save_interval = dataset_sample // self.global_batch_size self.eval_interval = self.save_interval # TODO - if getattr(self, 'save_retain_interval', None) is not None: - self.save_retain_interval *= self.save_interval + # if getattr(self, 'save_retain_interval', None) is not None: + # self.save_retain_interval *= self.save_interval else: raise ValueError('streaming dataset is not supported with `--save_strategy epoch`.') if self.max_epochs is not None: diff --git a/swift/megatron/callbacks/__init__.py b/swift/megatron/callbacks/__init__.py index bc6b0abe5b..b045ca748d 100644 --- a/swift/megatron/callbacks/__init__.py +++ b/swift/megatron/callbacks/__init__.py @@ -1 +1,3 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .base import MegatronCallback from .mapping import megatron_callbacks_map diff --git a/swift/megatron/callbacks/base.py b/swift/megatron/callbacks/base.py index 402d5f1732..f155e671e4 100644 --- a/swift/megatron/callbacks/base.py +++ b/swift/megatron/callbacks/base.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -18,23 +19,20 @@ def on_train_begin(self): def on_train_end(self): pass - def on_step_end(self): - pass - - def on_epoch_begin(self): + def on_step_begin(self): pass - def on_epoch_end(self): + def on_step_end(self): pass def on_log(self, logs): pass - def on_evaluate_begin(self): + def on_eval_begin(self): pass - def on_evaluate_end(self): + def on_eval_end(self): pass - def on_evaluate_step(self): + def on_eval_step(self): pass diff --git a/swift/megatron/callbacks/default_flow.py b/swift/megatron/callbacks/default_flow.py index 8b13789179..8ffa3d29ef 100644 --- a/swift/megatron/callbacks/default_flow.py +++ b/swift/megatron/callbacks/default_flow.py @@ -1 +1,23 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .base import MegatronCallback + +class DefaultFlowCallbacks(MegatronCallback): + + def on_step_end(self): + args = self.args + state = self.state + + state.iteration += 1 + args.consumed_train_samples += args.global_batch_size + + if state.iteration == 1 or state.iteration % args.log_interval == 0: + self.state.should_log = True + if args.eval_interval and state.iteration % args.eval_interval == 0: + self.state.should_eval = True + if args.save_interval and state.iteration % args.save_interval == 0: + self.state.should_save = True + + if state.iteration >= args.train_iters: + self.state.should_eval = True + self.state.should_save = True diff --git a/swift/megatron/callbacks/mapping.py b/swift/megatron/callbacks/mapping.py index d20a068bcb..f036200985 100644 --- a/swift/megatron/callbacks/mapping.py +++ b/swift/megatron/callbacks/mapping.py @@ -1,4 +1,8 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .default_flow import DefaultFlowCallback +from .print import PrintCallback megatron_callbacks_map = { 'print': PrintCallback, + 'default_flow': DefaultFlowCallback, } diff --git a/swift/megatron/callbacks/print.py b/swift/megatron/callbacks/print.py index f34e130036..1ce9dd4c76 100644 --- a/swift/megatron/callbacks/print.py +++ b/swift/megatron/callbacks/print.py @@ -1,5 +1,7 @@ -from .base import MegatronCallback +# Copyright (c) ModelScope Contributors. All rights reserved. from swift.utils import is_master +from .base import MegatronCallback + class PrintCallback(MegatronCallback): @@ -9,7 +11,8 @@ def __init__(self, trainer): self.eval_bar = None def on_train_begin(self): - self.training_bar = tqdm(total=self.args.train_iters, dynamic_ncols=True, disable=not is_master(), desc='Train: ') + self.training_bar = tqdm( + total=self.args.train_iters, dynamic_ncols=True, disable=not is_master(), desc='Train: ') def on_train_end(self): self.training_bar.close() @@ -21,5 +24,12 @@ def on_train_step(self): def on_eval_begin(self): self.eval_bar = tqdm(total=eval_iters, dynamic_ncols=True, disable=not is_master(), desc='Evaluate: ') + def on_eval_end(self): + self.eval_bar.close() + self.eval_bar = None + + def on_eval_step(self): + self.eval_bar.update() + def on_log(self, logs): print() diff --git a/swift/megatron/callbacks/swanlab.py b/swift/megatron/callbacks/swanlab.py index e69de29bb2..85b3e739d7 100644 --- a/swift/megatron/callbacks/swanlab.py +++ b/swift/megatron/callbacks/swanlab.py @@ -0,0 +1 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. diff --git a/swift/megatron/callbacks/wandb.py b/swift/megatron/callbacks/wandb.py index e69de29bb2..85b3e739d7 100644 --- a/swift/megatron/callbacks/wandb.py +++ b/swift/megatron/callbacks/wandb.py @@ -0,0 +1 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 7a87b6a5b1..036e38669b 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -19,8 +19,6 @@ from megatron.core.datasets.utils import Split from megatron.core.distributed import finalize_model_grads from megatron.core.enums import ModelType -from megatron.core.num_microbatches_calculator import (get_num_microbatches, init_num_microbatches_calculator, - update_num_microbatches) from megatron.core.optimizer import OptimizerConfig, _update_min_and_max_lr_in_param_groups, get_megatron_optimizer from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine @@ -37,12 +35,12 @@ from packaging import version from tqdm.auto import tqdm +from swift.megatron.callbacks import megatron_callbacks_map from swift.megatron.model import get_mcore_model from swift.megatron.tuners import LoraParallelLinear from swift.megatron.utils import (adapter_state_dict_context, copy_original_module_weight, get_optimizer_param_scheduler, get_padding_to, load_mcore_checkpoint, patch_merge_fn, prepare_mcore_model, wrap_model) -from swift.megatron.callbacks import megatron_callbacks_map from swift.metrics import MeanMetric from swift.template import Template from swift.trainers import SwiftMixin, dynamic_gradient_checkpointing @@ -50,9 +48,9 @@ from swift.utils import (JsonlWriter, deep_getattr, format_time, get_last_valid_indices, get_logger, is_last_rank, ms_logger_context) from .batch_sampler import MegatronPretrainingRandomSampler, MegatronPretrainingSampler +from .trainer_state import TrainerState from .utils import (get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, logical_and_across_model_parallel_group, reduce_max_stat_across_model_parallel_group) -from .trainer_state import TrainerState # try: # from megatron.training.datasets.data_samplers import MegatronPretrainingSampler @@ -77,13 +75,6 @@ def __init__(self, args, template: Template): self.config = self.unwrapped_models[0].config self.optimizer, self.scheduler = self.get_optimizer_and_scheduler() self.data_collator = self._get_data_collator() - init_num_microbatches_calculator( - args.rank, - None, - args.global_batch_size, - args.micro_batch_size, - args.data_parallel_size, - ) # TODO: resume_from_checkpoint self.state = TrainerState() if args.initialize_embedding: @@ -97,7 +88,7 @@ def __init__(self, args, template: Template): load_mcore_checkpoint(args, self.wrapped_models, load_arg='ref_adapter_load') if args.adapter_load is not None: with adapter_state_dict_context(): - args.iteration = load_mcore_checkpoint( + state.iteration = load_mcore_checkpoint( args, self.wrapped_models, self.optimizer, self.scheduler, load_arg='adapter_load') self.eval_metrics = None @@ -124,14 +115,13 @@ def _get_mean_metric(): 'eval': collections.defaultdict(_get_mean_metric) } self.mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') - self.args.callbacks += ['print', ''] self.callbacks = [] for callback in self.args.callbacks: self.callbacks(megatron_callbacks_map[callback](self)) - def call_event(self, event): + def call_event(self, event, *args, **kwargs): for callback in self.callbacks: - getattr(callback, event)() + getattr(callback, event)(*args, **kwargs) def prepare_model(self): args = self.args @@ -181,21 +171,21 @@ def cyclic_iter(self, iterable): assert training, 'training must be True' args = self.args - n_epoch = 0 + state = self.state is_finished = False while True: if not is_finished: - logger.info(f'The training of Epoch {n_epoch} starts...') + logger.info(f'The training of Epoch {state.epoch} starts...') for x in iterable: yield x # streaming - if training and args.max_epochs and n_epoch >= args.max_epochs - 1: + if training and args.max_epochs and state.epoch >= args.max_epochs - 1: is_finished = True - n_epoch += 1 + state.epoch += 1 if is_finished: # Note that this approach will train for one additional step. - logger.info(f'Training of {n_epoch} epochs has been completed, the training has finished.') - args.train_iters = args.iteration + 1 + logger.info(f'Training of {state.epoch} epochs has been completed, the training has finished.') + args.train_iters = state.iteration + 1 def _replace_data_iterator(self, data_iterator, model): return data_iterator @@ -486,9 +476,6 @@ def _load_iteration(self): checkpoint_args = state_dict['args'] check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr(checkpoint_args, 'consumed_train_samples', 0) - update_num_microbatches(consumed_samples=args.consumed_train_samples, verbose=True) - # TODO: ignore - # args.consumed_valid_samples = getattr(checkpoint_args, 'consumed_valid_samples', 0) else: print_rank_0('could not find arguments in the checkpoint ...') @@ -499,7 +486,7 @@ def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **k # read iteration args = self.args if not args.finetune: - args.iteration = self._load_iteration() + state.iteration = self._load_iteration() if args.apply_wd_to_qk_layernorm or self.args.vit_lr is not None or self.args.aligner_lr is not None: param_groups_context = self._patch_get_param_groups() @@ -820,7 +807,7 @@ def custom_log(self, total_loss_dict, mode: Literal['train', 'eval'], iteration= self._remove_log(total_loss_dict) if iteration is None: args = self.args - iteration = args.iteration + 1 + iteration = state.iteration + 1 if writer: for k, v in metrics.items(): writer.add_scalar(k, v, iteration) @@ -1167,20 +1154,28 @@ def train(self, train_dataset, val_dataset): self.config.finalize_model_grads_func = finalize_model_grads # TODO: manual_gc + self.call_event('on_train_begin') train_metrics = {} train_data_iterator = iter(self.cyclic_iter(train_dataloader)) - while args.iteration < args.train_iters: + state = self.state + while state.iteration < args.train_iters: + self.call_event('on_step_begin') metrics = self.train_step(train_data_iterator) + self.call_event('on_step_end') if mpu.is_pipeline_last_stage(ignore_virtual=True): self.aggregated_metrics(metrics, train_metrics) - self.training_log(train_metrics, grad_norm) + self.call_event('on_log', train_metrics) - if args.eval_interval and args.iteration % args.eval_interval == 0: + if state.should_eval: + state.should_eval = False self.evaluate(val_dataloader) - if args.save and args.save_interval and args.iteration % args.save_interval == 0: + if state.should_save: + state.should_save = False self.save_checkpoint() + self.call_event('on_train_end') + def save_checkpoint(self): pass @@ -1201,27 +1196,27 @@ def evaluate(self, val_dataloader): forward_backward_func = get_forward_backward_func() val_data_iterator = iter(val_dataloader) + self.call_event('on_eval_begin') with torch.no_grad(): iteration = 0 while iteration < args.eval_iters: - prog_bar.update() - iteration += 1 - metrics = forward_backward_func( forward_step_func=self.forward_step, data_iterator=val_data_iterator, model=self.wrapped_models, - num_microbatches=get_num_microbatches(), + num_microbatches=self.num_micro_batches, seq_length=args.max_length, micro_batch_size=args.micro_batch_size, forward_only=True, ) + self.call_event('on_eval_step') if mpu.is_pipeline_last_stage(ignore_virtual=True): self.aggregated_metrics(metrics, eval_metrics) # TODO: log metrics logger.info(f'eval_metrics: {eval_metrics}') for m in self.wrapped_models: m.train() + self.call_event('on_eval_end') def train_step(self, train_data_iterator): args = self.args @@ -1233,7 +1228,7 @@ def train_step(self, train_data_iterator): forward_step_func=self.forward_step, data_iterator=train_data_iterator, model=self.wrapped_models, - num_microbatches=get_num_microbatches(), + num_microbatches=args.num_micro_batches, seq_length=args.max_length, micro_batch_size=args.micro_batch_size, forward_only=False, @@ -1241,10 +1236,6 @@ def train_step(self, train_data_iterator): _, grad_norm, _ = self.optimizer.step() grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm) - - increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size - args.iteration += 1 - args.consumed_train_samples += increment self.scheduler.step(increment=increment) metrics['grad_norm'] = torch.tensor([grad_norm], dtype=torch.float32, device=torch.cuda.current_device()) diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index b0349ddffc..68069ddf64 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -3,6 +3,7 @@ import gc import time from contextlib import contextmanager +from dataclasses import dataclass from typing import Any, Dict, Optional import megatron.core @@ -15,7 +16,6 @@ from megatron.core.packed_seq_params import PackedSeqParams # from megatron.training import get_wandb_writer from packaging import version -from dataclasses import dataclass from transformers.utils import is_torch_npu_available from swift.utils import empty_cache, get_current_device, get_logger @@ -387,7 +387,7 @@ def logical_and_across_model_parallel_group(input: bool) -> bool: @dataclass class TrainerState: should_save: bool = False - should_evaluate: bool = False + should_eval: bool = False should_log: bool = False iteration: int = 0 From 2088a44d67e09cca5accd7401651563154664f1f Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 7 Feb 2026 21:31:08 +0800 Subject: [PATCH 42/43] update --- swift/megatron/callbacks/default_flow.py | 4 +-- swift/megatron/callbacks/print.py | 4 ++- swift/megatron/trainers/base.py | 35 ++++++++++++++++-------- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/swift/megatron/callbacks/default_flow.py b/swift/megatron/callbacks/default_flow.py index 8ffa3d29ef..7e3d63914b 100644 --- a/swift/megatron/callbacks/default_flow.py +++ b/swift/megatron/callbacks/default_flow.py @@ -2,14 +2,14 @@ from .base import MegatronCallback -class DefaultFlowCallbacks(MegatronCallback): +class DefaultFlowCallback(MegatronCallback): def on_step_end(self): args = self.args state = self.state state.iteration += 1 - args.consumed_train_samples += args.global_batch_size + state.consumed_train_samples += args.global_batch_size if state.iteration == 1 or state.iteration % args.log_interval == 0: self.state.should_log = True diff --git a/swift/megatron/callbacks/print.py b/swift/megatron/callbacks/print.py index 1ce9dd4c76..6c0776a053 100644 --- a/swift/megatron/callbacks/print.py +++ b/swift/megatron/callbacks/print.py @@ -1,4 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from tqdm import tqdm + from swift.utils import is_master from .base import MegatronCallback @@ -22,7 +24,7 @@ def on_train_step(self): self.training_bar.update() def on_eval_begin(self): - self.eval_bar = tqdm(total=eval_iters, dynamic_ncols=True, disable=not is_master(), desc='Evaluate: ') + self.eval_bar = tqdm(total=self.argseval_iters, dynamic_ncols=True, disable=not is_master(), desc='Evaluate: ') def on_eval_end(self): self.eval_bar.close() diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 036e38669b..0b9d2d20cf 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -48,8 +48,7 @@ from swift.utils import (JsonlWriter, deep_getattr, format_time, get_last_valid_indices, get_logger, is_last_rank, ms_logger_context) from .batch_sampler import MegatronPretrainingRandomSampler, MegatronPretrainingSampler -from .trainer_state import TrainerState -from .utils import (get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, +from .utils import (TrainerState, get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, logical_and_across_model_parallel_group, reduce_max_stat_across_model_parallel_group) # try: @@ -117,12 +116,23 @@ def _get_mean_metric(): self.mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') self.callbacks = [] for callback in self.args.callbacks: - self.callbacks(megatron_callbacks_map[callback](self)) + self.callbacks.append(megatron_callbacks_map[callback](self)) def call_event(self, event, *args, **kwargs): + if event == 'on_log': + self._log_callback(*args, **kwargs) for callback in self.callbacks: getattr(callback, event)(*args, **kwargs) + def _log_callback(self, logs): + n_iters = logs.pop('n_iters') + for k, v in logs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + if isinstance(v, float): + v = round(v, 8) + logs[k] = v / n_iters + def prepare_model(self): args = self.args hf_config = self.template.config @@ -187,9 +197,6 @@ def cyclic_iter(self, iterable): logger.info(f'Training of {state.epoch} epochs has been completed, the training has finished.') args.train_iters = state.iteration + 1 - def _replace_data_iterator(self, data_iterator, model): - return data_iterator - def _load_adapter_base_checkpoint(self, *_args, **kwargs): adapter_name = kwargs.pop('adapter_name', None) or 'ref_adapter' sharded_state_dict = kwargs.get('sharded_state_dict') @@ -1160,11 +1167,13 @@ def train(self, train_dataset, val_dataset): state = self.state while state.iteration < args.train_iters: self.call_event('on_step_begin') - metrics = self.train_step(train_data_iterator) + metrics, grad_norm = self.train_step(train_data_iterator) self.call_event('on_step_end') if mpu.is_pipeline_last_stage(ignore_virtual=True): self.aggregated_metrics(metrics, train_metrics) - self.call_event('on_log', train_metrics) + train_metrics['grad_norm'] = torch.tensor([grad_norm], dtype=torch.float32, device=torch.cuda.current_device()) + if state.should_log: + self.call_event('on_log', train_metrics) if state.should_eval: state.should_eval = False @@ -1236,12 +1245,14 @@ def train_step(self, train_data_iterator): _, grad_norm, _ = self.optimizer.step() grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm) - self.scheduler.step(increment=increment) + self.scheduler.step(increment=args.global_batch_size) - metrics['grad_norm'] = torch.tensor([grad_norm], dtype=torch.float32, device=torch.cuda.current_device()) - return metrics + return metrics, grad_norm def aggregated_metrics(self, metrics, total_metrics): + if 'n_iters' not in total_metrics: + total_metrics['n_iters'] = torch.tensor([0], dtype=torch.int64, device=torch.cuda.current_device()) + total_metrics['n_iters'] += 1 for key in metrics[0].keys(): if key not in total_metrics: total_metrics[key] = torch.tensor([0.0], dtype=torch.float32, device=torch.cuda.current_device()) @@ -1261,7 +1272,7 @@ def prepare_dataloader(self, train_dataset, val_dataset): train_batch_sampler = MegatronPretrainingRandomSampler( train_dataset, total_samples=len(train_dataset), - consumed_samples=args.consumed_train_samples, + consumed_samples=self.state.consumed_train_samples, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size(), From 002e6d9a84fedc53e275961fc9eea096a57379ff Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 7 Feb 2026 22:17:32 +0800 Subject: [PATCH 43/43] update --- swift/megatron/callbacks/print.py | 23 +- swift/megatron/trainers/base.py | 509 +----------------------------- swift/trainers/patcher.py | 5 +- 3 files changed, 25 insertions(+), 512 deletions(-) diff --git a/swift/megatron/callbacks/print.py b/swift/megatron/callbacks/print.py index 6c0776a053..3b1488a5a8 100644 --- a/swift/megatron/callbacks/print.py +++ b/swift/megatron/callbacks/print.py @@ -1,7 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import os + +import torch from tqdm import tqdm -from swift.utils import is_master +from swift.utils import JsonlWriter, format_time, is_master from .base import MegatronCallback @@ -11,10 +14,15 @@ def __init__(self, trainer): super().__init__(trainer) self.training_bar = None self.eval_bar = None + self.jsonl_writer = None def on_train_begin(self): self.training_bar = tqdm( total=self.args.train_iters, dynamic_ncols=True, disable=not is_master(), desc='Train: ') + self.start_step = state.iteration + self.start_time = time.time() + logging_path = os.path.join(self.args.save, 'logging.jsonl') + self.jsonl_writer = JsonlWriter(logging_path, enable_async=True, write_on_rank='last') def on_train_end(self): self.training_bar.close() @@ -34,4 +42,15 @@ def on_eval_step(self): self.eval_bar.update() def on_log(self, logs): - print() + state = self.state + args = self.args + logs['iteration'] = f'{state.iteration}/{args.train_iters}' + elapsed = time.time() - self.start_time + logs['elapsed_time'] = format_time(elapsed) + n_steps = state.iteration - self.start_step + train_speed = elapsed / n_steps if n_steps > 0 else 0.0 + logs['remaining_time'] = format_time((args.train_iters - state.iteration) * train_speed) + logs['memory(GiB)'] = round(torch.cuda.max_memory_reserved() / 1024**3, 2) + logs['train_speed(s/it)'] = round(train_speed, 6) + self.jsonl_writer.append(logs) + self.training_bar.write(str(logs)) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 0b9d2d20cf..81860a24d9 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -93,7 +93,6 @@ def __init__(self, args, template: Template): self.eval_metrics = None logging_path = os.path.join(args.save, 'logging.jsonl') logger.info(f'logging_path: {logging_path}') - self.jsonl_writer = JsonlWriter(logging_path, enable_async=True, write_on_rank='last') # for evaluate if args.check_model and hasattr(args, 'model_dir'): with ms_logger_context(logging.CRITICAL), patch_modelscope_hub_timeout(): @@ -548,254 +547,6 @@ def _all_reduce_metric(self, torch.distributed.all_reduce(reporting_metric, reduction, group=mpu.get_data_parallel_group()) return {k: reporting_metric[i] for i, k in enumerate(metric.keys())} - # Code borrowed from NVIDIA/Megatron-LM - def _evaluate( - self, - forward_step_func, - data_iterator, - model, - process_non_loss_data_func, - config, - verbose=False, - non_loss_data_func=None, - eval_iters=None, - ): - """Evaluation.""" - args = self.args - timers = get_timers() - - timers('evaluate', log_level=0).start(barrier=True) - if args.vision_pretraining and args.vision_pretraining_type == 'dino': - from megatron.legacy.model.vision.knn_monitor import compute_feature_bank - compute_feature_bank(model) - - # Turn on evaluation mode which disables dropout. - for model_module in model: - model_module.eval() - - # Disable result validation during evaluation - rerun_state_machine = get_rerun_state_machine() - rerun_mode = rerun_state_machine.get_mode() - rerun_state_machine.set_mode(RerunMode.DISABLED) - - total_loss_dict = {} - - # make validation batch size independent from training batch size - eval_batch_size = args.global_batch_size - eval_num_microbatches = eval_batch_size // (args.micro_batch_size * args.data_parallel_size) - forward_backward_func = get_forward_backward_func() - if args.enable_cuda_graph and args.cuda_graph_scope == 'full_iteration': - from megatron.core.full_cuda_graph import FullCudaGraphWrapper - forward_backward_func = FullCudaGraphWrapper( - forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) - - if eval_iters is None: - eval_iters = args.eval_iters - - with torch.no_grad(), tqdm( - total=eval_iters, dynamic_ncols=True, disable=not is_last_rank(), desc='Evaluate: ') as prog_bar: - iteration = 0 - if verbose: - print_rank_0(f'Evaluating on {eval_iters * eval_batch_size} samples') - while iteration < eval_iters: - iteration += 1 - prog_bar.update() - if verbose: - print_rank_0(f'Evaluating iter {iteration}/{eval_iters}') - - # Don't care about timing during evaluation - config.timers = None - ft_integration.on_eval_step_start() - new_data_iterator = self._replace_data_iterator(data_iterator, model) - loss_dicts = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=new_data_iterator, - model=model, - num_microbatches=eval_num_microbatches, - seq_length=args.seq_length, - micro_batch_size=args.micro_batch_size, - forward_only=True, - ) - ft_integration.on_eval_step_end() - config.timers = get_timers() - - # Empty unused memory - if args.empty_unused_memory_level >= 1: - torch.cuda.empty_cache() - - if mpu.is_pipeline_last_stage(ignore_virtual=True): - if self.mcore_013: - for key in loss_dicts[0].keys(): - if key not in total_loss_dict: - total_loss_dict[key] = torch.tensor([0.0, 0.0], dtype=torch.float).cuda() - val = [x[key].view(-1) for x in loss_dicts] - if val[0].numel() == 2: - val = torch.vstack(val).sum(dim=0) - torch.distributed.all_reduce( - val, group=mpu.get_data_parallel_group(with_context_parallel=True)) - total_loss_dict[key] += val - elif val[0].numel() == 1: - val = torch.cat(val).sum() - total_loss_dict[key][0] += val - total_loss_dict[key][1] += len(loss_dicts) - else: - raise ValueError(f'Invalid value shape: {val[0].shape} for key {key}') - else: - # Reduce across processes. - for loss_dict in loss_dicts: - for key in loss_dict: - if key not in total_loss_dict: - total_loss_dict[key] = torch.tensor([0.0, 0.0], dtype=torch.float).cuda() - val = loss_dict[key] - if isinstance(val, tuple) or isinstance(val, list): - total_loss_dict[key][0] += val[0] - total_loss_dict[key][1] += val[1] - else: - total_loss_dict[key][0] += val - total_loss_dict[key][1] += 1 - # args.consumed_valid_samples += eval_batch_size - - if args.exit_duration_in_mins: - train_time = (time.time() - training._TRAIN_START_TIME) / 60.0 - done_cuda = torch.tensor([train_time > args.exit_duration_in_mins], dtype=torch.int, device='cuda') - torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX) - done = done_cuda.item() - if done: - rerun_state_machine.set_mode(rerun_mode) - print_rank_0('Exiting during evaluation, timelimit reached') - return None, None, True - - collected_non_loss_data = None - if non_loss_data_func is not None: - collected_non_loss_data = non_loss_data_func(model) - elif process_non_loss_data_func is not None and is_last_rank(): - collected_non_loss_data = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=data_iterator, - model=model, - num_microbatches=get_num_microbatches(), - seq_length=args.seq_length, - micro_batch_size=args.micro_batch_size, - forward_only=True, - collect_non_loss_data=True, - ) - - # Move model back to the train mode. - for model_module in model: - model_module.train() - - for key in total_loss_dict: - numerator, denominator = total_loss_dict[key] - total_loss_dict[key] = numerator / denominator - if self.eval_metrics is not None: - metric = self.eval_metrics.compute() - for k, v in metric.items(): - total_loss_dict[k] = v if isinstance(v, torch.Tensor) else torch.tensor(v) - self.eval_metrics.reset() - timers('evaluate').stop() - timers.log(['evaluate']) - self.custom_log(total_loss_dict, 'eval') - rerun_state_machine.set_mode(rerun_mode) - if is_last_rank(): - logs = {} - for key, val in total_loss_dict.items(): - logs[f'eval_{key}'] = round(val.item(), 8) - self.jsonl_writer.append(logs) - return total_loss_dict, collected_non_loss_data, False - - def _evaluate_and_print_results( - self, - prefix, - forward_step_func, - data_iterator, - model, - iteration, - process_non_loss_data_func, - config, - verbose=False, - write_to_tensorboard=True, - non_loss_data_func=None, - ): - """Helper function to evaluate and dump results on screen.""" - - args = self.args - if write_to_tensorboard: - writer = get_tensorboard_writer() - else: - writer = None - - wandb_writer = get_wandb_writer() - - data_iterators = data_iterator if args.multiple_validation_sets else [data_iterator] - - if not args.multiple_validation_sets: - eval_iters = [args.eval_iters] - else: - eval_iters = args.eval_iters - - if args.full_validation: - assert len(eval_iters) == len(data_iterators) - - # with full validation we need to distribute eval_iters to all ranks - if mpu.get_tensor_model_parallel_rank() == 0: - eval_iters = torch.tensor(args.eval_iters, dtype=torch.long, device='cuda') - else: - eval_iters = torch.tensor([0] * len(eval_iters), dtype=torch.long, device='cuda') - torch.distributed.broadcast(eval_iters, 0) - eval_iters = eval_iters.tolist() - args.eval_iters = eval_iters[0] if not args.multiple_validation_sets else eval_iters - elif not args.multiple_validation_sets: - eval_iters = [args.eval_iters] - else: - eval_iters = args.eval_iters - - for index, (iterator, iterations) in enumerate(zip(data_iterators, eval_iters)): - suffix = '' - if args.multiple_validation_sets: - suffix = f'-{index}' - total_loss_dict, collected_non_loss_data, timelimit = self.evaluate( - forward_step_func, - iterator, - model, - process_non_loss_data_func, - config, - verbose, - non_loss_data_func, - eval_iters=iterations, - ) - # Timelimit hit during evaluation - if timelimit: - return - string = f' validation{suffix} loss at {prefix} | ' - for key in total_loss_dict: - string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) - ppl = None - if key == 'loss': - ppl = math.exp(min(20, total_loss_dict[key].item())) - string += '{} PPL: {:.6E} | '.format(key, ppl) - if writer: - writer.add_scalar('{} validation{}'.format(key, suffix), total_loss_dict[key].item(), iteration) - writer.add_scalar( - '{} validation{} vs samples'.format(key, suffix), - total_loss_dict[key].item(), - args.consumed_train_samples, - ) - if args.log_validation_ppl_to_tensorboard and ppl is not None: - writer.add_scalar('{} validation{} ppl'.format(key, suffix), ppl, iteration) - writer.add_scalar('{} validation{} ppl vs samples'.format(key, suffix), ppl, - args.consumed_train_samples) - if wandb_writer and is_last_rank(): - wandb_writer.log({'{} validation{}'.format(key, suffix): total_loss_dict[key].item()}, - iteration) - - if process_non_loss_data_func is not None and writer and is_last_rank(): - process_non_loss_data_func(collected_non_loss_data, iteration, writer) - - length = len(string) + 1 - print_rank_last('-' * length) - print_rank_last(string) - print_rank_last('-' * length) - def _get_metrics(self, total_loss_dict, mode): advanced_iters = total_loss_dict['advanced iterations'] if mode == 'train' else 1 return { @@ -821,262 +572,6 @@ def custom_log(self, total_loss_dict, mode: Literal['train', 'eval'], iteration= if wandb_writer: wandb_writer.log(metrics, iteration) - # Code borrowed from NVIDIA/Megatron-LM - def _training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale, - report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad): - """Log training information such as losses, timing, ....""" - args = self.args - timers = get_timers() - writer = get_tensorboard_writer() - wandb_writer = get_wandb_writer() - - # Advanced, skipped, and Nan iterations. - advanced_iters_key = 'advanced iterations' - skipped_iters_key = 'skipped iterations' - nan_iters_key = 'nan iterations' - # Advanced iterations. - if not skipped_iter: - total_loss_dict[advanced_iters_key] = total_loss_dict.get(advanced_iters_key, 0) + 1 - else: - if advanced_iters_key not in total_loss_dict: - total_loss_dict[advanced_iters_key] = 0 - # Skipped iterations. - total_loss_dict[skipped_iters_key] = total_loss_dict.get(skipped_iters_key, 0) + skipped_iter - # Update losses and set nan iterations - got_nan = False - for key in loss_dict: - if not skipped_iter: - total_loss_dict[key] = total_loss_dict.get(key, torch.tensor([0.0], dtype=torch.float, - device='cuda')) + loss_dict[key] - else: - value = loss_dict[key].float().sum().item() - is_nan = value == float('inf') or value == -float('inf') or value != value - got_nan = got_nan or is_nan - total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key, 0) + int(got_nan) - - # Logging. - timers_to_log = [ - 'forward-backward', 'forward-compute', 'backward-compute', 'batch-generator', 'forward-recv', - 'forward-send', 'backward-recv', 'backward-send', 'forward-send-forward-recv', 'forward-send-backward-recv', - 'backward-send-forward-recv', 'backward-send-backward-recv', 'forward-backward-send-forward-backward-recv', - 'layernorm-grads-all-reduce', 'embedding-grads-all-reduce', 'all-grads-sync', 'params-all-gather', - 'optimizer-copy-to-main-grad', 'optimizer-unscale-and-check-inf', 'optimizer-clip-main-grad', - 'optimizer-count-zeros', 'optimizer-inner-step', 'optimizer-copy-main-to-model-params', 'optimizer' - ] - - # Calculate batch size. - batch_size = args.micro_batch_size * args.data_parallel_size * get_num_microbatches() - - # Track app tag & app tag ID - one_logger_utils.track_app_tag(batch_size, args.world_size, args.seq_length) - - total_iterations = total_loss_dict[advanced_iters_key] + total_loss_dict[skipped_iters_key] - - # learning rate will be None on ranks without trainable params, so we must gather across mp ranks - learning_rate = reduce_max_stat_across_model_parallel_group(learning_rate) - # Tensorboard values. - # Timer requires all the ranks to call. - if args.log_timers_to_tensorboard and (iteration % args.tensorboard_log_interval == 0): - timers.write(timers_to_log, writer, iteration, normalizer=total_iterations) - if writer and (iteration % args.tensorboard_log_interval == 0): - if wandb_writer: - wandb_writer.log({'samples vs steps': args.consumed_train_samples}, iteration) - writer.add_scalar('learning-rate', learning_rate, iteration) - writer.add_scalar('learning-rate vs samples', learning_rate, args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'learning-rate': learning_rate}, iteration) - if args.decoupled_lr is not None: - writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration) - if args.skipped_train_samples > 0: - writer.add_scalar('skipped-train-samples', args.skipped_train_samples, iteration) - if wandb_writer: - wandb_writer.log({'skipped-train-samples': args.skipped_train_samples}, iteration) - writer.add_scalar('batch-size', batch_size, iteration) - writer.add_scalar('batch-size vs samples', batch_size, args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'batch-size': batch_size}, iteration) - log_loss_dict = loss_dict.copy() - self._remove_log(log_loss_dict) - for key in log_loss_dict: - writer.add_scalar(key, loss_dict[key], iteration) - writer.add_scalar(key + ' vs samples', loss_dict[key], args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({key: loss_dict[key]}, iteration) - if args.log_loss_scale_to_tensorboard: - writer.add_scalar('loss-scale', loss_scale, iteration) - writer.add_scalar('loss-scale vs samples', loss_scale, args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'loss-scale': loss_scale}, iteration) - if args.log_world_size_to_tensorboard: - writer.add_scalar('world-size', args.world_size, iteration) - writer.add_scalar('world-size vs samples', args.world_size, args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'world-size': args.world_size}, iteration) - if grad_norm is not None: - writer.add_scalar('grad-norm', grad_norm, iteration) - writer.add_scalar('grad-norm vs samples', grad_norm, args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'grad-norm': grad_norm}, iteration) - if num_zeros_in_grad is not None: - writer.add_scalar('num-zeros', num_zeros_in_grad, iteration) - writer.add_scalar('num-zeros vs samples', num_zeros_in_grad, args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'num-zeros': num_zeros_in_grad}, iteration) - if params_norm is not None: - writer.add_scalar('params-norm', params_norm, iteration) - writer.add_scalar('params-norm vs samples', params_norm, args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'params-norm': params_norm}, iteration) - if args.log_memory_to_tensorboard: - mem_stats = torch.cuda.memory_stats() - writer.add_scalar( - 'mem-reserved-bytes', - mem_stats['reserved_bytes.all.current'], - iteration, - ) - writer.add_scalar( - 'mem-allocated-bytes', - mem_stats['allocated_bytes.all.current'], - iteration, - ) - writer.add_scalar( - 'mem-max-allocated-bytes', - mem_stats['allocated_bytes.all.peak'], - iteration, - ) - writer.add_scalar( - 'mem-allocated-count', - mem_stats['allocation.all.current'], - iteration, - ) - if args.num_experts is not None: - moe_loss_scale = 1 / get_num_microbatches() - track_names = [] - if args.moe_router_load_balancing_type in ['aux_loss', 'seq_aux_loss']: - track_names.append('load_balancing_loss') - if args.moe_z_loss_coeff is not None: - track_names.append('z_loss') - track_moe_kwargs = {'mtp_num_layers': args.mtp_num_layers} if self.mcore_013 else {} - track_moe_metrics( - loss_scale=moe_loss_scale, - iteration=iteration, - writer=writer, - wandb_writer=wandb_writer, - total_loss_dict=total_loss_dict, - per_layer_logging=args.moe_per_layer_logging, - force_initialize=True, - track_names=track_names, - num_layers=args.num_layers, - moe_layer_freq=args.moe_layer_freq, - **track_moe_kwargs) - if args.mtp_num_layers is not None: - mtp_loss_scale = 1 / get_num_microbatches() - MTPLossLoggingHelper.track_mtp_metrics(mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict) - if iteration % args.log_interval == 0 or iteration == 1: - self.custom_log(total_loss_dict, 'train') - origin_total_loss_dict = total_loss_dict.copy() - - if args.record_memory_history and is_last_rank(): - snapshot = torch.cuda.memory._snapshot() - from pickle import dump - with open(args.memory_snapshot_path, 'wb') as f: - dump(snapshot, f) - - elapsed_time = timers('interval-time').elapsed(barrier=True) - elapsed_time_per_iteration = elapsed_time / total_iterations - train_percentage = iteration / args.train_iters - total_elapsed_time = timers('interval-time').active_time() - memory_GiB = round(torch.cuda.max_memory_reserved() / 1024**3, 2) - remaining_time = total_elapsed_time / train_percentage - total_elapsed_time - total_elapsed_time = format_time(total_elapsed_time) - remaining_time = format_time(remaining_time) - - throughput = num_floating_point_operations(args, batch_size) / ( - elapsed_time_per_iteration * 10**12 * args.world_size) - - one_logger_utils.track_e2e_metrics(args.log_throughput, throughput) - - if args.log_timers_to_tensorboard: - if writer: - writer.add_scalar('iteration-time', elapsed_time_per_iteration, iteration) - if wandb_writer: - wandb_writer.log({'iteration-time': elapsed_time_per_iteration}, iteration) - log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]" - log_string += ' iteration {:8d}/{:8d} |'.format(iteration, args.train_iters) - log_string += ' consumed samples: {:12d} |'.format(args.consumed_train_samples) - if args.skipped_train_samples > 0: - log_string += ' skipped samples: {:12d} |'.format(args.skipped_train_samples) - log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time_per_iteration * 1000.0) - log_string += (f' memory(GiB): {memory_GiB} |' - f' elapsed time: {total_elapsed_time} | remaining time: {remaining_time} |') - if args.log_throughput: - log_string += f' throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |' - if args.log_timers_to_tensorboard: - if writer: - writer.add_scalar('throughput', throughput, iteration) - if wandb_writer: - wandb_writer.log({'throughput': throughput}, iteration) - # Decoupled_learning_rate should be not None only on first and last pipeline stage. - log_string += f' learning rate: {learning_rate:.6E} |' - if args.decoupled_lr is not None and (mpu.is_pipeline_first_stage(ignore_virtual=True) - or mpu.is_pipeline_last_stage(ignore_virtual=True)): - assert decoupled_learning_rate is not None - log_string += f' decoupled learning rate: {decoupled_learning_rate:.6E} |' - else: - assert decoupled_learning_rate is None - log_string += f' global batch size: {batch_size:5d} |' - for key in total_loss_dict: - if key not in [advanced_iters_key, skipped_iters_key, nan_iters_key]: - avg = total_loss_dict[key].item() / float(max(1, total_loss_dict[advanced_iters_key])) - log_string += ' {}: {:.6E} |'.format(key, avg) - total_loss_dict[key] = torch.tensor([0.0], dtype=torch.float, device='cuda') - log_string += f' loss scale: {loss_scale:.1f} |' - if grad_norm is not None: - log_string += f' grad norm: {grad_norm:.3f} |' - if num_zeros_in_grad is not None: - log_string += f' num zeros: {num_zeros_in_grad} |' - if params_norm is not None: - log_string += f' params norm: {params_norm:.3f} |' - log_string += ' number of skipped iterations: {:3d} |'.format(total_loss_dict[skipped_iters_key]) - log_string += ' number of nan iterations: {:3d} |'.format(total_loss_dict[nan_iters_key]) - total_loss_dict[advanced_iters_key] = 0 - total_loss_dict[skipped_iters_key] = 0 - total_loss_dict[nan_iters_key] = 0 - print_rank_last(log_string) - if report_memory_flag: - # Report memory after optimizer state has been initialized. - if torch.distributed.get_rank() == 0: - num_microbatches = get_num_microbatches() - report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True) - report_memory(f'(after {iteration} iterations)') - report_memory_flag = False - timers.log(timers_to_log, normalizer=args.log_interval) - - if is_last_rank(): - logs = {} - for key in origin_total_loss_dict: - if key not in [advanced_iters_key, skipped_iters_key, nan_iters_key]: - avg = origin_total_loss_dict[key].item() / float( - max(1, origin_total_loss_dict[advanced_iters_key])) - logs[key] = round(avg, 8) - if grad_norm is not None: - logs['grad_norm'] = round(grad_norm, 8) - if params_norm is not None: - logs['params_norm'] = round(params_norm, 8) - logs['learning_rate'] = round(learning_rate, 8) - logs['elapsed_time_per_iteration'] = round(elapsed_time_per_iteration, 8) - logs['memory(GiB)'] = memory_GiB - logs['elapsed_time'] = total_elapsed_time - logs['remaining_time'] = remaining_time - if args.log_throughput: - logs['throughput'] = round(throughput, 8) - logs['loss_scale'] = round(loss_scale, 8) - logs['consumed_samples'] = args.consumed_train_samples - logs['global_step/max_steps'] = f'{iteration}/{args.train_iters}' - self.jsonl_writer.append(logs) - - return report_memory_flag - def merge_lora_adapters(self, adapter_name='default'): """Merge LoRA adapters into base model weights for vLLM inference.""" with torch.no_grad(): @@ -1171,7 +666,9 @@ def train(self, train_dataset, val_dataset): self.call_event('on_step_end') if mpu.is_pipeline_last_stage(ignore_virtual=True): self.aggregated_metrics(metrics, train_metrics) - train_metrics['grad_norm'] = torch.tensor([grad_norm], dtype=torch.float32, device=torch.cuda.current_device()) + train_metrics['grad_norm'] = torch.tensor([grad_norm], + dtype=torch.float32, + device=torch.cuda.current_device()) if state.should_log: self.call_event('on_log', train_metrics) diff --git a/swift/trainers/patcher.py b/swift/trainers/patcher.py index 349b71c1ca..59c5482ae9 100644 --- a/swift/trainers/patcher.py +++ b/swift/trainers/patcher.py @@ -27,14 +27,11 @@ def get_max_reserved_memory() -> float: def add_train_message(logs, state, start_time, start_step) -> None: logs['global_step/max_steps'] = f'{state.global_step}/{state.max_steps}' - train_percentage = state.global_step / state.max_steps if state.max_steps else 0. - logs['percentage'] = f'{train_percentage * 100:.2f}%' elapsed = time.time() - start_time logs['elapsed_time'] = format_time(elapsed) n_steps = state.global_step - start_step train_speed = elapsed / n_steps if n_steps > 0 else 0.0 - if train_percentage != 0: - logs['remaining_time'] = format_time((state.max_steps - state.global_step) * train_speed) + logs['remaining_time'] = format_time((state.max_steps - state.global_step) * train_speed) for k, v in logs.items(): if isinstance(v, float): logs[k] = round(logs[k], 8)