Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ packaging
ninja
swanlab[dashboard]
liger-kernel
colorama
71 changes: 71 additions & 0 deletions scripts/train/examples/run_train_1B_domain.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
enable_list="multimodal model.model.embed_tokens model.model.layers model.lm_head"
experiment_name="Qwen3_1.7B_Omics_sft_1014_all_task_test"
output_path="/mnt/shared-storage-user/ai4agr-share/wangzhefan/molly_checkpoint/${experiment_name}"

export OMP_NUM_THREADS=4
export MKL_NUM_THREADS=4

# export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
# export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_TIMEOUT=3600

options="--experiment-name $experiment_name \
--output_dir $output_path \
--text-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/Qwen3-1.7B \
--dna-rna-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/nucleotide-transformer-v2-500m-multi-species/ \
--dna-rna-k-tokens 1024 \
--protein-model-path /mnt/shared-storage-user/ai4agr-share/lijinzhe/PreModel/esm2_t33_650M_UR50D/ \
--protein-k-tokens 1024 \
--device cuda \
--train-mlp \
--train-llm \
--train-dataset-path /mnt/shared-storage-user/wangzhefan/multimodel/data/molly/train_all_task_merged_labelled.parquet \
--eval-dataset-path /mnt/shared-storage-user/wangzhefan/multimodel/data/molly/dev_all_task_merged_labelled.parquet \
--max-len 3072 \
--max-src-len 3072 \
--eval-max-len 3072 \
--eval-max-src-len 3072 \
--mode sft \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--read-nums 10240000000 \
--eval-read-nums 10240000000 \
--num_train_epochs 5 \
--learning_rate 3e-5 \
--bf16 \
--enable-list $enable_list \
--save_strategy steps \
--save_steps 80000 \
--eval_steps 80000 \
--eval_strategy steps \
--logging_strategy steps \
--logging_steps 1 \
--save_trainable False \
--save-total-limit 500 \
--warmup_ratio 0.1 \
--early-stopping-patience 1000000000 \
--gradient-accumulation-steps 1 \
--save_only_model \
--attn_impl flash_attention_2 \
--use_liger True \
--swanlab \
--swanlab-mode local \
--swanlab-team BioMLLM_report \
--swanlab-project BioMLLM \
--seed 42 \
--compute-domain-losses True\
"
# --load_best_model_at_end \
# --save_safetensors \
# --greater_is_better \
# --use-lora
# --load-pretrained \

deepspeed \
--include localhost:0 \
src/train.py \
--deepspeed_config src/configs/ds_z0_config.json \
$options


# py-spy dump -p 11499 --locals | head -60
2 changes: 1 addition & 1 deletion scripts/train/examples/run_train_8B_z0_b1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ src/train.py \
$options


# py-spy dump -p 11499 --locals | head -60
# py-spy dump -p 11499 --locals | head -60
88 changes: 72 additions & 16 deletions src/dataset/omics_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,50 @@ def _pretokenize_special_tokens(self):
r"<protein>\s*([ACDEFGHIKLMNPQRSTVWYBXZOU]+)\s*</protein>"),
}

def convert_source_to_id(self, source:str):
if 'antibody_antigen' in source:
return 0
elif 'cpd-prom_core' in source:
return 1
elif 'CRISPROnTarget' in source:
return 2
elif 'emp-H' in source:
return 3
elif 'enhancer_activity' in source:
return 4
elif 'Fluorescence-Fluorescence' in source:
return 5
elif 'FunctionEC-FunctionEC' in source:
return 6
elif 'Isoform-Isoform' in source:
return 7
elif 'MeanRibosomeLoading-MeanRibosomeLoading' in source:
return 8
elif 'Modification-Modification' in source:
return 9
elif 'NoncodingRNAFamily-NoncodingRNAFamily' in source:
return 10
elif 'pd-prom_300' in source:
return 11
elif 'ProgrammableRNASwitches-ProgrammableRNASwitches' in source:
return 12
elif 'promoter_enhancer_interaction' in source:
return 13
elif 'rna_protein_interaction' in source:
return 14
elif 'Solubility-Solubility' in source:
return 15
elif 'Stability-Stability' in source:
return 16
elif 'Thermostability-Thermostability' in source:
return 17
elif 'tf-h' in source:
return 18
elif 'tf-m' in source:
return 19
else:
return 100

