Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions examples/predict.job
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ if [ x"${platform}" = x"slurm" ]; then
node_opts=""
else
master_addr=${master_addr:-"127.0.0.1"}
node_opts="--nnodes=${nnodes:-1} --node_rank=${node_rank:-0}"
node_opts="+nnodes=${nnodes:-1} +node_rank=${node_rank:-0}"
fi
master_port=${master_port:-23456}
echo "MasterAddr=${master_addr}:${master_port}"
echo "==================================="
node_opts="${node_opts} --init_method=tcp://${master_addr}:${master_port}"
node_opts="${node_opts} +init_method=tcp://${master_addr}:${master_port}"

## init virtual environment if needed
conda_home=${conda_home:-"${HOME}/.local/anaconda3"}
Expand All @@ -81,13 +81,13 @@ runner="python"
if [ x"${platform}" = x"slurm" ]; then
runner="srun ${runner}"
fi
${runner} ${PWD}/main.py ${node_opts} predict \
--prefix=${CWD}/${exp}.pred${model_suffix} \
${runner} ${PWD}/main.py predict \
${node_opts} \
+prefix=${CWD}/${exp}.pred${model_suffix} \
\
--models ${CWD}/${exp}.folding/model.pth${model_suffix} \
--map_location=cpu \
--model_recycles=2 \
--model_shard_size=256 \
+models=[${CWD}/${exp}.folding/model.pth${model_suffix}] \
+model_recycles=2 \
+model_shard_size=256 \
\
--fasta_fmt=single \
+fasta_fmt=single \
$*
1 change: 1 addition & 0 deletions install_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ conda install -y -c nvidia/label/cuda-${cuda_version} \
conda install -y -c conda-forge \
biopython \
einops \
hydra-core \
tensorboard \
tqdm \
&& cleanup
53 changes: 28 additions & 25 deletions profold2/command/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
for further help.
"""
import os
from dataclasses import dataclass, make_dataclass
import logging
import pickle

Expand All @@ -22,7 +23,12 @@
from profold2.command import worker


def evaluate(rank, args): # pylint: disable=redefined-outer-name
@dataclass
class Args(worker.Args):
pass


def run(rank, args): # pylint: disable=redefined-outer-name
wm = worker.WorkerModel(rank, args)
feats, model = wm.load(args.model)
features = FeatureBuilder(feats).to(wm.device())
Expand Down Expand Up @@ -323,34 +329,31 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name

if __name__ == '__main__':
import argparse
import hydra

parser = argparse.ArgumentParser()

# init distributed env
parser.add_argument('--nnodes', type=int, default=None, help='number of nodes.')
parser.add_argument('--node_rank', type=int, default=0, help='rank of the node.')
parser.add_argument(
'--local_rank', type=int, default=None, help='local rank of xpu, default=None'
)
parser.add_argument(
'--init_method',
type=str,
default='file:///tmp/profold2.dist',
help='method to initialize the process group, '
'default=\'file:///tmp/profold2.dist\''
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

# output dir
parser.add_argument('-c', '--config', type=str, default=None, help='config file.')
parser.add_argument(
'-o',
'--prefix',
type=str,
default='.',
help='prefix of out directory, default=\'.\''
'overrides',
nargs='*',
metavar='KEY=VAL',
help='override configs, see: https://hydra.cc'
)
add_arguments(parser)
parser.add_argument('-v', '--verbose', action='store_true', help='verbose')

args = parser.parse_args()

worker.main(args, evaluate)
config_dir, config_name = os.path.split(
os.path.abspath(args.config)
) if exists(args.config) else (os.getcwd(), None)

with hydra.initialize_config_dir(
version_base=None, config_dir=config_dir, job_name=__file__
):
worker.main(
make_dataclass('t', [], namespace={
'Args': Args,
'run': run
}), hydra.compose(config_name, args.overrides)
)
69 changes: 25 additions & 44 deletions profold2/command/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,75 +4,56 @@
```
for further help.
"""
import os
import argparse

import hydra

from profold2.command import (evaluator, predictor, trainer, worker)
from profold2.utils import env
from profold2.utils import env, exists

_COMMANDS = [
('train', trainer.train, trainer.add_arguments),
('evaluate', evaluator.evaluate, evaluator.add_arguments),
('predict', predictor.predict, predictor.add_arguments),
]
_COMMANDS = [('train', trainer), ('evaluate', evaluator), ('predict', predictor)]


def create_args():
formatter_class = argparse.ArgumentDefaultsHelpFormatter
parser = argparse.ArgumentParser(formatter_class=formatter_class)

