diff --git a/examples/predict.job b/examples/predict.job index f2f4be3c..091142cd 100755 --- a/examples/predict.job +++ b/examples/predict.job @@ -56,12 +56,12 @@ if [ x"${platform}" = x"slurm" ]; then node_opts="" else master_addr=${master_addr:-"127.0.0.1"} - node_opts="--nnodes=${nnodes:-1} --node_rank=${node_rank:-0}" + node_opts="+nnodes=${nnodes:-1} +node_rank=${node_rank:-0}" fi master_port=${master_port:-23456} echo "MasterAddr=${master_addr}:${master_port}" echo "===================================" -node_opts="${node_opts} --init_method=tcp://${master_addr}:${master_port}" +node_opts="${node_opts} +init_method=tcp://${master_addr}:${master_port}" ## init virtual environment if needed conda_home=${conda_home:-"${HOME}/.local/anaconda3"} @@ -81,13 +81,13 @@ runner="python" if [ x"${platform}" = x"slurm" ]; then runner="srun ${runner}" fi -${runner} ${PWD}/main.py ${node_opts} predict \ - --prefix=${CWD}/${exp}.pred${model_suffix} \ +${runner} ${PWD}/main.py predict \ + ${node_opts} \ + +prefix=${CWD}/${exp}.pred${model_suffix} \ \ - --models ${CWD}/${exp}.folding/model.pth${model_suffix} \ - --map_location=cpu \ - --model_recycles=2 \ - --model_shard_size=256 \ + +models=[${CWD}/${exp}.folding/model.pth${model_suffix}] \ + +model_recycles=2 \ + +model_shard_size=256 \ \ - --fasta_fmt=single \ + +fasta_fmt=single \ $* diff --git a/install_env.sh b/install_env.sh index 958c2f4b..ba8e5dc1 100644 --- a/install_env.sh +++ b/install_env.sh @@ -65,6 +65,7 @@ conda install -y -c nvidia/label/cuda-${cuda_version} \ conda install -y -c conda-forge \ biopython \ einops \ + hydra-core \ tensorboard \ tqdm \ && cleanup diff --git a/profold2/command/evaluator.py b/profold2/command/evaluator.py index 5153d48a..1c9f7481 100644 --- a/profold2/command/evaluator.py +++ b/profold2/command/evaluator.py @@ -5,6 +5,7 @@ for further help. """ import os +from dataclasses import dataclass, make_dataclass import logging import pickle @@ -22,7 +23,12 @@ from profold2.command import worker -def evaluate(rank, args): # pylint: disable=redefined-outer-name +@dataclass +class Args(worker.Args): + pass + + +def run(rank, args): # pylint: disable=redefined-outer-name wm = worker.WorkerModel(rank, args) feats, model = wm.load(args.model) features = FeatureBuilder(feats).to(wm.device()) @@ -323,34 +329,31 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name if __name__ == '__main__': import argparse + import hydra - parser = argparse.ArgumentParser() - - # init distributed env - parser.add_argument('--nnodes', type=int, default=None, help='number of nodes.') - parser.add_argument('--node_rank', type=int, default=0, help='rank of the node.') - parser.add_argument( - '--local_rank', type=int, default=None, help='local rank of xpu, default=None' - ) - parser.add_argument( - '--init_method', - type=str, - default='file:///tmp/profold2.dist', - help='method to initialize the process group, ' - 'default=\'file:///tmp/profold2.dist\'' + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - # output dir + parser.add_argument('-c', '--config', type=str, default=None, help='config file.') parser.add_argument( - '-o', - '--prefix', - type=str, - default='.', - help='prefix of out directory, default=\'.\'' + 'overrides', + nargs='*', + metavar='KEY=VAL', + help='override configs, see: https://hydra.cc' ) - add_arguments(parser) - parser.add_argument('-v', '--verbose', action='store_true', help='verbose') args = parser.parse_args() - - worker.main(args, evaluate) + config_dir, config_name = os.path.split( + os.path.abspath(args.config) + ) if exists(args.config) else (os.getcwd(), None) + + with hydra.initialize_config_dir( + version_base=None, config_dir=config_dir, job_name=__file__ + ): + worker.main( + make_dataclass('t', [], namespace={ + 'Args': Args, + 'run': run + }), hydra.compose(config_name, args.overrides) + ) diff --git a/profold2/command/main.py b/profold2/command/main.py index a32c932a..b8ea1854 100644 --- a/profold2/command/main.py +++ b/profold2/command/main.py @@ -4,75 +4,56 @@ ``` for further help. """ +import os import argparse +import hydra + from profold2.command import (evaluator, predictor, trainer, worker) -from profold2.utils import env +from profold2.utils import env, exists -_COMMANDS = [ - ('train', trainer.train, trainer.add_arguments), - ('evaluate', evaluator.evaluate, evaluator.add_arguments), - ('predict', predictor.predict, predictor.add_arguments), -] +_COMMANDS = [('train', trainer), ('evaluate', evaluator), ('predict', predictor)] def create_args(): formatter_class = argparse.ArgumentDefaultsHelpFormatter parser = argparse.ArgumentParser(formatter_class=formatter_class) - # distributed args - parser.add_argument( - '--nnodes', - type=int, - default=env('SLURM_NNODES', defval=None, dtype=int), - help='number of nodes.' - ) - parser.add_argument( - '--node_rank', - type=int, - default=env('SLURM_NODEID', defval=0, dtype=int), - help='rank of the node.' - ) - parser.add_argument( - '--local_rank', - type=int, - default=int(env('LOCAL_RANK', defval=0, dtype=int)), - help='local rank of xpu.' - ) - parser.add_argument( - '--init_method', - type=str, - default=None, - help='method to initialize the process group.' - ) - # command args subparsers = parser.add_subparsers(dest='command', required=True) - for cmd, _, add_arguments in _COMMANDS: + for cmd, _ in _COMMANDS: cmd_parser = subparsers.add_parser(cmd, formatter_class=formatter_class) - - # output dir cmd_parser.add_argument( - '-o', '--prefix', type=str, default='.', help='prefix of out directory.' + '-c', '--config', type=str, default=None, help='config file.' + ) + cmd_parser.add_argument( + 'overrides', + nargs='*', + metavar='KEY=VAL', + help='override configs, see: https://hydra.cc' ) - add_arguments(cmd_parser) - # verbose - cmd_parser.add_argument('-v', '--verbose', action='store_true', help='verbose') return parser.parse_args() -def create_fn(args): # pylint: disable=redefined-outer-name - for cmd, fn, _ in _COMMANDS: +def create_task(args): # pylint: disable=redefined-outer-name + for cmd, task in _COMMANDS: if cmd == args.command: - return fn + return task return None def main(): args = create_args() - work_fn = create_fn(args) - worker.main(args, work_fn) + config_dir, config_name = os.path.split( + os.path.abspath(args.config) + ) if exists(args.config) else (os.getcwd(), None) + + with hydra.initialize_config_dir( + version_base=None, config_dir=config_dir, job_name=args.command + ): + task = create_task(args) + worker.main(task, hydra.compose(config_name, args.overrides)) if __name__ == '__main__': diff --git a/profold2/command/predictor.py b/profold2/command/predictor.py index eb69f21f..f69182eb 100644 --- a/profold2/command/predictor.py +++ b/profold2/command/predictor.py @@ -5,10 +5,12 @@ for further help. """ import os +from dataclasses import dataclass, field, make_dataclass import functools import glob import json import logging +from typing import Optional import numpy as np import torch @@ -128,7 +130,33 @@ def _location_split(model_location): yield model_name, (features, model) -def predict(rank, args): # pylint: disable=redefined-outer-name +@dataclass +class Args(worker.Args): + models: list[str] = field(default_factory=list) # models to be loaded + # using[model_name=model_location] + # format + model_recycles: int = 0 # number of recycles + model_shard_size: Optional[int] = 0 # shard size in the evoformer model + map_location: str = 'cpu' # remapped to an alternative set of devices + + no_relaxer: bool = False # do NOT run relaxer + no_gpu_relax: bool = False # force to run relax on cpu + no_pth: bool = False # do NOT dump prediction headers + + data_dir: Optional[str] = None # dataset dir + data_idx: Optional[str] = None # dataset idx + add_pseudo_linker: bool = False # enable loading complex data + + fasta_files: list[str] = field(default_factory=list) # fasta files + fasta_file_list: Optional[str] = None # file listing fasta files by line + fasta_fmt: str = 'single' # single or a3m + + num_workers: int = 1 # number of workers + + max_msa_size: int = 1024 + + +def run(rank, args): # pylint: disable=redefined-outer-name model_runners = dict(_load_models(rank, args)) logging.info('Have %d models: %s', len(model_runners), list(model_runners.keys())) @@ -360,35 +388,31 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name if __name__ == '__main__': import argparse + import hydra - parser = argparse.ArgumentParser() - - # init distributed env - parser.add_argument('--nnodes', type=int, default=None, help='number of nodes.') - parser.add_argument('--node_rank', type=int, default=0, help='rank of the node.') - parser.add_argument( - '--local_rank', type=int, default=None, help='local rank of xpu, default=None' - ) - parser.add_argument( - '--init_method', - type=str, - default='file:///tmp/profold2.dist', - help='method to initialize the process group, ' - 'default=\'file:///tmp/profold2.dist\'' + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - # output dir + parser.add_argument('-c', '--config', type=str, default=None, help='config file.') parser.add_argument( - '-o', - '--prefix', - type=str, - default='.', - help='prefix of out directory, default=\'.\'' + 'overrides', + nargs='*', + metavar='KEY=VAL', + help='override configs, see: https://hydra.cc' ) - add_arguments(parser) - # verbose - parser.add_argument('-v', '--verbose', action='store_true', help='verbose') args = parser.parse_args() - - worker.main(args, predict) + config_dir, config_name = os.path.split( + os.path.abspath(args.config) + ) if exists(args.config) else (os.getcwd(), None) + + with hydra.initialize_config_dir( + version_base=None, config_dir=config_dir, job_name=__file__ + ): + worker.main( + make_dataclass('t', [], namespace={ + 'Args': Args, + 'run': run + }), hydra.compose(config_name, args.overrides) + ) diff --git a/profold2/command/trainer.py b/profold2/command/trainer.py index 83e0ae4b..d33aa042 100644 --- a/profold2/command/trainer.py +++ b/profold2/command/trainer.py @@ -7,10 +7,12 @@ import os import contextlib import copy +from dataclasses import dataclass, make_dataclass import json import logging import random import re +from typing import Any, Optional from urllib.parse import urlparse, parse_qsl import numpy as np @@ -66,13 +68,86 @@ def no_sync_ctx(cond, module): yield +@dataclass +class Args(worker.Args): + model_dim: tuple[int, int, int] = (384, 256, 128) + model_num_tokens: int = len(residue_constants.restypes_with_x) + model_evoformer_depth: int = 48 + model_evoformer_head_num: int = 8 + model_evoformer_head_dim: int = 32 + model_evoformer_accept_msa_attn: bool = True + model_evoformer_accept_frame_attn: bool = False + model_evoformer_accept_frame_update: bool = False + model_dropout: tuple[float, float] = (0.15, 0.25) + model_shard_size: Optional[int] = None + model_recycles: int = 2 # number of recycles in model + model_recycling_pos: bool = False + model_recycling_frames: bool = False + model_features: Optional[str] = None + model_headers: Optional[str] = None + model_params_requires_grad: Optional[str] = None + model_params_requires_hook: Optional[str] = None + model_params_optim_option: Optional[str] = None + + train_data: str = 'train.zip' # train dataset dir + train_idx: Optional[str] = None # train name idx + train_chain: Optional[str] = None # train dataset chain idx + train_attr: Optional[str] = None # train dataset attr idx + train_data_weights: Optional[str] = None + train_crop_probability: float = 0.0 + train_pseudo_linker_prob: float = 0.0 + train_msa_as_seq_prob: float = 0.0 + train_msa_as_seq_topn: Optional[int] = None + train_msa_as_seq_clustering: bool = False + train_msa_as_seq_min_alr: Optional[float] = None + train_msa_as_seq_min_ident: Optional[float] = None + + tuning_data: Optional[str] = None # tuning dataset dir + tuning_idx: Optional[str] = None # tuning name idx + tuning_chain: Optional[str] = None # tuning dataset chain idx + tuning_attr: Optional[str] = None # tuning dataset attr idx + tuning_data_weights: Optional[str] = None + + eval_data: Optional[str] = None # eval dataset dir + eval_idx: Optional[str] = None # eval name idx + eval_chain: Optional[str] = None # eval dataset chain idx + eval_attr: Optional[str] = None # eval dataset attr idx + + min_protein_len: int = 50 + max_protein_len: int = 1024 + min_crop_len: int = 80 + max_crop_len: int = 256 + crop_algorithm: str = 'auto' + crop_probability: float = 0.0 + data_rm_mask_prob: float = 0.0 + max_msa_size: int = 1024 + max_var_size: int = 8192 + + num_var_task: int = 1 + + num_batches: int = 100000 # number of batches + batch_size: int = 1 # batch size + num_workers: int = 1 # num of workers + prefetch_factor: int = 2 # num of batches loaded in advance by each worker + gradient_accumulate_every: int = 16 # accumulate grads every k times + learning_rate: float = 1e-3 # learning rate + amp_enabled: bool = False # enable automatic mixed precision + random_seed: Optional[int] = None # random seed + checkpoint_every: int = 100 + checkpoint_max_to_keep: int = 5 # the maximum number of checkpoints to keep + + wandb_enabled: bool = False # enable wandb for experient tracking + wandb_project: str = 'profold2' # wandb project name + wandb_mode: str = 'online' + + def preprocess(args): # pylint: disable=redefined-outer-name assert args.model_evoformer_accept_msa_attn or args.model_evoformer_accept_frame_attn # pylint: disable=line-too-long if args.checkpoint_every > 0: os.makedirs(os.path.join(args.prefix, 'checkpoints'), exist_ok=True) -def train(rank, args): # pylint: disable=redefined-outer-name +def run(rank, args): # pylint: disable=redefined-outer-name from torch.utils.tensorboard import SummaryWriter # pylint: disable=import-outside-toplevel random.seed(args.random_seed) @@ -483,408 +558,33 @@ def batch_with_coords(batch): wm.save(os.path.join(args.prefix, 'model.pth'), feats, model) -setattr(train, 'preprocess', preprocess) - - -def add_arguments(parser): # pylint: disable=redefined-outer-name - parser.add_argument( - '-t', '--train_data', type=str, default='train.zip', help='train dataset dir.' - ) - parser.add_argument('--train_idx', type=str, default=None, help='train dataset idx.') - parser.add_argument( - '--train_chain', type=str, default=None, help='train dataset chain idx.' - ) - parser.add_argument( - '--train_attr', type=str, default=None, help='train dataset attr idx.' - ) - parser.add_argument( - '--train_data_weights', - type=str, - default=None, - help='sample train data by weights.' - ) - parser.add_argument( - '-n', '--num_batches', type=int, default=100000, help='number of batches.' - ) - parser.add_argument( - '-e', '--eval_data', type=str, default=None, help='eval dataset dir.' - ) - parser.add_argument('--eval_idx', type=str, default=None, help='eval dataset idx.') - parser.add_argument( - '--eval_chain', type=str, default=None, help='eval dataset chain idx.' - ) - parser.add_argument( - '--eval_attr', type=str, default=None, help='eval dataset attr idx.' - ) - parser.add_argument('--tuning_data', type=str, default=None, help='eval dataset dir.') - parser.add_argument( - '--tuning_idx', type=str, default=None, help='tuning dataset idx.' - ) - parser.add_argument( - '--tuning_chain', type=str, default=None, help='tuning dataset chain idx.' - ) - parser.add_argument( - '--tuning_attr', type=str, default=None, help='tuning dataset attr idx.' - ) - parser.add_argument( - '--tuning_data_weights', - type=str, - default=None, - help='sample tuning data by weights.' - ) - parser.add_argument( - '--tuning_with_coords', action='store_true', help='use `coord` when tuning.' - ) - parser.add_argument( - '--min_protein_len', - type=int, - default=50, - help='filter out proteins whose lengthMIN_CROP_LEN.' - ) - parser.add_argument( - '--train_pseudo_linker_prob', - type=float, - default=0.0, - help='enable loading complex data.' - ) - parser.add_argument( - '--data_rm_mask_prob', - type=float, - default=0.0, - help='remove masked amino acid with probability DATA_RM_MASK_PROB.' - ) - parser.add_argument( - '--train_msa_as_seq_prob', - type=float, - default=0.0, - help='take msa_{i} as sequence with probability DATA_MSA_AS_SEQ_PROB.' - ) - parser.add_argument( - '--train_msa_as_seq_topn', - type=int, - default=None, - help='take msa_{i} as sequence belongs to DATA_MSA_AS_SEQ_TOPN.' - ) - parser.add_argument( - '--train_msa_as_seq_clustering', - action='store_true', - help='take msa_{i} as sequence sampling from clusters.' - ) - parser.add_argument( - '--train_msa_as_seq_min_alr', - type=float, - default=None, - help='take msa_{i} as sequence with alr <= DATA_MSA_AS_SEQ_MIN_ALR.' - ) - parser.add_argument( - '--train_msa_as_seq_min_ident', - type=float, - default=None, - help='take msa_{i} as sequence with ident <= DATA_MSA_AS_SEQ_MIN_IDENT.' - ) - parser.add_argument( - '--tuning_pseudo_linker_prob', - type=float, - default=0.0, - help='enable loading complex data.' - ) - parser.add_argument( - '--tuning_crop_probability', - type=float, - default=0.0, - help='crop protein with probability CROP_PROBABILITY when it\'s ' - 'length>MIN_CROP_LEN.' - ) - parser.add_argument( - '--tuning_msa_as_seq_prob', - type=float, - default=0.0, - help='take msa_{i} as sequence with probability DATA_MSA_AS_SEQ_PROB.' - ) - parser.add_argument( - '--tuning_msa_as_seq_topn', - type=int, - default=None, - help='take msa_{i} as sequence belongs to DATA_MSA_AS_SEQ_TOPN.' - ) - parser.add_argument( - '--tuning_msa_as_seq_clustering', - action='store_true', - help='take msa_{i} as sequence sampling from clusters.' - ) - parser.add_argument( - '--tuning_msa_as_seq_min_alr', - type=float, - default=None, - help='take msa_{i} as sequence with alr <= DATA_MSA_AS_SEQ_MIN_ALR.' - ) - parser.add_argument( - '--tuning_msa_as_seq_min_ident', - type=float, - default=None, - help='take msa_{i} as sequence with ident <= DATA_MSA_AS_SEQ_MIN_IDENT.' - ) - parser.add_argument('--random_seed', type=int, default=None, help='random seed.') - - parser.add_argument( - '--checkpoint_max_to_keep', - type=int, - default=5, - help='the maximum number of checkpoints to keep.' - ) - parser.add_argument( - '--checkpoint_every', - type=int, - default=100, - help='save a checkpoint every K times.' - ) - parser.add_argument( - '--tuning_every', type=int, default=10, help='tuning model every K times.' - ) - parser.add_argument( - '--eval_every', type=int, default=100, help='eval model every K times.' - ) - parser.add_argument( - '--gradient_accumulate_every', - type=int, - default=16, - help='accumulate grads every k times.' - ) - parser.add_argument('-b', '--batch_size', type=int, default=1, help='batch size') - parser.add_argument('--num_workers', type=int, default=1, help='number of workers.') - parser.add_argument( - '--prefetch_factor', - type=int, - default=2, - help='number of batches loaded in advance by each worker.' - ) - parser.add_argument( - '-l', '--learning_rate', type=float, default='1e-3', help='learning rate.' - ) - parser.add_argument( - '--lr_scheduler', - type=str, - default=optim.SchedulerType.CONSTANT.value, - choices=[m.value for m in optim.SchedulerType], - help='lr scheduler.' - ) - parser.add_argument( - '--lr_scheduler_warmup_steps', - type=int, - default=None, - help='num of warmup steps for lr scheduler.' - ) - parser.add_argument( - '--lr_scheduler_training_steps', - type=int, - default=None, - help='num of training steps for applying lr scheduler.' - ) - parser.add_argument( - '--lr_scheduler_eta_min', - type=float, - default=0.0, - help='eta_min for applying lr scheduler.' - ) - - parser.add_argument( - '--model_features', - type=str, - default='model_features_main.json', - help='json format features of model.' - ) - parser.add_argument( - '--model_headers', - type=str, - default='model_headers_main.json', - help='json format headers of model.' - ) - parser.add_argument( - '--model_recycles', type=int, default=2, help='number of recycles in model.' - ) - parser.add_argument( - '--model_recycling_frames', action='store_true', help='enable frame recycling.' - ) - parser.add_argument( - '--model_recycling_pos', - action='store_true', - help='enable backbone atom position recycling.' - ) - parser.add_argument( - '--model_dim', - type=int, - nargs=3, - default=(384, 256, 128), - help='dimension of model.' - ) - parser.add_argument( - '--model_num_tokens', - type=int, - default=len(residue_constants.restypes_with_x), - help='number of tokens in the model.' - ) - parser.add_argument( - '--model_evoformer_depth', - type=int, - default=1, - help='depth of evoformer in model.' - ) - parser.add_argument( - '--model_evoformer_head_num', - type=int, - default=48, - help='number of heads in evoformer model.' - ) - parser.add_argument( - '--model_evoformer_head_dim', - type=int, - default=32, - help='dimensions of each head in evoformer model.' - ) - parser.add_argument( - '--model_shard_size', - type=int, - default=None, - help='shard size in evoformer model.' - ) - parser.add_argument( - '--model_dropout', - type=float, - nargs=2, - default=(0.15, 0.25), - help='dropout of evoformer(single & pair) in model.' - ) - parser.add_argument( - '--model_evoformer_accept_msa_attn', - action='store_true', - help='enable MSATransformer in evoformer.' - ) - parser.add_argument( - '--model_evoformer_accept_frame_attn', - action='store_true', - help='enable FrameTransformer in evoformer.' - ) - parser.add_argument( - '--model_evoformer_accept_frame_update', - action='store_true', - help='enable FrameUpdater in evoformer.' - ) - parser.add_argument( - '--model_params_requires_grad', - type=str, - default=None, - help='learn partial parameters only.' - ) - parser.add_argument( - '--model_params_requires_hook', - type=str, - default=None, - help='hook partial parameters.' - ) - parser.add_argument( - '--model_params_optim_option', - type=str, - default=None, - help='optimizer arguments accepted by partial parameters only.' - ) - - parser.add_argument( - '--wandb_enabled', - action='store_true', - help='Enable wandb for experient tracking' - ) - parser.add_argument( - '--wandb_project', type=str, default='profold2', help='Wandb project name' - ) - parser.add_argument('--wandb_dir', type=str, default=None, help='Wandb dir') - parser.add_argument('--wandb_name', type=str, default=None, help='Wandb name name') - parser.add_argument( - '--wandb_mode', - type=str, - default='online', - choices=['online', 'offline', 'disabled'], - help='Wandb mode' - ) - parser.add_argument( - '--amp_enabled', action='store_true', help='enable automatic mixed precision.' - ) - - if __name__ == '__main__': import argparse + import hydra - parser = argparse.ArgumentParser() - - # init distributed env - parser.add_argument('--nnodes', type=int, default=None, help='number of nodes.') - parser.add_argument('--node_rank', type=int, default=0, help='rank of the node.') - parser.add_argument( - '--local_rank', type=int, default=None, help='local rank of xpu, default=None' - ) - parser.add_argument( - '--init_method', - type=str, - default='file:///tmp/profold2.dist', - help='method to initialize the process group, ' - 'default=\'file:///tmp/profold2.dist\'' + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - # output dir + parser.add_argument('-c', '--config', type=str, default=None, help='config file.') parser.add_argument( - '-o', - '--prefix', - type=str, - default='.', - help='prefix of out directory, default=\'.\'' + 'overrides', + nargs='*', + metavar='KEY=VAL', + help='override configs, see: https://hydra.cc' ) - add_arguments(parser) - # verbose - parser.add_argument('-v', '--verbose', action='store_true', help='verbose') args = parser.parse_args() + config_dir, config_name = os.path.split( + os.path.abspath(args.config) + ) if exists(args.config) else (os.getcwd(), None) - worker.main(args, train) + with hydra.initialize_config_dir( + version_base=None, config_dir=config_dir, job_name=__file__ + ): + worker.main( + make_dataclass('t', [], namespace={ + 'Args': Args, + 'run': run + }), hydra.compose(config_name, args.overrides) + ) diff --git a/profold2/command/worker.py b/profold2/command/worker.py index f17e6555..459a7c50 100644 --- a/profold2/command/worker.py +++ b/profold2/command/worker.py @@ -1,11 +1,15 @@ """Wrap distibuted env """ import os +from dataclasses import dataclass import functools import logging from logging.handlers import QueueHandler, QueueListener import re import resource +from typing import Optional + +from omegaconf import OmegaConf import torch import torch.multiprocessing as mp @@ -52,7 +56,7 @@ def filter(self, record): class _WorkerLogging(object): """Initialize distibuted logger """ - def __init__(self, work_fn, args): # pylint: disable=redefined-outer-name + def __init__(self, task, args): # pylint: disable=redefined-outer-name # logging os.makedirs(args.prefix, exist_ok=True) @@ -61,7 +65,7 @@ def __init__(self, work_fn, args): # pylint: disable=redefined-outer-name logging.StreamHandler(), logging.FileHandler( os.path.join( - args.prefix, f'{work_fn.__name__}_{args.node_rank}{local_rank}.log' + args.prefix, f'{task.__name__}_{args.node_rank}{local_rank}.log' ) ) ] @@ -225,14 +229,33 @@ def __call__(self, rank, args): # pylint: disable=redefined-outer-name xpu.destroy_process_group() -def main(args, fn): # pylint: disable=redefined-outer-name +@dataclass +class Args: + nnodes: Optional[int] = env('SLURM_NODEID', defval=None, dtype=int) # number of nodes + node_rank: int = env('SLURM_NODEID', defval=0, dtype=int) # rank of the node + local_rank: int = env('LOCAL_RANK', defval=0, dtype=int) # local rank of xpu + init_method: Optional[str] = None # method to initialize the process group + + prefix: str = '.' # prefix of output directory + + amp_enabled: bool = False # enable automatic mixed precision + enable_profiler: bool = False # enable profiler + enable_memory_snapshot: bool = False # enable memory snapshot + + verbose: bool = False # verbose + + +def main(task, args): # pylint: disable=redefined-outer-name + args = OmegaConf.to_object( + OmegaConf.merge(OmegaConf.structured(task.Args), args) + ) if exists(args.nnodes): mp.set_start_method('spawn', force=True) #-------------- # setup logging #-------------- - work_log = _WorkerLogging(fn, args) + work_log = _WorkerLogging(task, args) work_log.start() logging.info('-----------------') @@ -242,10 +265,10 @@ def main(args, fn): # pylint: disable=redefined-outer-name #-------------- # run fn with args #-------------- - if hasattr(fn, 'preprocess'): - fn.preprocess(args) + if hasattr(task, 'preprocess'): + task.preprocess(args) - work_fn = WorkerFunction(fn, work_log.queue) + work_fn = WorkerFunction(task.run, work_log.queue) if exists(args.nnodes): #mp.set_start_method('spawn', force=True) mp.spawn( @@ -257,8 +280,8 @@ def main(args, fn): # pylint: disable=redefined-outer-name else: work_fn(args.local_rank, args) - if hasattr(fn, 'postprocess'): - fn.postprocess(args) + if hasattr(task, 'postprocess'): + task.postprocess(args) logging.info('-----------------') logging.info('Resources(myself): %s', resource.getrusage(resource.RUSAGE_SELF))