def format_raw(self, sample: pd.core.series.Series, tokenizer) -> dict:
"""
Format a Parquet example into DNA-LLM format suitable for processing.
Expand Down Expand Up @@ -274,15 +318,18 @@ def format_raw(self, sample: pd.core.series.Series, tokenizer) -> dict:
"raw_input": input_text,
"raw_output": output_text,
}

return {
"input_ids": input_ids,
"output_ids": output_ids,
"reasoning_token_ids": reasoning_ids,
"omic_ids_list": omic_ids_list,
"omic_info_list": omic_info_list,
"task": sample.get("task", ""),
"label": sample.get("label", ""),
}
"input_ids": input_ids,
"output_ids": output_ids,
"reasoning_token_ids": reasoning_ids,
"omic_ids_list": omic_ids_list,
"omic_info_list": omic_info_list,
"task": sample.get("task", ""),
"label": sample.get("label", ""),
"task_label": self.convert_source_to_id(sample.get("task")),
"task_num": sample.get("task_num"),
}

# pylint: disable=too-many-branches
def process_sample(self, sample: Dict[str,
Expand Down Expand Up @@ -360,13 +407,15 @@ def process_sample(self, sample: Dict[str,
attention_mask.extend([0] * pad_len)

return {
"input_ids": torch.LongTensor(input_ids),
"omic_ids": torch.stack(sample["omic_ids_list"]),
"omic_info_list": sample["omic_info_list"],
"labels": torch.LongTensor(labels),
"attention_mask": torch.LongTensor(attention_mask),
"cal_metric_pos": cal_metric_pos,
}
"input_ids": torch.LongTensor(input_ids),
"omic_ids": torch.stack(sample["omic_ids_list"]),
"omic_info_list": sample["omic_info_list"],
"labels": torch.LongTensor(labels),
"attention_mask": torch.LongTensor(attention_mask),
"cal_metric_pos": cal_metric_pos,
"task_label": torch.tensor(sample.get("task_label")),
"task_num": torch.tensor(sample.get("task_num"))
}

def _encode_sequence(self, seq: str, seq_type: str) -> torch.LongTensor:
"""
Expand Down Expand Up @@ -398,6 +447,7 @@ def _encode_sequence(self, seq: str, seq_type: str) -> torch.LongTensor:
return encoding["input_ids"].squeeze(0)


# def qwen_omics_collate_fn(batch):
def qwen_omics_collate_fn(batch):
"""
Collate function for DataLoader with multimodal DNA batches.
Expand All @@ -409,13 +459,14 @@ def qwen_omics_collate_fn(batch):
Returns:
Batched tensors suitable for model input
"""

input_ids = [sample["input_ids"] for sample in batch]
labels = [sample["labels"] for sample in batch]
attention_mask = [sample["attention_mask"] for sample in batch]
cal_metric_pos = [sample.get("cal_metric_pos") for sample in batch]
omic_info_lists = [sample.get("omic_info_list", []) for sample in batch]
omic_ids = [sample.get("omic_ids", None) for sample in batch]
task_label = [sample.get("task_label") for sample in batch]
task_num = [sample.get("task_num") for sample in batch]

input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
batch_first=True,
Expand All @@ -428,6 +479,9 @@ def qwen_omics_collate_fn(batch):
padding_value=0)
omic_ids = (torch.nn.utils.rnn.pad_sequence(
omic_ids, batch_first=True, padding_value=1) if omic_ids else None)
task_label = torch.stack(task_label)
task_num = torch.stack(task_num)


# Pad omic_info_lists to the same length as omic_ids
for i, _ in enumerate(omic_info_lists):
Expand All @@ -444,6 +498,8 @@ def qwen_omics_collate_fn(batch):
"omic_ids": omic_ids,
"omic_info_list": omic_info_lists,
"cal_metric_pos": cal_metric_pos,
"task_label": task_label,
"task_num": task_num,
}


Expand Down
6 changes: 4 additions & 2 deletions src/model/omics_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from utils.tools import time_count

