From d9511a53534d8ae44b76c2d076b2c0838f934481 Mon Sep 17 00:00:00 2001 From: wanggzf Date: Thu, 13 Nov 2025 17:04:37 +0800 Subject: [PATCH 1/5] add_domain_losses --- scripts/train/examples/run_train_8B_z0_b1.sh | 21 +- src/dataset/omics_dataset.py | 58 ++ src/model/omics_one.py | 6 +- src/train.py | 47 +- src/trainer/__init__.py | 5 + src/trainer/domain_loss.py | 886 +++++++++++++++++++ src/trainer/omics_trainer.py | 6 + 7 files changed, 1014 insertions(+), 15 deletions(-) create mode 100644 src/trainer/domain_loss.py diff --git a/scripts/train/examples/run_train_8B_z0_b1.sh b/scripts/train/examples/run_train_8B_z0_b1.sh index 137e7ae..48fe26e 100755 --- a/scripts/train/examples/run_train_8B_z0_b1.sh +++ b/scripts/train/examples/run_train_8B_z0_b1.sh @@ -1,6 +1,6 @@ enable_list="multimodal model.model.embed_tokens model.model.layers model.lm_head" experiment_name="Qwen3_8B_Omics_sft_1014_all_task_test" -output_path="${experiment_name}" +output_path="/mnt/shared-storage-user/ai4agr-share/wangzhefan/molly_checkpoint/${experiment_name}" export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 @@ -12,22 +12,22 @@ export MKL_NUM_THREADS=4 options="--experiment-name $experiment_name \ --output_dir $output_path \ --text-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/Qwen3-8B \ ---dna-rna-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/nucleotide-transformer/ \ +--dna-rna-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/nucleotide-transformer-v2-500m-multi-species/ \ --dna-rna-k-tokens 1024 \ --protein-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/esm2_t33_650M_UR50D/ \ --protein-k-tokens 1024 \ --device cuda \ --train-mlp \ --train-llm \ ---train-dataset-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/data/BioMLLM/train-val-test/train_all_task_standard.parquet \ ---eval-dataset-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/data/BioMLLM/train-val-test/dev_all_task_standard.parquet \ +--train-dataset-path /mnt/shared-storage-user/wangzhefan/multimodel/data/molly/train_all_task_merged_labelled.parquet \ +--eval-dataset-path /mnt/shared-storage-user/wangzhefan/multimodel/data/molly/dev_all_task_merged_labelled.parquet \ --max-len 3072 \ --max-src-len 3072 \ --eval-max-len 3072 \ --eval-max-src-len 3072 \ --mode sft \ ---per_device_train_batch_size 1 \ ---per_device_eval_batch_size 1 \ +--per_device_train_batch_size 4 \ +--per_device_eval_batch_size 4 \ --read-nums 10240000000 \ --eval-read-nums 10240000000 \ --num_train_epochs 5 \ @@ -39,20 +39,21 @@ options="--experiment-name $experiment_name \ --eval_steps 80000 \ --eval_strategy steps \ --logging_strategy steps \ ---logging_steps 20 \ +--logging_steps 1 \ --save_trainable False \ --save-total-limit 500 \ --warmup_ratio 0.1 \ --early-stopping-patience 1000000000 \ ---gradient-accumulation-steps 2 \ +--gradient-accumulation-steps 1 \ --save_only_model \ ---attn_impl flash_attention_3 \ +--attn_impl flash_attention_2 \ --use_liger True \ --swanlab \ --swanlab-mode local \ --swanlab-team BioMLLM_report \ --swanlab-project BioMLLM \ --seed 42 \ +--compute-domain-losses True \ " # --load_best_model_at_end \ # --save_safetensors \ @@ -61,7 +62,7 @@ options="--experiment-name $experiment_name \ # --load-pretrained \ deepspeed \ ---include localhost:0,1,2,3 \ +--include localhost:0,1 \ src/train.py \ --deepspeed_config src/configs/ds_z0_config.json \ $options diff --git a/src/dataset/omics_dataset.py b/src/dataset/omics_dataset.py index c45110c..d039098 100644 --- a/src/dataset/omics_dataset.py +++ b/src/dataset/omics_dataset.py @@ -169,6 +169,50 @@ def _pretokenize_special_tokens(self): r"\s*([ACDEFGHIKLMNPQRSTVWYBXZOU]+)\s*"), } + def convert_source_to_id(self, source:str): + if 'antibody_antigen' in source: + return 0 + elif 'cpd-prom_core' in source: + return 1 + elif 'CRISPROnTarget' in source: + return 2 + elif 'emp-H' in source: + return 3 + elif 'enhancer_activity' in source: + return 4 + elif 'Fluorescence-Fluorescence' in source: + return 5 + elif 'FunctionEC-FunctionEC' in source: + return 6 + elif 'Isoform-Isoform' in source: + return 7 + elif 'MeanRibosomeLoading-MeanRibosomeLoading' in source: + return 8 + elif 'Modification-Modification' in source: + return 9 + elif 'NoncodingRNAFamily-NoncodingRNAFamily' in source: + return 10 + elif 'pd-prom_300' in source: + return 11 + elif 'ProgrammableRNASwitches-ProgrammableRNASwitches' in source: + return 12 + elif 'promoter_enhancer_interaction' in source: + return 13 + elif 'rna_protein_interaction' in source: + return 14 + elif 'Solubility-Solubility' in source: + return 15 + elif 'Stability-Stability' in source: + return 16 + elif 'Thermostability-Thermostability' in source: + return 17 + elif 'tf-h' in source: + return 18 + elif 'tf-m' in source: + return 19 + else: + return 100 + def format_raw(self, sample: pd.core.series.Series, tokenizer) -> dict: """ Format a Parquet example into DNA-LLM format suitable for processing. @@ -282,6 +326,8 @@ def format_raw(self, sample: pd.core.series.Series, tokenizer) -> dict: "omic_info_list": omic_info_list, "task": sample.get("task", ""), "label": sample.get("label", ""), + "xsource": self.convert_source_to_id(sample.get("task")), + "task_num": sample.get("task_num"), } # pylint: disable=too-many-branches @@ -366,6 +412,8 @@ def process_sample(self, sample: Dict[str, "labels": torch.LongTensor(labels), "attention_mask": torch.LongTensor(attention_mask), "cal_metric_pos": cal_metric_pos, + "xsource": torch.tensor(sample.get("xsource")), + "task_num": torch.tensor(sample.get("task_num")) } def _encode_sequence(self, seq: str, seq_type: str) -> torch.LongTensor: @@ -416,6 +464,8 @@ def qwen_omics_collate_fn(batch): cal_metric_pos = [sample.get("cal_metric_pos") for sample in batch] omic_info_lists = [sample.get("omic_info_list", []) for sample in batch] omic_ids = [sample.get("omic_ids", None) for sample in batch] + xsource = [sample.get("xsource") for sample in batch] + task_num = [sample.get("task_num") for sample in batch] input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, @@ -428,6 +478,12 @@ def qwen_omics_collate_fn(batch): padding_value=0) omic_ids = (torch.nn.utils.rnn.pad_sequence( omic_ids, batch_first=True, padding_value=1) if omic_ids else None) + # xsource = torch.tensor(xsource, dtype=torch.long) + # task_num = torch.tensor(task_num, dtype=torch.long) + # 从 [tensor(0), tensor(1), tensor(2)] -> tensor([0, 1, 2]) + xsource = torch.stack(xsource) + task_num = torch.stack(task_num) + # Pad omic_info_lists to the same length as omic_ids for i, _ in enumerate(omic_info_lists): @@ -444,6 +500,8 @@ def qwen_omics_collate_fn(batch): "omic_ids": omic_ids, "omic_info_list": omic_info_lists, "cal_metric_pos": cal_metric_pos, + "xsource": xsource, + "task_num": task_num, } diff --git a/src/model/omics_one.py b/src/model/omics_one.py index 90b14e2..1ce5532 100644 --- a/src/model/omics_one.py +++ b/src/model/omics_one.py @@ -5,9 +5,8 @@ import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, PreTrainedModel, PretrainedConfig -from transformers.modeling_outputs import CausalLMOutputWithPast from utils.tools import time_count - +from trainer import CausalLMOutputWithPast class OmicsOne(nn.Module): def __init__(self, config): @@ -148,6 +147,9 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + # 添加参数,即使forward不使用,否则会去除该字段 + xsource: Optional[torch.LongTensor] = None, + task_num: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple[torch.Tensor, ...], CausalLMOutputWithPast]: return_dict = (return_dict if return_dict is not None else diff --git a/src/train.py b/src/train.py index 4aa89b7..d673818 100644 --- a/src/train.py +++ b/src/train.py @@ -15,7 +15,8 @@ set_seed, TrainingArguments ) - +from transformers import Qwen3ForCausalLM +import types # pylint: disable=no-name-in-module # pylint: disable=too-many-branches @@ -24,7 +25,7 @@ from torch.utils.data.distributed import DistributedSampler from model import OmicsOne, get_omics_one_config -from trainer import OmicsTrainer +from trainer import OmicsTrainer, CausalLMOutputWithPast, my_inner_training_loop, my_maybe_log_save_evaluate, my_training_step, my_lce_forward from utils import ( init_swanlab_rank_0, pre_train_lora, @@ -34,6 +35,36 @@ get_current_device, ) +import transformers +import torch + +def check_versions(): + """检查关键库的版本""" + required_versions = { + 'transformers': '4.53.0', + 'torch': '2.7.0+cu128' + } + + current_versions = { + 'transformers': transformers.__version__, + 'torch': torch.__version__ + } + + # ANSI颜色代码 + YELLOW = '\033[93m' + RESET = '\033[0m' + + print("=== 版本检查 ===") + for lib, current in current_versions.items(): + required = required_versions.get(lib, 'unknown') + # 用黄色显示required版本 + print(f"{lib}: {current} (required: {YELLOW}{required}{RESET})") + + # 可以添加警告但不阻止执行 + if transformers.__version__ != required_versions['transformers']: + print(f"{YELLOW}⚠️ 警告: transformers 版本 {transformers.__version__} 可能与补丁不兼容{RESET}") + + def setup_tokenizers(args): """ Setup tokenizers for both text and DNA models. @@ -61,7 +92,6 @@ def setup_tokenizers(args): return tokenizer, dna_rna_tokenizer, protein_tokenizer - def setup_model_and_optimizer(args, tokenizer): print_rank_0("-------------------init model-------------------------") @@ -469,6 +499,10 @@ def main(): type=bool, default=True, help="If train llm") + parser.add_argument("--compute-domain-losses", + type=bool, + default=False, + help="compute domain losses during training") # Optimizer configuration parser.add_argument("--learning_rate", @@ -558,6 +592,8 @@ def main(): # Add DeepSpeed arguments parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() + if args.compute_domain_losses: + check_versions() # Setup random seed number set_seed(args.seed) @@ -646,6 +682,11 @@ def main(): tokenizer=tokenizer, data_collator=qwen_omics_collate_fn, ) + if args.compute_domain_losses: + trainer.training_step = types.MethodType(my_training_step, trainer) + trainer._maybe_log_save_evaluate = types.MethodType(my_maybe_log_save_evaluate, trainer) + trainer._inner_training_loop = types.MethodType(my_inner_training_loop, trainer) + Qwen3ForCausalLM.forward = my_lce_forward # Start training trainer.train() except Exception as e: diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py index 1d8aa71..be2172b 100644 --- a/src/trainer/__init__.py +++ b/src/trainer/__init__.py @@ -1,5 +1,10 @@ from .omics_trainer import OmicsTrainer +from .domain_loss import CausalLMOutputWithPast, my_inner_training_loop, my_maybe_log_save_evaluate, my_training_step, my_lce_forward __all__ = [ "OmicsTrainer", + "CausalLMOutputWithPast", + "my_inner_training_loop", + "my_maybe_log_save_evaluat", + "my_training_step,my_lce_forward" ] diff --git a/src/trainer/domain_loss.py b/src/trainer/domain_loss.py new file mode 100644 index 0000000..b701a6f --- /dev/null +++ b/src/trainer/domain_loss.py @@ -0,0 +1,886 @@ + +from transformers import Trainer + +from dataclasses import dataclass +from transformers.utils import ModelOutput +from typing import Optional, List, Dict, Union, Any +import torch +from transformers.cache_utils import Cache +from torch import nn +from transformers.utils import logging, is_sagemaker_mp_enabled, is_torch_xla_available +from transformers.debug_utils import DebugOption +from transformers.integrations.deepspeed import deepspeed_init +from transformers.trainer_callback import TrainerState, ExportableState +from transformers.trainer_pt_utils import get_model_param_count +import time +import contextlib +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from transformers.training_args import OptimizerNames +from transformers.utils import is_accelerate_available +from accelerate.utils import DistributedType +if is_accelerate_available(): + from transformers.utils import is_accelerate_available + + + +logger = logging.get_logger(__name__) +@dataclass +class CausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + loss: Optional[torch.FloatTensor] = None + domain_losses: Optional[List[torch.FloatTensor]] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +def my_inner_training_loop( + self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None + ): + self.accelerator.free_memory() + self._train_batch_size = batch_size + if self.args.auto_find_batch_size: + if self.state.train_batch_size != self._train_batch_size: + from accelerate.utils import release_memory + + (self.model_wrapped,) = release_memory(self.model_wrapped) + self.model_wrapped = self.model + + # Check for DeepSpeed *after* the initial pass and modify the config + if self.is_deepspeed_enabled: + # Temporarily unset `self.args.train_batch_size` + original_bs = self.args.per_device_train_batch_size + self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) + self.propagate_args_to_deepspeed(True) + self.args.per_device_train_batch_size = original_bs + self.state.train_batch_size = self._train_batch_size + logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") + # Data loader and number of training steps + train_dataloader = self.get_train_dataloader() + if self.is_fsdp_xla_v2_enabled: + train_dataloader = tpu_spmd_dataloader(train_dataloader) + + # Setting up training control variables: + # number of training epochs: num_train_epochs + # number of training steps per epoch: num_update_steps_per_epoch + # total number of training steps to execute: max_steps + total_train_batch_size = self.get_total_train_batch_size(args) + + ( + num_train_epochs, + num_update_steps_per_epoch, + num_examples, + num_train_samples, + epoch_based, + len_dataloader, + max_steps, + ) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size) + + num_train_tokens = None + if self.args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader, None if epoch_based else max_steps) + # If going by epochs, multiply tokens linearly + if len_dataloader is not None and epoch_based: + num_train_tokens *= args.num_train_epochs + # Otherwise since its steps, we just multiply by grad accum + else: + num_train_tokens *= args.gradient_accumulation_steps + + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + if self.args.n_gpu > 1: + # nn.DataParallel(model) replicates the model, creating new variables and module + # references registered here no longer work on other gpus, breaking the module + raise ValueError( + "Currently --debug underflow_overflow is not supported under DP. Please use DDP" + " (torchrun or torch.distributed.launch (deprecated))." + ) + else: + debug_overflow = DebugUnderflowOverflow(self.model) # noqa + + delay_optimizer_creation = ( + is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled or self.is_tp_enabled + ) + + # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404 + is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2) + if is_fsdp2: + delay_optimizer_creation = False + + # We need to reset the scheduler, as its parameters may be different on subsequent calls + if self._created_lr_scheduler: + self.lr_scheduler = None + self._created_lr_scheduler = False + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) + + if not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + self.state = TrainerState( + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ] + ) + self.state.is_hyper_param_search = trial is not None + self.state.train_batch_size = self._train_batch_size + + # Compute absolute values for logging, eval, and save if given as ratio + self.state.compute_steps(args, max_steps) + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) + + model = self._wrap_model(self.model_wrapped) + + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = True if model is self.model else False + + if use_accelerator_prepare and self.is_fsdp_enabled: + # In case of auto_find_batch_size=True + # Remove FSDP wrapping from sub-models. + self.model = unwrap_model(self.model, recursive=True) + + if delay_optimizer_creation: + if use_accelerator_prepare: + # configure fsdp plugin for qlora if any + self._fsdp_qlora_plugin_updates() + if self.accelerator.mixed_precision != "fp8": + self.model = self.accelerator.prepare(self.model) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + self.model.train() + if hasattr(self.lr_scheduler, "step"): + if self.use_apex: + model = self.accelerator.prepare(self.model) + else: + if delay_optimizer_creation: + self.optimizer = self.accelerator.prepare(self.optimizer) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + else: + # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + # In this case we are in DDP + LOMO, which should be supported + self.optimizer = self.accelerator.prepare(self.optimizer) + + if self.is_fsdp_enabled: + self.model = self.model_wrapped = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # ckpt loading + if resume_from_checkpoint is not None: + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint( + self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model) + ) + elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) + + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + self._load_scaler(resume_from_checkpoint) + + # important: at this point: + # self.model is the Transformers Model + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), + # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") + + self.state.epoch = 0 + start_time = time.time() + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + # Check if continuing training from a checkpoint + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + self.compare_trainer_and_checkpoint_args(self.args, self.state) + self._load_callback_state() + epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." + ) + + # Update the references + for attr in ("model", "optimizer", "lr_scheduler"): + setattr(self.callback_handler, attr, getattr(self, attr)) + self.callback_handler.train_dataloader = train_dataloader + + self.state.init_training_references(self, max_steps, num_train_epochs, trial) + + # tr_loss is a tensor to avoid synchronization of TPUs through .item() + tr_loss = torch.tensor(0.0, device=args.device) + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + model.zero_grad() + grad_norm: Optional[float] = None + learning_rate = None + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + if args.eval_on_start: + self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) + + for epoch in range(epochs_trained, num_train_epochs): + epoch_dataloader = train_dataloader + if hasattr(epoch_dataloader, "set_epoch"): + epoch_dataloader.set_epoch(epoch) + + # Reset the past mems state at the beginning of each epoch if necessary. + if args.past_index >= 0: + self._past = None + + steps_in_epoch = ( + len(epoch_dataloader) + if len_dataloader is not None + else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + + rng_to_sync = False + steps_skipped = 0 + if steps_trained_in_current_epoch > 0: + epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) + steps_skipped = steps_trained_in_current_epoch + steps_trained_in_current_epoch = 0 + rng_to_sync = True + + step = -1 + epoch_iterator = iter(epoch_dataloader) + # We chunkify the epoch iterator into gradient accumulation steps `n` batches + remainder = steps_in_epoch % args.gradient_accumulation_steps + if remainder == 0: + remainder = args.gradient_accumulation_steps + update_step = -1 + total_updates = steps_in_epoch // args.gradient_accumulation_steps + int( + remainder < args.gradient_accumulation_steps + ) + for _ in range(total_updates): + update_step += 1 + num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder + batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device) + for i, inputs in enumerate(batch_samples): + step += 1 + do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch + # Since we perform prefetching, we need to manually set sync_gradients + self.accelerator.gradient_state._set_sync_gradients(do_sync_step) + + if self.args.include_num_input_tokens_seen: + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." + ) + else: + input_tokens = inputs[main_input_name].numel() + input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) + self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + # We explicitly want to avoid relying on `accelerator.accumulate` for generation training + context = ( + functools.partial(self.accelerator.no_sync, model=model) + if i != len(batch_samples) - 1 + and self.accelerator.distributed_type != DistributedType.DEEPSPEED + else contextlib.nullcontext + ) + with context(): + # tr_loss_step = self.training_step(model, inputs, num_items_in_batch) + tr_loss_step, domain_loss_dict = self.training_step(model, inputs, num_items_in_batch) + # tr_loss_step, domain_loss_list = self.training_step(model, inputs, num_items_in_batch) + + if ( + args.logging_nan_inf_filter + and not is_torch_xla_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + if tr_loss.device != tr_loss_step.device: + raise ValueError( + f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" + ) + tr_loss = tr_loss + tr_loss_step + + self.current_flos += float(self.floating_point_ops(inputs)) + + if do_sync_step: + # Since we perform prefetching, we need to manually set sync_gradients to True + self.accelerator.gradient_state._set_sync_gradients(True) + + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + if is_sagemaker_mp_enabled() and args.fp16: + _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) + elif self.use_apex: + from apex import amp + + # Revert to normal clipping otherwise, handling Apex or full precision + _grad_norm = nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + grad_norm_context = contextlib.nullcontext + if self.is_tp_enabled: + from torch.distributed._tensor.experimental import implicit_replication + + grad_norm_context = implicit_replication + with grad_norm_context(): + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) + + if ( + is_accelerate_available() + and self.accelerator.distributed_type == DistributedType.DEEPSPEED + ): + grad_norm = model.get_global_grad_norm() + # In some cases the grad norm may not return a float + if hasattr(grad_norm, "item"): + grad_norm = grad_norm.item() + else: + grad_norm = _grad_norm + + self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) + + self.optimizer.step() + + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + + # get leaning rate before update + learning_rate = self._get_learning_rate() + + if not self.accelerator.optimizer_step_was_skipped: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + self._maybe_log_save_evaluate( + tr_loss, + grad_norm, + model, + trial, + epoch, + ignore_keys_for_eval, + start_time, + learning_rate=learning_rate, + extra=domain_loss_dict, + # extra=domain_loss_list, + ) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + + # PyTorch/XLA relies on the data loader to insert the mark_step for + # each step. Since we are breaking the loop early, we need to manually + # insert the mark_step here. + if self.control.should_epoch_stop or self.control.should_training_stop: + if is_torch_xla_available(): + xm.mark_step() + break + # We also need to break out of the nested loop + if self.control.should_epoch_stop or self.control.should_training_stop: + if is_torch_xla_available(): + xm.mark_step() + break + if step < 0: + logger.warning( + "There seems not to be a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + # self._maybe_log_save_evaluate( + # tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate + # ) + self._maybe_log_save_evaluate( + tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate, extra=domain_loss_dict + # tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate, extra=domain_loss_list + ) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + if is_torch_xla_available(): + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + else: + logger.warning( + "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + "configured. Check your training configuration if this is unexpected." + ) + if self.control.should_training_stop: + break + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + # Wait for everyone to get here so we are sure the model has been saved by process 0. + if is_torch_xla_available(): + xm.rendezvous("load_best_model_at_end") + elif args.parallel_mode == ParallelMode.DISTRIBUTED: + dist.barrier() + elif is_sagemaker_mp_enabled(): + smp.barrier() + + self._load_best_model() + + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() + effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError + train_loss = self._total_loss_scalar / effective_global_step + + metrics = speed_metrics( + "train", + start_time, + num_samples=num_train_samples, + num_steps=self.state.max_steps, + num_tokens=num_train_tokens, + ) + self.store_flos() + metrics["total_flos"] = self.state.total_flos + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self._memory_tracker.stop_and_update_metrics(metrics) + + self.log(metrics) + + run_dir = self._get_output_dir(trial) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) + + # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: + for checkpoint in checkpoints_sorted: + if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint, ignore_errors=True) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + # Wait for the checkpoint to be uploaded. + self._finish_current_push() + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None: + self._deactivate_neftune(self.model) + + return TrainOutput(self.state.global_step, train_loss, metrics) + +def my_maybe_log_save_evaluate( + self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None, extra: Dict={} + # self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None, extra: List=[] + ): + if self.control.should_log and self.state.global_step > self._globalstep_last_logged: + if is_torch_xla_available(): + xm.mark_step() + + # logs: dict[str, float] = {} + logs: Dict[str, float] = extra + + # all_gather + mean() to get average loss over all processes + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + if grad_norm is not None: + logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm + if learning_rate is not None: + logs["learning_rate"] = learning_rate + else: + logs["learning_rate"] = self._get_learning_rate() + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + + self.log(logs, start_time) + + # import pdb + # pdb.set_trace() + # for item in extra: + # logs: dict[str, float] = item + # self.log(logs, start_time) + +def my_training_step( + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + num_items_in_batch: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to train. + inputs (`dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + model.train() + if hasattr(self.optimizer, "train") and callable(self.optimizer.train): + self.optimizer.train() + + inputs = self._prepare_inputs(inputs) + if is_sagemaker_mp_enabled(): + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) + return loss_mb.reduce_mean().detach().to(self.args.device) + + with self.compute_loss_context_manager(): + # loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + loss, outputs = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch, return_outputs=True) + + task_num = inputs.get('task_num') + def build_domain_loss_dict(domain_losses, source): + values = [value.item() for value in outputs.domain_losses] + ret = {} + log_all = [] + for s, v in zip(source, values): + log = {} + source_id = s.item() + keymap = { + 0: "antibody_antigen", + 1: "cpd-prom_core", + 2: "CRISPROnTarget", + 3: "emp-H", + 4: "enhancer_activity", + 5: "Fluorescence", + 6: "FunctionEC", + 7: "Isoform", + 8: "MeanRibosomeLoading", + 9: "Modification", + 10: "NoncodingRNAFamily", + 11: "pd-prom_300", + 12: "ProgrammableRNASwitches", + 13: "promoter_enhancer_interaction", + 14: "rna_protein_interaction", + 15: "Solubility", + 16: "Stability", + 17: "Thermostability", + 18: "tf-h", + 19: "tf-m", + 100: "Other" + } + if source_id in keymap: + suffix = keymap[source_id] + else: + suffix = 'bad' + key = f'loss_{suffix}' + # 记录所有key对应的loss + log[key] = v + log_all.append(log) + # 相同key记录第一个loss 否则跳过 + if key in ret: + pass + else: + ret[key] = v + return ret, log_all + domain_loss_dict,log_all = build_domain_loss_dict(outputs.domain_losses, inputs['xsource']) + # domain_loss_list = build_domain_loss_dict(outputs.domain_losses, inputs['xsource']) + # import pdb + # pdb.set_trace() + for log,num in zip(log_all, task_num): + for key, value in log.items(): + # logger.info(f"Step {self.state.global_step}: task_num = {num}, domain = {key}, loss = {value}") + print(f"Step {self.state.global_step}: task_num = {num}, domain = {key}, loss = {value}") + + del inputs + del outputs.domain_losses + outputs.domain_losses = None + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_mlu_available(): + torch.mlu.empty_cache() + elif is_torch_musa_available(): + torch.musa.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(): + torch.mps.empty_cache() + elif is_torch_hpu_available(): + logger.warning( + "`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()." + ) + else: + torch.cuda.empty_cache() + + kwargs = {} + + # For LOMO optimizers you need to explicitly use the learnign rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.use_apex: + from apex import amp + + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss + if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None: + # loss = loss / self.args.gradient_accumulation_steps + domain_losses = [value / self.args.gradient_accumulation_steps for value in domain_losses] + + # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled + # https://github.com/huggingface/transformers/pull/35808 + if self.accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs["scale_wrt_gas"] = False + + self.accelerator.backward(loss, **kwargs) + + # return loss.detach() + return loss.detach(), domain_loss_dict + # return loss.detach(), domain_loss_list + +def my_lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3ForCausalLM + + >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + domain_losses = [] + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + loss = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + temp_logits = self.lm_head(kept_hidden_states) + batch_size = temp_logits.size(0) + for i in range(batch_size): + sample_logits = temp_logits[i].unsqueeze(0) + sample_labels = labels[i].unsqueeze(0) + sample_loss = self.loss_function( + logits=sample_logits, + labels=sample_labels, + vocab_size=self.config.vocab_size, + **kwargs + ) + domain_losses.append(sample_loss.detach()) + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + batch_size = logits.size(0) + for i in range(batch_size): + # 提取第 i 个样本的 logits 和 labels + sample_logits = logits[i] + sample_labels = labels[i] + + # 调用 ForCausalLMLoss 计算单个样本的损失 + sample_loss = self.loss_function( + logits=sample_logits.unsqueeze(0), # 增加 batch 维度 + labels=sample_labels.unsqueeze(0), # 增加 batch 维度 + vocab_size=self.config.vocab_size, + **kwargs + ) + domain_losses.append(sample_loss.detach()) + + return CausalLMOutputWithPast( + loss=loss, + domain_losses=domain_losses, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/trainer/omics_trainer.py b/src/trainer/omics_trainer.py index e5f1502..7930082 100644 --- a/src/trainer/omics_trainer.py +++ b/src/trainer/omics_trainer.py @@ -6,6 +6,9 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from utils.tools import print_rank_0 +from utils.tools import swanlab_log_rank_0 +from torch import nn +from typing import Any, Union, Optional class OmicsTrainer(Trainer): @@ -103,3 +106,6 @@ def save_model(self, output_dir=None, _internal_call=False): print_rank_0(f"Protein projector saved to {protein_projector_path}") else: super().save_model(output_dir, _internal_call) + + + From f4038844305127b40c93333d52950133963b6bcc Mon Sep 17 00:00:00 2001 From: wanggzf Date: Fri, 14 Nov 2025 14:05:37 +0800 Subject: [PATCH 2/5] add_domain_losses --- scripts/train/examples/run_train_1B_domain.sh | 71 +++++++++++++++++++ scripts/train/examples/run_train_8B_z0_b1.sh | 23 +++--- src/dataset/omics_dataset.py | 55 ++++++++++---- src/train.py | 4 +- 4 files changed, 127 insertions(+), 26 deletions(-) create mode 100644 scripts/train/examples/run_train_1B_domain.sh diff --git a/scripts/train/examples/run_train_1B_domain.sh b/scripts/train/examples/run_train_1B_domain.sh new file mode 100644 index 0000000..d0c03da --- /dev/null +++ b/scripts/train/examples/run_train_1B_domain.sh @@ -0,0 +1,71 @@ +enable_list="multimodal model.model.embed_tokens model.model.layers model.lm_head" +experiment_name="Qwen3_1.7B_Omics_sft_1014_all_task_test" +output_path="/mnt/shared-storage-user/ai4agr-share/wangzhefan/molly_checkpoint/${experiment_name}" + +export OMP_NUM_THREADS=4 +export MKL_NUM_THREADS=4 + +# export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600 +# export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 +# export NCCL_TIMEOUT=3600 + +options="--experiment-name $experiment_name \ +--output_dir $output_path \ +--text-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/Qwen3-1.7B \ +--dna-rna-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/nucleotide-transformer-v2-500m-multi-species/ \ +--dna-rna-k-tokens 1024 \ +--protein-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/esm2_t33_650M_UR50D/ \ +--protein-k-tokens 1024 \ +--device cuda \ +--train-mlp \ +--train-llm \ +--train-dataset-path /mnt/shared-storage-user/wangzhefan/multimodel/data/molly/train_all_task_merged_labelled.parquet \ +--eval-dataset-path /mnt/shared-storage-user/wangzhefan/multimodel/data/molly/dev_all_task_merged_labelled.parquet \ +--max-len 3072 \ +--max-src-len 3072 \ +--eval-max-len 3072 \ +--eval-max-src-len 3072 \ +--mode sft \ +--per_device_train_batch_size 4 \ +--per_device_eval_batch_size 4 \ +--read-nums 10240000000 \ +--eval-read-nums 10240000000 \ +--num_train_epochs 5 \ +--learning_rate 3e-5 \ +--bf16 \ +--enable-list $enable_list \ +--save_strategy steps \ +--save_steps 80000 \ +--eval_steps 80000 \ +--eval_strategy steps \ +--logging_strategy steps \ +--logging_steps 1 \ +--save_trainable False \ +--save-total-limit 500 \ +--warmup_ratio 0.1 \ +--early-stopping-patience 1000000000 \ +--gradient-accumulation-steps 1 \ +--save_only_model \ +--attn_impl flash_attention_2 \ +--use_liger True \ +--swanlab \ +--swanlab-mode local \ +--swanlab-team BioMLLM_report \ +--swanlab-project BioMLLM \ +--seed 42 \ +--compute-domain-losses True\ +" +# --load_best_model_at_end \ +# --save_safetensors \ +# --greater_is_better \ +# --use-lora +# --load-pretrained \ + +deepspeed \ +--include localhost:0 \ +src/train.py \ +--deepspeed_config src/configs/ds_z0_config.json \ +$options + + +# py-spy dump -p 11499 --locals | head -60 diff --git a/scripts/train/examples/run_train_8B_z0_b1.sh b/scripts/train/examples/run_train_8B_z0_b1.sh index 48fe26e..898981f 100755 --- a/scripts/train/examples/run_train_8B_z0_b1.sh +++ b/scripts/train/examples/run_train_8B_z0_b1.sh @@ -1,6 +1,6 @@ enable_list="multimodal model.model.embed_tokens model.model.layers model.lm_head" experiment_name="Qwen3_8B_Omics_sft_1014_all_task_test" -output_path="/mnt/shared-storage-user/ai4agr-share/wangzhefan/molly_checkpoint/${experiment_name}" +output_path="${experiment_name}" export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 @@ -12,22 +12,22 @@ export MKL_NUM_THREADS=4 options="--experiment-name $experiment_name \ --output_dir $output_path \ --text-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/Qwen3-8B \ ---dna-rna-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/nucleotide-transformer-v2-500m-multi-species/ \ +--dna-rna-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/nucleotide-transformer/ \ --dna-rna-k-tokens 1024 \ --protein-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/esm2_t33_650M_UR50D/ \ --protein-k-tokens 1024 \ --device cuda \ --train-mlp \ --train-llm \ ---train-dataset-path /mnt/shared-storage-user/wangzhefan/multimodel/data/molly/train_all_task_merged_labelled.parquet \ ---eval-dataset-path /mnt/shared-storage-user/wangzhefan/multimodel/data/molly/dev_all_task_merged_labelled.parquet \ +--train-dataset-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/data/BioMLLM/train-val-test/train_all_task_standard.parquet \ +--eval-dataset-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/data/BioMLLM/train-val-test/dev_all_task_standard.parquet \ --max-len 3072 \ --max-src-len 3072 \ --eval-max-len 3072 \ --eval-max-src-len 3072 \ --mode sft \ ---per_device_train_batch_size 4 \ ---per_device_eval_batch_size 4 \ +--per_device_train_batch_size 1 \ +--per_device_eval_batch_size 1 \ --read-nums 10240000000 \ --eval-read-nums 10240000000 \ --num_train_epochs 5 \ @@ -39,21 +39,20 @@ options="--experiment-name $experiment_name \ --eval_steps 80000 \ --eval_strategy steps \ --logging_strategy steps \ ---logging_steps 1 \ +--logging_steps 20 \ --save_trainable False \ --save-total-limit 500 \ --warmup_ratio 0.1 \ --early-stopping-patience 1000000000 \ ---gradient-accumulation-steps 1 \ +--gradient-accumulation-steps 2 \ --save_only_model \ ---attn_impl flash_attention_2 \ +--attn_impl flash_attention_3 \ --use_liger True \ --swanlab \ --swanlab-mode local \ --swanlab-team BioMLLM_report \ --swanlab-project BioMLLM \ --seed 42 \ ---compute-domain-losses True \ " # --load_best_model_at_end \ # --save_safetensors \ @@ -62,10 +61,10 @@ options="--experiment-name $experiment_name \ # --load-pretrained \ deepspeed \ ---include localhost:0,1 \ +--include localhost:0,1,2,3 \ src/train.py \ --deepspeed_config src/configs/ds_z0_config.json \ $options -# py-spy dump -p 11499 --locals | head -60 +# py-spy dump -p 11499 --locals | head -60 \ No newline at end of file diff --git a/src/dataset/omics_dataset.py b/src/dataset/omics_dataset.py index d039098..c0c68c6 100644 --- a/src/dataset/omics_dataset.py +++ b/src/dataset/omics_dataset.py @@ -40,6 +40,7 @@ def __init__( dna_rna_tokenizer=None, protein_tokenizer=None, read_nums=None, + compute_domain_losses = False, shuffle=False, seed=42, type=None, @@ -54,6 +55,7 @@ def __init__( dataset_config: Configuration for the dataset. dna_rna_tokenizer: Tokenizer for DNA/RNA sequences. read_nums: Maximum number of samples to read. + compute_domain_losses: Calculate domain losses. shuffle: Whether to shuffle the dataset. seed: Random seed for shuffling. type: Dataset type. "Train / Eval" or "Test" @@ -66,6 +68,7 @@ def __init__( self.tokenizer = tokenizer self.dna_rna_tokenizer = dna_rna_tokenizer self.protein_tokenizer = protein_tokenizer + self.compute_domain_losses = compute_domain_losses self.dataset_config = dataset_config self.shuffle = shuffle self.seed = seed @@ -318,7 +321,8 @@ def format_raw(self, sample: pd.core.series.Series, tokenizer) -> dict: "raw_input": input_text, "raw_output": output_text, } - return { + elif self.compute_domain_losses: + return { "input_ids": input_ids, "output_ids": output_ids, "reasoning_token_ids": reasoning_ids, @@ -329,6 +333,15 @@ def format_raw(self, sample: pd.core.series.Series, tokenizer) -> dict: "xsource": self.convert_source_to_id(sample.get("task")), "task_num": sample.get("task_num"), } + return { + "input_ids": input_ids, + "output_ids": output_ids, + "reasoning_token_ids": reasoning_ids, + "omic_ids_list": omic_ids_list, + "omic_info_list": omic_info_list, + "task": sample.get("task", ""), + "label": sample.get("label", ""), + } # pylint: disable=too-many-branches def process_sample(self, sample: Dict[str, @@ -405,7 +418,8 @@ def process_sample(self, sample: Dict[str, labels.extend([-100] * pad_len) attention_mask.extend([0] * pad_len) - return { + if self.compute_domain_losses: + return { "input_ids": torch.LongTensor(input_ids), "omic_ids": torch.stack(sample["omic_ids_list"]), "omic_info_list": sample["omic_info_list"], @@ -415,6 +429,14 @@ def process_sample(self, sample: Dict[str, "xsource": torch.tensor(sample.get("xsource")), "task_num": torch.tensor(sample.get("task_num")) } + return { + "input_ids": torch.LongTensor(input_ids), + "omic_ids": torch.stack(sample["omic_ids_list"]), + "omic_info_list": sample["omic_info_list"], + "labels": torch.LongTensor(labels), + "attention_mask": torch.LongTensor(attention_mask), + "cal_metric_pos": cal_metric_pos, + } def _encode_sequence(self, seq: str, seq_type: str) -> torch.LongTensor: """ @@ -446,7 +468,8 @@ def _encode_sequence(self, seq: str, seq_type: str) -> torch.LongTensor: return encoding["input_ids"].squeeze(0) -def qwen_omics_collate_fn(batch): +# def qwen_omics_collate_fn(batch): +def qwen_omics_collate_fn(batch, args): """ Collate function for DataLoader with multimodal DNA batches. Handles variable length DNA sequences and attention masks. @@ -457,15 +480,13 @@ def qwen_omics_collate_fn(batch): Returns: Batched tensors suitable for model input """ - input_ids = [sample["input_ids"] for sample in batch] labels = [sample["labels"] for sample in batch] attention_mask = [sample["attention_mask"] for sample in batch] cal_metric_pos = [sample.get("cal_metric_pos") for sample in batch] omic_info_lists = [sample.get("omic_info_list", []) for sample in batch] omic_ids = [sample.get("omic_ids", None) for sample in batch] - xsource = [sample.get("xsource") for sample in batch] - task_num = [sample.get("task_num") for sample in batch] + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, @@ -478,11 +499,6 @@ def qwen_omics_collate_fn(batch): padding_value=0) omic_ids = (torch.nn.utils.rnn.pad_sequence( omic_ids, batch_first=True, padding_value=1) if omic_ids else None) - # xsource = torch.tensor(xsource, dtype=torch.long) - # task_num = torch.tensor(task_num, dtype=torch.long) - # 从 [tensor(0), tensor(1), tensor(2)] -> tensor([0, 1, 2]) - xsource = torch.stack(xsource) - task_num = torch.stack(task_num) # Pad omic_info_lists to the same length as omic_ids @@ -493,6 +509,21 @@ def qwen_omics_collate_fn(batch): "start": -1 }] * (omic_ids.shape[1] - len(omic_info_lists[i]))) + if args.compute_domain_losses: + xsource = [sample.get("xsource") for sample in batch] + task_num = [sample.get("task_num") for sample in batch] + xsource = torch.stack(xsource) + task_num = torch.stack(task_num) + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "omic_ids": omic_ids, + "omic_info_list": omic_info_lists, + "cal_metric_pos": cal_metric_pos, + "xsource": xsource, + "task_num": task_num, + } return { "input_ids": input_ids, "labels": labels, @@ -500,8 +531,6 @@ def qwen_omics_collate_fn(batch): "omic_ids": omic_ids, "omic_info_list": omic_info_lists, "cal_metric_pos": cal_metric_pos, - "xsource": xsource, - "task_num": task_num, } diff --git a/src/train.py b/src/train.py index d673818..fd0129b 100644 --- a/src/train.py +++ b/src/train.py @@ -198,6 +198,7 @@ def setup_dataset(args, tokenizer, dna_rna_tokenizer, protein_tokenizer): dna_rna_tokenizer=dna_rna_tokenizer, protein_tokenizer=protein_tokenizer, read_nums=args.read_nums, + compute_domain_losses = args.compute_domain_losses, shuffle=True, seed=args.seed, type="Train", @@ -680,7 +681,8 @@ def main(): train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, - data_collator=qwen_omics_collate_fn, + # data_collator=qwen_omics_collate_fn, + data_collator=lambda features: qwen_omics_collate_fn(features, args) ) if args.compute_domain_losses: trainer.training_step = types.MethodType(my_training_step, trainer) From b0038e95ecd325d08798733d55f1e732e99bbd26 Mon Sep 17 00:00:00 2001 From: wanggzf Date: Fri, 14 Nov 2025 14:20:53 +0800 Subject: [PATCH 3/5] add_domain_losses --- src/trainer/omics_trainer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/trainer/omics_trainer.py b/src/trainer/omics_trainer.py index 7930082..e5f1502 100644 --- a/src/trainer/omics_trainer.py +++ b/src/trainer/omics_trainer.py @@ -6,9 +6,6 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from utils.tools import print_rank_0 -from utils.tools import swanlab_log_rank_0 -from torch import nn -from typing import Any, Union, Optional class OmicsTrainer(Trainer): @@ -106,6 +103,3 @@ def save_model(self, output_dir=None, _internal_call=False): print_rank_0(f"Protein projector saved to {protein_projector_path}") else: super().save_model(output_dir, _internal_call) - - - From 11684417ab2361cbcd697d97a2ac96f8a638ed84 Mon Sep 17 00:00:00 2001 From: wanggzf Date: Tue, 18 Nov 2025 16:02:41 +0800 Subject: [PATCH 4/5] add_domain_losses --- requirements.txt | 1 + src/dataset/omics_dataset.py | 10 +++++----- src/model/omics_one.py | 2 +- src/train.py | 11 +++-------- src/trainer/domain_loss.py | 24 ++++++++++++------------ 5 files changed, 22 insertions(+), 26 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7f01692..af52024 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ packaging ninja swanlab[dashboard] liger-kernel +colorama diff --git a/src/dataset/omics_dataset.py b/src/dataset/omics_dataset.py index c0c68c6..36a86e8 100644 --- a/src/dataset/omics_dataset.py +++ b/src/dataset/omics_dataset.py @@ -330,7 +330,7 @@ def format_raw(self, sample: pd.core.series.Series, tokenizer) -> dict: "omic_info_list": omic_info_list, "task": sample.get("task", ""), "label": sample.get("label", ""), - "xsource": self.convert_source_to_id(sample.get("task")), + "task_label": self.convert_source_to_id(sample.get("task")), "task_num": sample.get("task_num"), } return { @@ -426,7 +426,7 @@ def process_sample(self, sample: Dict[str, "labels": torch.LongTensor(labels), "attention_mask": torch.LongTensor(attention_mask), "cal_metric_pos": cal_metric_pos, - "xsource": torch.tensor(sample.get("xsource")), + "task_label": torch.tensor(sample.get("task_label")), "task_num": torch.tensor(sample.get("task_num")) } return { @@ -510,9 +510,9 @@ def qwen_omics_collate_fn(batch, args): }] * (omic_ids.shape[1] - len(omic_info_lists[i]))) if args.compute_domain_losses: - xsource = [sample.get("xsource") for sample in batch] + task_label = [sample.get("task_label") for sample in batch] task_num = [sample.get("task_num") for sample in batch] - xsource = torch.stack(xsource) + task_label = torch.stack(task_label) task_num = torch.stack(task_num) return { "input_ids": input_ids, @@ -521,7 +521,7 @@ def qwen_omics_collate_fn(batch, args): "omic_ids": omic_ids, "omic_info_list": omic_info_lists, "cal_metric_pos": cal_metric_pos, - "xsource": xsource, + "task_label": task_label, "task_num": task_num, } return { diff --git a/src/model/omics_one.py b/src/model/omics_one.py index 1ce5532..2bad6a2 100644 --- a/src/model/omics_one.py +++ b/src/model/omics_one.py @@ -148,7 +148,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, # 添加参数,即使forward不使用,否则会去除该字段 - xsource: Optional[torch.LongTensor] = None, + task_label: Optional[torch.LongTensor] = None, task_num: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple[torch.Tensor, ...], CausalLMOutputWithPast]: diff --git a/src/train.py b/src/train.py index fd0129b..ed96f6c 100644 --- a/src/train.py +++ b/src/train.py @@ -37,6 +37,7 @@ import transformers import torch +from colorama import Fore, Style def check_versions(): """检查关键库的版本""" @@ -50,19 +51,13 @@ def check_versions(): 'torch': torch.__version__ } - # ANSI颜色代码 - YELLOW = '\033[93m' - RESET = '\033[0m' - print("=== 版本检查 ===") for lib, current in current_versions.items(): required = required_versions.get(lib, 'unknown') - # 用黄色显示required版本 - print(f"{lib}: {current} (required: {YELLOW}{required}{RESET})") + print(f"{lib}: {current} (required: {Fore.YELLOW}{required}{Style.RESET_ALL})") - # 可以添加警告但不阻止执行 if transformers.__version__ != required_versions['transformers']: - print(f"{YELLOW}⚠️ 警告: transformers 版本 {transformers.__version__} 可能与补丁不兼容{RESET}") + print(f"{Fore.YELLOW}⚠️ 警告: transformers 版本 {transformers.__version__} 可能与补丁不兼容{Style.RESET_ALL}") def setup_tokenizers(args): diff --git a/src/trainer/domain_loss.py b/src/trainer/domain_loss.py index b701a6f..76dcbe7 100644 --- a/src/trainer/domain_loss.py +++ b/src/trainer/domain_loss.py @@ -653,19 +653,19 @@ def build_domain_loss_dict(domain_losses, source): 2: "CRISPROnTarget", 3: "emp-H", 4: "enhancer_activity", - 5: "Fluorescence", - 6: "FunctionEC", - 7: "Isoform", - 8: "MeanRibosomeLoading", - 9: "Modification", - 10: "NoncodingRNAFamily", + 5: "Fluorescence-Fluorescence", + 6: "FunctionEC-FunctionEC", + 7: "Isoform-Isoform", + 8: "MeanRibosomeLoading-MeanRibosomeLoading", + 9: "Modification-Modification", + 10: "NoncodingRNAFamily-NoncodingRNAFamily", 11: "pd-prom_300", - 12: "ProgrammableRNASwitches", + 12: "ProgrammableRNASwitches-ProgrammableRNASwitches", 13: "promoter_enhancer_interaction", 14: "rna_protein_interaction", - 15: "Solubility", - 16: "Stability", - 17: "Thermostability", + 15: "Solubility-Solubility", + 16: "Stability-Stability", + 17: "Thermostability-Thermostability", 18: "tf-h", 19: "tf-m", 100: "Other" @@ -684,8 +684,8 @@ def build_domain_loss_dict(domain_losses, source): else: ret[key] = v return ret, log_all - domain_loss_dict,log_all = build_domain_loss_dict(outputs.domain_losses, inputs['xsource']) - # domain_loss_list = build_domain_loss_dict(outputs.domain_losses, inputs['xsource']) + domain_loss_dict,log_all = build_domain_loss_dict(outputs.domain_losses, inputs['task_label']) + # domain_loss_list = build_domain_loss_dict(outputs.domain_losses, inputs['task_label']) # import pdb # pdb.set_trace() for log,num in zip(log_all, task_num): From 788bc4424af471c1f5db35e7c4d9a07e1c99735a Mon Sep 17 00:00:00 2001 From: wanggzf Date: Tue, 18 Nov 2025 18:27:52 +0800 Subject: [PATCH 5/5] add_domain_losses --- src/dataset/omics_dataset.py | 85 ++++++++++++------------------------ src/train.py | 3 +- 2 files changed, 28 insertions(+), 60 deletions(-) diff --git a/src/dataset/omics_dataset.py b/src/dataset/omics_dataset.py index 36a86e8..05ab0b4 100644 --- a/src/dataset/omics_dataset.py +++ b/src/dataset/omics_dataset.py @@ -40,7 +40,6 @@ def __init__( dna_rna_tokenizer=None, protein_tokenizer=None, read_nums=None, - compute_domain_losses = False, shuffle=False, seed=42, type=None, @@ -55,7 +54,6 @@ def __init__( dataset_config: Configuration for the dataset. dna_rna_tokenizer: Tokenizer for DNA/RNA sequences. read_nums: Maximum number of samples to read. - compute_domain_losses: Calculate domain losses. shuffle: Whether to shuffle the dataset. seed: Random seed for shuffling. type: Dataset type. "Train / Eval" or "Test" @@ -68,7 +66,6 @@ def __init__( self.tokenizer = tokenizer self.dna_rna_tokenizer = dna_rna_tokenizer self.protein_tokenizer = protein_tokenizer - self.compute_domain_losses = compute_domain_losses self.dataset_config = dataset_config self.shuffle = shuffle self.seed = seed @@ -321,27 +318,18 @@ def format_raw(self, sample: pd.core.series.Series, tokenizer) -> dict: "raw_input": input_text, "raw_output": output_text, } - elif self.compute_domain_losses: - return { - "input_ids": input_ids, - "output_ids": output_ids, - "reasoning_token_ids": reasoning_ids, - "omic_ids_list": omic_ids_list, - "omic_info_list": omic_info_list, - "task": sample.get("task", ""), - "label": sample.get("label", ""), - "task_label": self.convert_source_to_id(sample.get("task")), - "task_num": sample.get("task_num"), - } + return { - "input_ids": input_ids, - "output_ids": output_ids, - "reasoning_token_ids": reasoning_ids, - "omic_ids_list": omic_ids_list, - "omic_info_list": omic_info_list, - "task": sample.get("task", ""), - "label": sample.get("label", ""), - } + "input_ids": input_ids, + "output_ids": output_ids, + "reasoning_token_ids": reasoning_ids, + "omic_ids_list": omic_ids_list, + "omic_info_list": omic_info_list, + "task": sample.get("task", ""), + "label": sample.get("label", ""), + "task_label": self.convert_source_to_id(sample.get("task")), + "task_num": sample.get("task_num"), + } # pylint: disable=too-many-branches def process_sample(self, sample: Dict[str, @@ -418,25 +406,16 @@ def process_sample(self, sample: Dict[str, labels.extend([-100] * pad_len) attention_mask.extend([0] * pad_len) - if self.compute_domain_losses: - return { - "input_ids": torch.LongTensor(input_ids), - "omic_ids": torch.stack(sample["omic_ids_list"]), - "omic_info_list": sample["omic_info_list"], - "labels": torch.LongTensor(labels), - "attention_mask": torch.LongTensor(attention_mask), - "cal_metric_pos": cal_metric_pos, - "task_label": torch.tensor(sample.get("task_label")), - "task_num": torch.tensor(sample.get("task_num")) - } return { - "input_ids": torch.LongTensor(input_ids), - "omic_ids": torch.stack(sample["omic_ids_list"]), - "omic_info_list": sample["omic_info_list"], - "labels": torch.LongTensor(labels), - "attention_mask": torch.LongTensor(attention_mask), - "cal_metric_pos": cal_metric_pos, - } + "input_ids": torch.LongTensor(input_ids), + "omic_ids": torch.stack(sample["omic_ids_list"]), + "omic_info_list": sample["omic_info_list"], + "labels": torch.LongTensor(labels), + "attention_mask": torch.LongTensor(attention_mask), + "cal_metric_pos": cal_metric_pos, + "task_label": torch.tensor(sample.get("task_label")), + "task_num": torch.tensor(sample.get("task_num")) + } def _encode_sequence(self, seq: str, seq_type: str) -> torch.LongTensor: """ @@ -469,7 +448,7 @@ def _encode_sequence(self, seq: str, seq_type: str) -> torch.LongTensor: # def qwen_omics_collate_fn(batch): -def qwen_omics_collate_fn(batch, args): +def qwen_omics_collate_fn(batch): """ Collate function for DataLoader with multimodal DNA batches. Handles variable length DNA sequences and attention masks. @@ -486,7 +465,8 @@ def qwen_omics_collate_fn(batch, args): cal_metric_pos = [sample.get("cal_metric_pos") for sample in batch] omic_info_lists = [sample.get("omic_info_list", []) for sample in batch] omic_ids = [sample.get("omic_ids", None) for sample in batch] - + task_label = [sample.get("task_label") for sample in batch] + task_num = [sample.get("task_num") for sample in batch] input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, @@ -499,6 +479,8 @@ def qwen_omics_collate_fn(batch, args): padding_value=0) omic_ids = (torch.nn.utils.rnn.pad_sequence( omic_ids, batch_first=True, padding_value=1) if omic_ids else None) + task_label = torch.stack(task_label) + task_num = torch.stack(task_num) # Pad omic_info_lists to the same length as omic_ids @@ -509,21 +491,6 @@ def qwen_omics_collate_fn(batch, args): "start": -1 }] * (omic_ids.shape[1] - len(omic_info_lists[i]))) - if args.compute_domain_losses: - task_label = [sample.get("task_label") for sample in batch] - task_num = [sample.get("task_num") for sample in batch] - task_label = torch.stack(task_label) - task_num = torch.stack(task_num) - return { - "input_ids": input_ids, - "labels": labels, - "attention_mask": attention_mask, - "omic_ids": omic_ids, - "omic_info_list": omic_info_lists, - "cal_metric_pos": cal_metric_pos, - "task_label": task_label, - "task_num": task_num, - } return { "input_ids": input_ids, "labels": labels, @@ -531,6 +498,8 @@ def qwen_omics_collate_fn(batch, args): "omic_ids": omic_ids, "omic_info_list": omic_info_lists, "cal_metric_pos": cal_metric_pos, + "task_label": task_label, + "task_num": task_num, } diff --git a/src/train.py b/src/train.py index ed96f6c..d961d99 100644 --- a/src/train.py +++ b/src/train.py @@ -676,8 +676,7 @@ def main(): train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, - # data_collator=qwen_omics_collate_fn, - data_collator=lambda features: qwen_omics_collate_fn(features, args) + data_collator=qwen_omics_collate_fn, ) if args.compute_domain_losses: trainer.training_step = types.MethodType(my_training_step, trainer)