# distributed args
parser.add_argument(
'--nnodes',
type=int,
default=env('SLURM_NNODES', defval=None, dtype=int),
help='number of nodes.'
)
parser.add_argument(
'--node_rank',
type=int,
default=env('SLURM_NODEID', defval=0, dtype=int),
help='rank of the node.'
)
parser.add_argument(
'--local_rank',
type=int,
default=int(env('LOCAL_RANK', defval=0, dtype=int)),
help='local rank of xpu.'
)
parser.add_argument(
'--init_method',
type=str,
default=None,
help='method to initialize the process group.'
)

# command args
subparsers = parser.add_subparsers(dest='command', required=True)
for cmd, _, add_arguments in _COMMANDS:
for cmd, _ in _COMMANDS:
cmd_parser = subparsers.add_parser(cmd, formatter_class=formatter_class)

# output dir
cmd_parser.add_argument(
'-o', '--prefix', type=str, default='.', help='prefix of out directory.'
'-c', '--config', type=str, default=None, help='config file.'
)
cmd_parser.add_argument(
'overrides',
nargs='*',
metavar='KEY=VAL',
help='override configs, see: https://hydra.cc'
)
add_arguments(cmd_parser)
# verbose
cmd_parser.add_argument('-v', '--verbose', action='store_true', help='verbose')

return parser.parse_args()


def create_fn(args): # pylint: disable=redefined-outer-name
for cmd, fn, _ in _COMMANDS:
def create_task(args): # pylint: disable=redefined-outer-name
for cmd, task in _COMMANDS:
if cmd == args.command:
return fn
return task
return None


def main():
args = create_args()
work_fn = create_fn(args)
worker.main(args, work_fn)
config_dir, config_name = os.path.split(
os.path.abspath(args.config)
) if exists(args.config) else (os.getcwd(), None)

with hydra.initialize_config_dir(
version_base=None, config_dir=config_dir, job_name=args.command
):
task = create_task(args)
worker.main(task, hydra.compose(config_name, args.overrides))


if __name__ == '__main__':
Expand Down
76 changes: 50 additions & 26 deletions profold2/command/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
for further help.
"""
import os
from dataclasses import dataclass, field, make_dataclass
import functools
import glob
import json
import logging
from typing import Optional

import numpy as np
import torch
Expand Down Expand Up @@ -128,7 +130,33 @@ def _location_split(model_location):
yield model_name, (features, model)


def predict(rank, args): # pylint: disable=redefined-outer-name
@dataclass
class Args(worker.Args):
models: list[str] = field(default_factory=list) # models to be loaded
# using[model_name=model_location]
# format
model_recycles: int = 0 # number of recycles
model_shard_size: Optional[int] = 0 # shard size in the evoformer model
map_location: str = 'cpu' # remapped to an alternative set of devices

no_relaxer: bool = False # do NOT run relaxer
no_gpu_relax: bool = False # force to run relax on cpu
no_pth: bool = False # do NOT dump prediction headers

data_dir: Optional[str] = None # dataset dir
data_idx: Optional[str] = None # dataset idx
add_pseudo_linker: bool = False # enable loading complex data

fasta_files: list[str] = field(default_factory=list) # fasta files
fasta_file_list: Optional[str] = None # file listing fasta files by line
fasta_fmt: str = 'single' # single or a3m

num_workers: int = 1 # number of workers

max_msa_size: int = 1024


def run(rank, args): # pylint: disable=redefined-outer-name
model_runners = dict(_load_models(rank, args))
logging.info('Have %d models: %s', len(model_runners), list(model_runners.keys()))

Expand Down Expand Up @@ -360,35 +388,31 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name

if __name__ == '__main__':
import argparse
import hydra

parser = argparse.ArgumentParser()

# init distributed env
parser.add_argument('--nnodes', type=int, default=None, help='number of nodes.')
parser.add_argument('--node_rank', type=int, default=0, help='rank of the node.')
parser.add_argument(
'--local_rank', type=int, default=None, help='local rank of xpu, default=None'
)
parser.add_argument(
'--init_method',
type=str,
default='file:///tmp/profold2.dist',
help='method to initialize the process group, '
'default=\'file:///tmp/profold2.dist\''
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

# output dir
parser.add_argument('-c', '--config', type=str, default=None, help='config file.')
parser.add_argument(
'-o',
'--prefix',
type=str,
default='.',
help='prefix of out directory, default=\'.\''
'overrides',
nargs='*',
metavar='KEY=VAL',
help='override configs, see: https://hydra.cc'
)
add_arguments(parser)
# verbose
parser.add_argument('-v', '--verbose', action='store_true', help='verbose')

args = parser.parse_args()

worker.main(args, predict)
config_dir, config_name = os.path.split(
os.path.abspath(args.config)
) if exists(args.config) else (os.getcwd(), None)

with hydra.initialize_config_dir(
version_base=None, config_dir=config_dir, job_name=__file__
):
worker.main(
make_dataclass('t', [], namespace={
'Args': Args,
'run': run
}), hydra.compose(config_name, args.overrides)
)
Loading