from trainer import CausalLMOutputWithPast
class OmicsOne(nn.Module):

def __init__(self, config):
Expand Down Expand Up @@ -148,6 +147,9 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# 添加参数,即使forward不使用,否则会去除该字段
task_label: Optional[torch.LongTensor] = None,
task_num: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor, ...], CausalLMOutputWithPast]:
return_dict = (return_dict if return_dict is not None else
Expand Down
43 changes: 40 additions & 3 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
set_seed,
TrainingArguments
)

from transformers import Qwen3ForCausalLM
import types

# pylint: disable=no-name-in-module
# pylint: disable=too-many-branches
Expand All @@ -24,7 +25,7 @@

from torch.utils.data.distributed import DistributedSampler
from model import OmicsOne, get_omics_one_config
from trainer import OmicsTrainer
from trainer import OmicsTrainer, CausalLMOutputWithPast, my_inner_training_loop, my_maybe_log_save_evaluate, my_training_step, my_lce_forward
from utils import (
init_swanlab_rank_0,
pre_train_lora,
Expand All @@ -34,6 +35,31 @@
get_current_device,
)

import transformers
import torch
from colorama import Fore, Style

def check_versions():
"""检查关键库的版本"""
required_versions = {
'transformers': '4.53.0',
'torch': '2.7.0+cu128'
}

current_versions = {
'transformers': transformers.__version__,
'torch': torch.__version__
}

print("=== 版本检查 ===")
for lib, current in current_versions.items():
required = required_versions.get(lib, 'unknown')
print(f"{lib}: {current} (required: {Fore.YELLOW}{required}{Style.RESET_ALL})")

if transformers.__version__ != required_versions['transformers']:
print(f"{Fore.YELLOW}⚠️ 警告: transformers 版本 {transformers.__version__} 可能与补丁不兼容{Style.RESET_ALL}")


def setup_tokenizers(args):
"""
Setup tokenizers for both text and DNA models.
Expand Down Expand Up @@ -61,7 +87,6 @@ def setup_tokenizers(args):

return tokenizer, dna_rna_tokenizer, protein_tokenizer


def setup_model_and_optimizer(args, tokenizer):
print_rank_0("-------------------init model-------------------------")

Expand Down Expand Up @@ -168,6 +193,7 @@ def setup_dataset(args, tokenizer, dna_rna_tokenizer, protein_tokenizer):
dna_rna_tokenizer=dna_rna_tokenizer,
protein_tokenizer=protein_tokenizer,
read_nums=args.read_nums,
compute_domain_losses = args.compute_domain_losses,
shuffle=True,
seed=args.seed,
type="Train",
Expand Down Expand Up @@ -469,6 +495,10 @@ def main():
type=bool,
default=True,
help="If train llm")
parser.add_argument("--compute-domain-losses",
type=bool,
default=False,
help="compute domain losses during training")

# Optimizer configuration
parser.add_argument("--learning_rate",
Expand Down Expand Up @@ -558,6 +588,8 @@ def main():
# Add DeepSpeed arguments
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
if args.compute_domain_losses:
check_versions()

# Setup random seed number
set_seed(args.seed)
Expand Down Expand Up @@ -646,6 +678,11 @@ def main():
tokenizer=tokenizer,
data_collator=qwen_omics_collate_fn,
)
if args.compute_domain_losses:
trainer.training_step = types.MethodType(my_training_step, trainer)
trainer._maybe_log_save_evaluate = types.MethodType(my_maybe_log_save_evaluate, trainer)
trainer._inner_training_loop = types.MethodType(my_inner_training_loop, trainer)
Qwen3ForCausalLM.forward = my_lce_forward
# Start training
trainer.train()
except Exception as e:
Expand Down
5 changes: 5 additions & 0 deletions src/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from .omics_trainer import OmicsTrainer
from .domain_loss import CausalLMOutputWithPast, my_inner_training_loop, my_maybe_log_save_evaluate, my_training_step, my_lce_forward

__all__ = [
"OmicsTrainer",
"CausalLMOutputWithPast",
"my_inner_training_loop",
"my_maybe_log_save_evaluat",
"my_training_step,my_lce_forward"
]
Loading
Loading