From 02777473b8ce213eb1decfc75f05ba045f77212a Mon Sep 17 00:00:00 2001 From: chungongyu Date: Wed, 30 Jul 2025 17:05:41 +0800 Subject: [PATCH] feat: add Diffusion Structure Module --- examples/generate.job | 93 ++ install_env.sh | 2 + profold2/command/generator.py | 424 +++++++++ profold2/command/main.py | 3 +- profold2/common/chemical_components.py | 218 +++++ profold2/data/dataset.py | 94 +- profold2/data/utils.py | 41 +- profold2/model/alphafold2.py | 8 +- profold2/model/atom_layout.py | 157 ++++ profold2/model/diffusion.py | 1100 ++++++++++++++++++++++++ profold2/model/head.py | 683 ++++++++++++++- 11 files changed, 2813 insertions(+), 10 deletions(-) create mode 100755 examples/generate.job create mode 100644 profold2/command/generator.py create mode 100644 profold2/common/chemical_components.py create mode 100644 profold2/model/atom_layout.py create mode 100644 profold2/model/diffusion.py diff --git a/examples/generate.job b/examples/generate.job new file mode 100755 index 00000000..8a325d1c --- /dev/null +++ b/examples/generate.job @@ -0,0 +1,93 @@ +#!/bin/bash +#SBATCH --job-name=profold2_generate # identifier for the job listings +#SBATCH --output=generate.log # outputfile + +#SBATCH --nodes=2 # number of nodes you want to use +#SBATCH --gpus=4 # count of GPUs required for the job +#SBATCH --qos=gpugpu # quality of service +#SBATCH --ntasks-per-node=1 # number of tasks to invoke on each node +#SBATCH --gpus-per-task=2 # every process wants one GPU! +#SBATCH --gpu-bind=none # NCCL can't deal with task-binding... + +# check if script is started via SLURM or bash +# if with SLURM: there variable '$SLURM_JOBID' will exist +CWD=$0 +if [ -n "${SLURM_JOBID}" ]; then + CWD=`scontrol show job ${SLURM_JOBID}|awk -F= '/Command=/{print $2}'` +fi +CWD=`realpath -s ${CWD}` +CWD=`dirname ${CWD}` +PWD=`dirname ${CWD}` + +help() { + echo "usage: `basename $0` [-h] -p {slurm,local} -- [pred_opt ...] fasta_file [fasta_file ...]" + echo "positional arguments:" + echo " pred_opt generate option." + echo " \`python ${PWD}/main.py generate -h\` for further help." + echo " fasta_file fasta format protein sequence file." + echo "options:" + echo " -h, --help show this help message and exit" + echo " -p PLATFORM, --platform PLATFORM {slurm,local}" + echo " type of platform. (default: slurm)" + exit $1 +} + +platform="slurm" + +ARGS=$(getopt -o "p:h" -l "platform:,help" -- "$@") || help 1 +eval "set -- ${ARGS}" +while true; do + case "$1" in + (-p | --platform) platform="$2"; shift 2;; + (-h | --help) help 0 ;; + (--) shift 1; break;; + (*) help 1; + esac +done + +## get the first node name as master address - customized for vgg slurm +## e.g. master(gnodee[2-5],gnoded1) == gnodee2 +echo "===================================" +echo "CurrentWorkDir=`pwd`" +echo "Platform=${platform}" +if [ x"${platform}" = x"slurm" ]; then + echo "NodeList=${SLURM_NODELIST}" + master_addr=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) + node_opts="" +else + master_addr=${master_addr:-"127.0.0.1"} + node_opts="--nnodes=${nnodes:-1} --node_rank=${node_rank:-0}" +fi +master_port=${master_port:-23456} +echo "MasterAddr=${master_addr}:${master_port}" +echo "===================================" +node_opts="${node_opts} --init_method=tcp://${master_addr}:${master_port}" + +## init virtual environment if needed +conda_home=${conda_home:-"${HOME}/.local/anaconda3"} +. ${conda_home}/bin/activate profold2 + +exp=${exp:-"150m"} +model_suffix=${model_suffix:-""} + +export AxialAttention_accept_edge_norm=${AxialAttention_accept_edge_norm:-"0"} +export AxialAttention_accept_kernel_fn=${AxialAttention_accept_kernel_fn:-"1"} +# export AxialAttention_accept_kernel_dtype="float16" + +export NCCL_SOCKET_FAMILY=AF_INET +export NCCL_TIMEOUT=3600 + +runner="python" +if [ x"${platform}" = x"slurm" ]; then + runner="srun ${runner}" +fi +${runner} ${PWD}/main.py ${node_opts} generate \ + --prefix=${CWD}/${exp}.pred${model_suffix} \ + \ + --models ${CWD}/${exp}.folding/model.pth${model_suffix} \ + --map_location=cpu \ + --model_recycles=2 \ + --model_shard_size=256 \ + \ + --fasta_fmt=single \ + $* diff --git a/install_env.sh b/install_env.sh index 958c2f4b..029c7514 100644 --- a/install_env.sh +++ b/install_env.sh @@ -64,7 +64,9 @@ conda install -y -c nvidia/label/cuda-${cuda_version} \ conda install -y -c conda-forge \ biopython \ + biotite \ einops \ + rdkit \ tensorboard \ tqdm \ && cleanup diff --git a/profold2/command/generator.py b/profold2/command/generator.py new file mode 100644 index 00000000..0b0cfb3b --- /dev/null +++ b/profold2/command/generator.py @@ -0,0 +1,424 @@ +"""Tools for inference, run + ```bash + $python generator.py -h + ``` + for further help. +""" +import os +import contextlib +import functools +import glob +import json +import logging + +import numpy as np +import torch +from torch.utils.data.distributed import DistributedSampler + +# models & data +from profold2.common import protein +from profold2.data import dataset +from profold2.data.dataset import ProteinSequenceDataset +from profold2.data.parsers import parse_fasta +from profold2.data.utils import parse_seq_index, pdb_from_generation, str_seq_index +from profold2.model import accelerator, profiler, snapshot, FeatureBuilder, ReturnValues +from profold2.utils import exists, timing + +from profold2.command import worker + + +@contextlib.contextmanager +def _env_ctx(**kwargs): + old_env = dict(os.environ) + for key, value in kwargs.items(): + if exists(value): + os.environ[key] = str(value) + + yield + + for key in kwargs: + if key not in old_env: + del os.environ[key] + os.environ.update(old_env) + + +def _read_fasta(args): # pylint: disable=redefined-outer-name + def filename_get(fasta_file): + fasta_file = os.path.basename(fasta_file) + pid, _ = os.path.splitext(fasta_file) + return pid + + if exists(args.fasta_file_list): + with open(args.fasta_file_list, 'r') as f: + for fasta_file in filter(lambda x: len(x) > 0, map(lambda x: x.strip(), f)): + args.fasta_files.append(fasta_file) + for fasta_glob in args.fasta_files: + for fasta_file in glob.glob(fasta_glob): + fasta_name = filename_get(fasta_file) + with open(fasta_file, 'r') as f: + fasta_str = f.read() + yield fasta_name, fasta_str + + +def _create_dataloader(xpu, args): # pylint: disable=redefined-outer-name + kwargs = {'pin_memory': True, 'shuffle': False} + if exists(args.data_dir): + if xpu.is_available() and accelerator.world_size(args.nnodes) > 1: + kwargs['num_replicas'] = accelerator.world_size(args.nnodes) + kwargs['rank'] = xpu.rank + return dataset.load( + data_dir=args.data_dir, + data_idx=args.data_idx, + pseudo_linker_prob=1.0 if args.add_pseudo_linker else 0.0, + pseudo_linker_shuffle=False, + max_msa_depth=args.max_msa_size, + num_workers=args.num_workers, + **kwargs + ) + + sequences, descriptions, msa = [], [], [] + for fasta_name, fasta_str in _read_fasta(args): + s, d = parse_fasta(fasta_str) + d[0] = f'{fasta_name} {d[0]}' if exists(d[0]) else fasta_name + if args.fasta_fmt == 'single': + sequences += [s] + descriptions += [d] + msa += [[None] * len(s)] + else: + sequences += [s[:1]] + descriptions += [d[:1]] + if len(s) > args.max_msa_size: + s = s[:1] + list( + np.random.choice( + s, size=args.max_msa_size - 1, replace=False + ) if args.max_msa_size > 1 else [] + ) + msa += [[s]] + data = ProteinSequenceDataset(sequences, descriptions, msa=msa) + if xpu.is_available() and accelerator.world_size(args.nnodes) > 1: + kwargs['sampler'] = DistributedSampler( + data, + num_replicas=accelerator.world_size(args.nnodes), + rank=xpu.rank, + shuffle=False + ) + return torch.utils.data.DataLoader( + data, + collate_fn=ProteinSequenceDataset.collate_fn, + num_workers=args.num_workers, + **kwargs + ) + + +def _create_relaxer(use_gpu_relax=False): + from profold2.relax import relax # pylint: disable=import-outside-toplevel + + return relax.AmberRelaxation( + max_iterations=relax.RELAX_MAX_ITERATIONS, + tolerance=relax.RELAX_ENERGY_TOLERANCE, + stiffness=relax.RELAX_STIFFNESS, + exclude_residues=relax.RELAX_EXCLUDE_RESIDUES, + max_outer_iterations=relax.RELAX_MAX_OUTER_ITERATIONS, + use_gpu=use_gpu_relax + ) + + +def _load_models(rank, args): # pylint: disable=redefined-outer-name + def _location_split(model_location): + k = model_location.find('=') + if k != -1: + return model_location.split('=', 1) + model_name = os.path.basename(model_location) + model_name, _ = os.path.splitext(model_name) + return model_name, model_location + + wm = worker.WorkerModel(rank, args) + for i, model_location in enumerate(args.models): + model_name, model_location = _location_split(model_location) + logging.info( + 'Load model [%d/%d] %s from %s', i, len(args.models), model_name, model_location + ) + + feats, model = wm.load(model_location) + features = FeatureBuilder(feats).to(wm.device()) + yield model_name, (features, model) + + +def generate(rank, args): # pylint: disable=redefined-outer-name + model_runners = dict(_load_models(rank, args)) + logging.info('Have %d models: %s', len(model_runners), list(model_runners.keys())) + + test_loader = _create_dataloader(rank, args) + amber_relaxer = ( + None if args.no_relaxer else + _create_relaxer(use_gpu_relax=rank.is_available() and not args.no_gpu_relax) + ) + + def timing_callback(timings, key, tic, toc): + timings[key] = toc - tic + + def generate_structure(idx, batch): + assert len(batch['pid']) == 1 + timings = {} + + fasta_name = ','.join(batch['pid']) + with timing( + f'Generating {fasta_name}', + print_fn=logging.info, + callback_fn=functools.partial(timing_callback, timings, 'generate_structure') + ): + logging.debug('Sequence [%d] %s shape : %s', idx, fasta_name, batch['seq'].shape) + if args.fasta_fmt in ('a3m', 'a4m'): + logging.debug('msa shape %s: %s', fasta_name, batch['msa'].shape) + + output_dir = os.path.join(args.prefix, fasta_name) + os.makedirs(output_dir, exist_ok=True) + + unrelaxed_pdbs, relaxed_pdbs = {}, {} + ranking_scores = {} + + def process_structure(model_name, pdb_str, plddt=None): + ranking_scores[model_name] = 0 + if exists(plddt): + ranking_scores[model_name] = plddt.item() + + unrelaxed_pdbs[model_name] = pdb_str + unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb') + with open(unrelaxed_pdb_path, 'w') as f: + f.write(unrelaxed_pdbs[model_name]) + + if exists(amber_relaxer): + # Relax the generation. + with timing( + f'Relax pdb from model {model_name} on {fasta_name}', + print_fn=logging.info, + callback_fn=functools.partial( + timing_callback, timings, f'relax_{model_name}' + ) + ): + retry = 2 + while retry > 0: + retry -= 1 + try: + relaxed_pdb_str, _, _ = amber_relaxer.process( + prot=protein.from_pdb_string(unrelaxed_pdbs[model_name]) + ) + break + except ValueError as e: + logging.error('Relax throw an exception: %s', e) + if retry <= 0: + relaxed_pdb_str = unrelaxed_pdbs[model_name] + logging.error('Using unrelaxed pdb instead.') + # raise e + + relaxed_pdbs[model_name] = relaxed_pdb_str + + # Save the relaxed PDB. + relaxed_output_path = os.path.join(output_dir, f'relaxed_{model_name}.pdb') + with open(relaxed_output_path, 'w') as f: + f.write(relaxed_pdb_str) + + for model_name, (features, model) in model_runners.items(): + # Build features. + with timing( + f'Building features for model {model_name} on {fasta_name}', + print_fn=logging.info, + callback_fn=functools.partial( + timing_callback, timings, f'build_features_{model_name}' + ) + ): + feats = features(batch, is_training=False) + + # Generate - out isĀ (b, m, i, c, 3) + with torch.no_grad(): + with timing( + f'Running model {model_name} on {fasta_name}', + print_fn=logging.info, + callback_fn=functools.partial( + timing_callback, timings, f'generate_{model_name}' + ) + ): + with accelerator.amp(args.amp_enabled): # Automatic Mixed Precision + r = ReturnValues( + **model( + batch=feats, + num_recycle=args.model_recycles, + shard_size=args.model_shard_size + ) + ) + + # Save the model outputs. + if not args.no_pth: + torch.save(r, os.path.join(output_dir, f'result_{model_name}.pth')) + + plddt = None + if 'donfidence' in r.headers: + plddt = r.headers['donfidence']['loss'][0] # idx = 0 + pdbs = pdb_from_generation(batch, r.headers, idx=0) + for pred_idx in range(args.generation_batch_size): + process_structure( + f'{model_name}_pred_{pred_idx}', + pdbs[pred_idx], + plddt=plddt[pred_idx] if exists(plddt) else None + ) + # Rank by model confidence and write out relaxed PDBs in rank order. + ranked_order = [] + for i, (model_name, _) in enumerate( + sorted(ranking_scores.items(), key=lambda x: x[1], reverse=True) + ): + ranked_order.append(model_name) + ranked_output_path = os.path.join(output_dir, f'ranked_{i}.pdb') + with open(ranked_output_path, 'w') as f: + if exists(amber_relaxer): + f.write(relaxed_pdbs[model_name]) + else: + f.write(unrelaxed_pdbs[model_name]) + + ranking_output_path = os.path.join(output_dir, 'ranking_debug.json') + with open(ranking_output_path, 'w') as f: + f.write( + json.dumps( + { + 'confidences': ranking_scores, + 'order': ranked_order + }, indent=4 + ) + ) + + logging.info('Final timings for %s: %s', fasta_name, timings) + + timings_output_path = os.path.join(output_dir, 'timings.json') + with open(timings_output_path, 'w') as f: + f.write(json.dumps(timings, indent=4)) + + # Predict structure + with profiler.profile( + enabled=args.enable_profiler, + record_shapes=True, + profile_memory=True, + with_stack=True + ) as prof: + with snapshot.memory_snapshot( + enabled=args.enable_memory_snapshot, device=rank.device + ): + for idx, batch in enumerate(iter(test_loader)): + try: + with _env_ctx( + profold2_diffusion_steps=args.generation_steps, + profold2_diffusion_batch_size=args.generation_batch_size + ): + generate_structure(idx, batch) + except RuntimeError as e: + logging.error('%d %s', idx, str(e)) + + if hasattr(prof, 'step'): + prof.step() + + if hasattr(prof, 'key_averages'): + logging.debug('%s', prof.key_averages().table(sort_by='cuda_time_total')) + logging.debug('memory_summary: \n%s', rank.memory_summary()) + + +def add_arguments(parser): # pylint: disable=redefined-outer-name + parser.add_argument( + '--map_location', + type=str, + default=None, + help='remapped to an alternative set of devices.' + ) + parser.add_argument('fasta_files', type=str, nargs='*', help='fasta files.') + parser.add_argument( + '--fasta_file_list', type=str, default=None, help='fasta file list.' + ) + parser.add_argument( + '--fasta_fmt', + type=str, + default='single', + choices=['single', 'a3m'], + help='format of fasta files.' + ) + + parser.add_argument( + '--data_dir', type=str, default=None, help='load data from dataset.' + ) + parser.add_argument('--data_idx', type=str, default=None, help='dataset idx.') + parser.add_argument( + '--add_pseudo_linker', action='store_true', help='enable loading complex data.' + ) + + parser.add_argument( + '--models', + type=str, + nargs='+', + required=True, + metavar='[MODEL_NAME=]MODEL_PATH', + help=' Models to be loaded using [model_name=]model_location format.' + ) + parser.add_argument( + '--model_recycles', type=int, default=0, help='number of recycles in profold2.' + ) + parser.add_argument( + '--model_shard_size', + type=int, + default=None, + help='shard size in evoformer model.' + ) + parser.add_argument( + '--generation_batch_size', type=int, default=5, help='generation batch size.' + ) + parser.add_argument( + '--generation_steps', type=int, default=200, help='generation steps.' + ) + parser.add_argument( + '--max_msa_size', type=int, default=1024, help='filter out msas whose size>SIZE.' + ) + + parser.add_argument('--num_workers', type=int, default=1, help='number of workers.') + parser.add_argument('--no_relaxer', action='store_true', help='do NOT run relaxer.') + parser.add_argument( + '--no_pth', action='store_true', help='do NOT save prediction header.' + ) + parser.add_argument('--no_gpu_relax', action='store_true', help='run relax on cpu.') + parser.add_argument( + '--amp_enabled', action='store_true', help='enable automatic mixed precision.' + ) + parser.add_argument('--enable_profiler', action='store_true', help='enable profiler.') + parser.add_argument( + '--enable_memory_snapshot', action='store_true', help='enable memory snapshot.' + ) + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + + # init distributed env + parser.add_argument('--nnodes', type=int, default=None, help='number of nodes.') + parser.add_argument('--node_rank', type=int, default=0, help='rank of the node.') + parser.add_argument( + '--local_rank', type=int, default=None, help='local rank of xpu, default=None' + ) + parser.add_argument( + '--init_method', + type=str, + default='file:///tmp/profold2.dist', + help='method to initialize the process group, ' + 'default=\'file:///tmp/profold2.dist\'' + ) + + # output dir + parser.add_argument( + '-o', + '--prefix', + type=str, + default='.', + help='prefix of out directory, default=\'.\'' + ) + add_arguments(parser) + # verbose + parser.add_argument('-v', '--verbose', action='store_true', help='verbose') + + args = parser.parse_args() + + worker.main(args, predict) diff --git a/profold2/command/main.py b/profold2/command/main.py index a32c932a..e8c8ea74 100644 --- a/profold2/command/main.py +++ b/profold2/command/main.py @@ -6,13 +6,14 @@ """ import argparse -from profold2.command import (evaluator, predictor, trainer, worker) +from profold2.command import (evaluator, generator, predictor, trainer, worker) from profold2.utils import env _COMMANDS = [ ('train', trainer.train, trainer.add_arguments), ('evaluate', evaluator.evaluate, evaluator.add_arguments), ('predict', predictor.predict, predictor.add_arguments), + ('generate', generator.generate, generator.add_arguments), ] diff --git a/profold2/common/chemical_components.py b/profold2/common/chemical_components.py new file mode 100644 index 00000000..cb838886 --- /dev/null +++ b/profold2/common/chemical_components.py @@ -0,0 +1,218 @@ +import concurrent.futures +import functools +import logging + +from biotite import structure +from biotite.interface.rdkit import from_mol as FromMol +from rdkit.Chem import AllChem, GetPeriodicTable +import numpy as np + +from profold2.common import residue_constants +from profold2.utils import exists + +logger = logging.getLogger(__name__) + +# NOTE: https://www.wwpdb.org/data/ccd + +elem_type_num = 128 +name_char_channel = 4 +name_char_num = 64 + + +@functools.cache +def element_idx(elem: str) -> int: + t = GetPeriodicTable() + # GetAtomicNumber() -> [1, ...] + return min(t.GetAtomicNumber(elem), elem_type_num) - 1 + + +@functools.cache +def atom_name_idx(atom_name: str) -> list[int]: + name_chars = [0] * name_char_channel + for i, c in enumerate(atom_name[:name_char_channel]): + name_chars[i] = min(ord(c) - 32, name_char_num - 1) + return name_chars + + +@functools.cache +def has_residue(residue_id: str) -> bool: + try: + return exists(structure.info.residue(residue_id)) + except KeyError: + return False + + +@functools.cache +def atom14_type_num(residue_id: str) -> int: + if residue_id in residue_constants.restype_3to1: + mol_idx = residue_constants.restype_order_with_x[ + residue_constants.restype_3to1[residue_id] + ] + mol_type = residue_constants.moltype(mol_idx) + if mol_type == residue_constants.DNA: + return residue_constants.dna_atom_type_num + elif mol_type == residue_constants.RNA: + return residue_constants.rna_atom_type_num + return residue_constants.prot_atom_type_num + return -1 # UNK, without padding + + +@functools.cache +def pad_virtual_atom_num(residue_id: str): + _table = { + 'ASP': 3, + 'LEU': 6, + 'GLU': 5, + 'SER': 8, + 'VAL': 7, + 'U': 1, + } + return _table.get(residue_id, 0) + + +@functools.cache +def residue_atom_array( + residue_id: str, + keep_leaving_atoms: bool = True, + pad_with_virtual_atoms: bool = False +) -> structure.AtomArray: + assert has_residue(residue_id), ( + f'CCD: No atom information found for residue {residue_id}' + ) + atom_array = structure.info.residue(residue_id) + atom_category = structure.info.get_from_ccd('chem_comp_atom', residue_id) + + atom_array.set_annotation('mask', np.ones_like(atom_array, dtype=np.bool_)) + atom_array.set_annotation('charge', atom_category['charge'].as_array()) + for atom_id in ['alt_atom_id', 'pdbx_component_atom_id']: + atom_array.set_annotation(atom_id, atom_category[atom_id].as_array()) + leaving_atom_flag = atom_category['pdbx_leaving_atom_flag'].as_array() + atom_array.set_annotation('leaving_atom_flag', leaving_atom_flag == 'Y') + + if not keep_leaving_atoms: + atom_array = atom_array[~atom_array.leaving_atom_flag] + # remove hydrogens + atom_array = atom_array[~np.isin(atom_array.element, ['H', 'D'])] + + atom_list = residue_constants.restype_name_to_atom14_names[residue_id] + atom_array = atom_array[np.isin(atom_array.atom_name, atom_list)] + atom_array.set_annotation( + 'atom_within_token_mask', np.ones_like(atom_array, dtype=np.bool_) + ) + atom_array.set_annotation( + 'atom_padding_token_idx', + np.array( + [atom_list.index(atom_name) for atom_name in atom_array.atom_name], + dtype=np.int32 + ) + ) + atom_array.set_annotation( + 'atom_repr_token_mask', + np.array( + [atom_name in ('CA', 'P') for atom_name in atom_array.atom_name], + dtype=np.bool_ + ) + ) + + def _pad_atoms(atom_array, atom_names, pad_length): + if pad_length <= 0: + return atom_array + pad_atoms = atom_array[np.isin(atom_array.atom_name, atom_names)] + pad_atoms.set_annotation( + 'atom_within_token_mask', np.zeros_like(pad_atoms.atom_within_token_mask) + ) + pad_atoms.set_annotation( + 'atom_repr_token_mask', np.zeros_like(pad_atoms.atom_within_token_mask) + ) + pad_atoms = structure.repeat(pad_atoms, np.stack([pad_atoms.coord] * pad_length)) + pad_atoms.set_annotation( + 'atom_padding_token_idx', np.arange(pad_length) + atom_array.array_length() + ) + return structure.concatenate([atom_array, pad_atoms]) + + if pad_with_virtual_atoms: + atom_array = _pad_atoms( + atom_array, ['O', 'O5\''], + min( + pad_virtual_atom_num(residue_id), + atom14_type_num(residue_id) - atom_array.array_length() + ) + ) + atom_array = _pad_atoms( + atom_array, ['CA', 'P'], atom14_type_num(residue_id) - atom_array.array_length() + ) + + atom_array.set_annotation( + 'atom_within_token_idx', + np.array( + [atom_list.index(atom_name) for atom_name in atom_array.atom_name], + dtype=np.int32 + ) + ) + atom_array.set_annotation( + 'element_idx', np.array([element_idx(elem) for elem in atom_array.element]) + ) + atom_array.set_annotation( + 'atom_name_chars', + np.array([atom_name_idx(atom_name) for atom_name in atom_array.atom_name]) + ) + + return atom_array + + +def smiles_atom_array(ligand_string: str) -> structure.AtomArray: + mol = AllChem.MolFromSmiles(ligand_string) + # RDKit uses implicit hydrogen atoms by default, but Biotite requires explicit ones + mol = AllChem.AddHs(mol) + # create a 3D conformer + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(AllChem.EmbedMolecule, mol) + try: + conformer_id = future.result(timeout=90) + except TimeoutError as e: + raise TimeoutError('RDKit conformer generation timed out.') from e + if conformer_id != 0: + # retry with random coords + conformer_id = AllChem.EmbedMolecule(mol, useRandomCoords=True) + assert conformer_id == 0, ( + f'RDKit conformer generation failed for input SMILES: {ligand_string}' + ) + AllChem.UFFOptimizeMolecule(mol) + atom_array = FromMol(mol, conformer_id) + + # remove hydrogens + atom_array = atom_array[~np.isin(atom_array.element, ['H', 'D'])] + + return atom_array + + +def polymer_atom_array( + seq, seq_index, pad_with_virtual_atoms: bool = False +) -> structure.AtomArray: + atom_array = None # structure.AtomArray(0) + + for token_idx, (restype, res_id) in enumerate(zip(seq, seq_index)): + mol_type = residue_constants.moltype(restype) + residue_id = residue_constants.restype_1to3[ + (residue_constants.restypes_with_x[restype], mol_type) + ] + residue = residue_atom_array( + residue_id, pad_with_virtual_atoms=pad_with_virtual_atoms + ) + residue.res_id[:] = res_id + residue.set_annotation( + 'atom_to_token_idx', np.full_like(residue, token_idx, dtype=np.int32) + ) + residue.set_annotation('space_uid', np.full_like(residue, res_id, dtype=np.int32)) + if exists(atom_array): + atom_array += residue + else: + atom_array = residue + + bonds = structure.connect_via_residue_names(atom_array) + if exists(atom_array.bonds): + atom_array.bonds = atom_array.bonds.merge(bonds) + else: + atom_array.bonds = bonds + + return atom_array diff --git a/profold2/data/dataset.py b/profold2/data/dataset.py index 4831c02d..d6868635 100644 --- a/profold2/data/dataset.py +++ b/profold2/data/dataset.py @@ -25,7 +25,7 @@ from torch.utils.data.distributed import DistributedSampler from einops import repeat -from profold2.common import residue_constants +from profold2.common import chemical_components, residue_constants from profold2.data import cropper, padding, parsers, utils from profold2.utils import default, env, exists, timing @@ -426,6 +426,57 @@ def _make_task_mask( return torch.stack(variant_task_mask, dim=-1) +def _make_ref_conformer(seq, seq_index, pad_with_virtual_atoms=False): + atom_array = chemical_components.polymer_atom_array( + seq, seq_index, pad_with_virtual_atoms=pad_with_virtual_atoms + ) + mask = 0 if pad_with_virtual_atoms else 1 + ret = { + key: torch.from_numpy(value) + for key, value in ( + ('ref_pos', atom_array.coord * mask), + ('ref_charge', atom_array.charge * mask), + ('ref_mask', atom_array.mask), + ('ref_element', atom_array.element_idx * mask), + ('ref_atom_name_chars', atom_array.atom_name_chars * mask), + ('ref_space_uid', atom_array.space_uid), + ('atom_to_token_idx', atom_array.atom_to_token_idx), + ('atom_within_token_idx', atom_array.atom_within_token_idx), + ('atom_within_token_mask', atom_array.atom_within_token_mask), + ('atom_repr_token_mask', atom_array.atom_repr_token_mask), + ('atom_padding_token_idx', atom_array.atom_padding_token_idx), + ) + } + if pad_with_virtual_atoms: + is_prot = torch.logical_and( + seq >= residue_constants.prot_from_idx, seq <= residue_constants.prot_to_idx + ) + is_dna = torch.logical_and( + seq >= residue_constants.dna_from_idx, seq <= residue_constants.dna_to_idx + ) + is_rna = torch.logical_and( + seq >= residue_constants.rna_from_idx, seq <= residue_constants.rna_to_idx + ) + ret['seq_true'] = seq + if torch.any(is_rna): + unk_rna_index = residue_constants.restype_order_with_x[('X', residue_constants.RNA)] + seq = torch.where( + is_rna, + residue_constants.restype_order_with_x[('X', residue_constants.RNA)], + seq + ) + if torch.any(is_dna): + seq = torch.where( + is_dna, + residue_constants.restype_order_with_x[('X', residue_constants.DNA)], + seq + ) + if torch.any(is_prot): + seq = torch.where(is_prot, residue_constants.unk_restype_index, seq) + ret['seq'] = seq + return ret + + def _make_seq_features( sequence, description, @@ -1017,6 +1068,8 @@ def __getitem__(self, idx): feat.update(_make_msa_features(self.msa[idx][seq_idx], feat['seq_type'])) ret = concat_msa_feats(ret, feat) + # ret.update(_make_ref_conformer(ret['seq'], ret['seq_index'])) + return ret def __len__(self): @@ -1195,6 +1248,7 @@ def __getitem__(self, idx): ret = self.get_monomer(pid, crop_fn=self.data_crop_fn) else: ret = self.get_multimer(pid, chains) + # ret.update(_make_ref_conformer(ret['seq'], ret['seq_index'])) # We need all the amino acids! # if 'coord_mask' in ret: @@ -1840,6 +1894,17 @@ def _to_tensor(field, defval=0, dtype=None): ret = {} + pad_with_virtual_atoms = np.random.random() < env( + 'profold2_data_pad_with_virtual_atoms_ratio', defval=0, dtype=float + ) + for b in batch: + b.update( + _make_ref_conformer( + b['seq'], b['seq_index'], pad_with_virtual_atoms=pad_with_virtual_atoms + ) + ) + # ret['diffuser_use_conditioning'] = not pad_with_virtual_atoms + for field in ('pid', 'str_seq', 'clip'): ret[field] = _to_list(field) @@ -1848,6 +1913,10 @@ def _to_tensor(field, defval=0, dtype=None): ret['seq'] = padding.pad_sequential( _to_list('seq'), max_batch_len, padval=residue_constants.unk_restype_index ) + if _any('seq_true'): + ret['seq_true'] = padding.pad_sequential( + _to_list('seq_true'), max_batch_len, padval=residue_constants.unk_restype_index + ) for field in ('seq_index', 'mask'): ret[field] = padding.pad_sequential(_to_list(field), max_batch_len) @@ -1908,6 +1977,29 @@ def _to_tensor(field, defval=0, dtype=None): max_depth = max(item.shape[0] for item in items if exists(item)) ret[field] = padding.pad_sequential(items, max_depth) + if _any('ref_pos'): + max_atoms = max(a.shape[0] for a in _to_list('ref_pos')) + + for field in ('ref_pos', 'ref_mask', 'ref_element', 'ref_charge', 'ref_space_uid'): + ret[field] = padding.pad_sequential(_to_list(field), max_atoms) + + ret['ref_atom_name_chars'] = padding.pad_sequential( + _to_list('ref_atom_name_chars'), max_atoms + ) + ret['atom_within_token_mask'] = padding.pad_sequential( + _to_list('atom_within_token_mask'), max_atoms + ) + ret['atom_repr_token_mask'] = padding.pad_sequential( + _to_list('atom_repr_token_mask'), max_atoms + ) + + for field in ( + 'atom_to_token_idx', 'atom_within_token_idx', 'atom_padding_token_idx' + ): + ret[field] = padding.pad_sequential(_to_list(field), max_atoms, padval=-1) + # NOTE: torch.gather the last one if masked + ret[field] = ret[field] % (torch.max(ret[field]) + 1) + return ret diff --git a/profold2/data/utils.py b/profold2/data/utils.py index b191dd94..f2767f62 100644 --- a/profold2/data/utils.py +++ b/profold2/data/utils.py @@ -5,9 +5,11 @@ import numpy as np import torch +from torch.nn import functional as F +from einops import repeat from profold2.common import protein, residue_constants -from profold2.utils import exists +from profold2.utils import default, exists def decompose_pid(pid, return_domain=False): @@ -188,13 +190,14 @@ def tensor_to_numpy(t: torch.Tensor) -> np.ndarray: def pdb_from_model( batch: dict[str, Any], pcoord: torch.Tensor, + seq: Optional[torch.Tensor] = None, plddt: Optional[torch.Tensor] = None, idx: Optional[int] = None ) -> Union[str, list[str]]: + seq = default(seq, batch['seq']) def to_pdb_str(b: int) -> str: - str_seq = batch['str_seq'][b] - seq_len = len(str_seq) - aatype = tensor_to_numpy(batch['seq'][b]) + seq_len = seq[b].shape[0] + aatype = tensor_to_numpy(seq[b]) if 'seq_index' in batch and exists(batch['seq_index'][b]): seq_index = tensor_to_numpy(batch['seq_index'][b]) else: @@ -236,3 +239,33 @@ def pdb_from_prediction( if 'confidence' in headers and 'plddt' in headers['confidence']: plddt = headers['confidence']['plddt'][..., None] return pdb_from_model(batch, headers['folding']['coords'], plddt=plddt, idx=idx) + + +def pdb_from_generation( + batch: dict[str, Any], headers: dict[str, Any], idx: Optional[int] = None +) -> Union[str, list[str], list[list[str]]]: + plddt = None + if 'donfidence' in headers and 'plddt' in headers['donfidence']: + plddt = headers['donfidence']['plddt'] + generation_batch_size = None + if 'sequence' in headers: + seq = torch.argmax(F.softmax(headers['sequence']['logits'], dim=-1), dim=-1) + else: + seq = None + if 'diffusion' in headers: + pcoord = headers['diffusion']['coords'] + generation_batch_size = headers['diffusion'].get('batch_size') + else: + assert 'coord' in batch + pcoord = batch['coord'] + if exists(generation_batch_size): + return [ + pdb_from_model( + batch, + pcoord[:, m], # (b m i c d) + seq=seq[:, m] if exists(seq) else None, + plddt=plddt[:, m] if exists(plddt) else None, + idx=idx + ) for m in range(generation_batch_size) + ] + return pdb_from_model(batch, pcoord, seq=seq, plddt=plddt, idx=idx) diff --git a/profold2/model/alphafold2.py b/profold2/model/alphafold2.py index a99c02f7..44da1a84 100644 --- a/profold2/model/alphafold2.py +++ b/profold2/model/alphafold2.py @@ -13,7 +13,7 @@ from einops import rearrange from profold2.common import residue_constants -from profold2.model import accelerator, commons, folding, functional +from profold2.model import accelerator, commons, diffusion, folding, functional from profold2.model.evoformer import Evoformer from profold2.model.head import HeaderBuilder from profold2.utils import env, exists, status @@ -348,7 +348,11 @@ def forward( ret.headers[name] = value if 'representations' in value: representations.update(value['representations']) - if 'folding' in ret.headers: + if 'diffusion' in ret.headers: + batch = diffusion.multi_chain_permutation_alignment( + ret.headers['diffusion'], batch + ) + elif 'folding' in ret.headers: batch = folding.multi_chain_permutation_alignment(ret.headers['folding'], batch) if self.training and compute_loss: for name, module, options in self.headers: diff --git a/profold2/model/atom_layout.py b/profold2/model/atom_layout.py new file mode 100644 index 00000000..6054b116 --- /dev/null +++ b/profold2/model/atom_layout.py @@ -0,0 +1,157 @@ +"""Helper functions for different atom layouts and conversion between them.""" +import math +from typing import Optional, Union + +import torch +from torch.nn import functional as F +from einops import rearrange, repeat + +from profold2.common import residue_constants +from profold2.model import functional +from profold2.utils import default, exists + + +def padding( + t: torch.Tensor, + dim: int, + pad: Union[tuple[int], list[int]], +) -> torch.Tensor: + dim = dim % t.dim() + if pad != (0, 0): + pad = (0, 0) * (t.dim() - dim - 1) + pad + t = F.pad(t, pad=pad) + return t + + +def reshape( + t: torch.Tensor, dim: int, shape: Union[tuple[int], list[int]] +) -> torch.Tensor: + dim = dim % t.dim() + assert t.shape[dim] == math.prod(shape) + return torch.reshape(t, t.shape[:dim] + tuple(shape) + t.shape[dim + 1:]) + + +def permute(t: torch.Tensor, dim: int) -> torch.Tensor: + dim = dim % t.dim() + if dim + 1 < t.dim(): + dims = tuple(range(dim)) + (-1, ) + tuple(range(dim, t.dim() - 1)) + t = torch.permute(t, dims) + return t + + +def unfold( + q_window_size: int, + k_window_size: int, + dim: int, + q: torch.Tensor, + k: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor], int]: + assert k_window_size % 2 == 0 and q_window_size % 2 == 0 + assert k_window_size >= q_window_size + assert not exists(k) or q.shape[dim] == k.shape[dim] + + n = q.shape[dim] // q_window_size + (1 if q.shape[dim] % q_window_size > 0 else 0) + q_padding = 0, n * q_window_size - q.shape[dim] + # q + q = reshape(padding(q, dim=dim, pad=q_padding), dim=dim, shape=(n, q_window_size)) + # k + if exists(k): + k_padding = ( + (k_window_size - q_window_size) // 2, + n * q_window_size + (k_window_size - q_window_size) // 2 - k.shape[dim] + ) + k = padding(k, dim=dim, pad=k_padding) + k = permute(k.unfold(dim, size=k_window_size, step=q_window_size), dim=dim) + return q, k, q_padding[1] + + +def gather( + t: torch.Tensor, + atom_to_token_idx: torch.Tensor, + q_window_size: Optional[int] = None, + k_window_size: Optional[int] = None +) -> torch.Tensor: + if exists(q_window_size) and exists(k_window_size): # pair + n = t.shape[-2] + t = rearrange(t, '... i j d -> ... (i j) d') + + q_atom_to_token_idx, k_atom_to_token_idx, *_ = unfold( + q_window_size, k_window_size, dim=-1, q=atom_to_token_idx, k=atom_to_token_idx + ) + atom_to_token_idx = rearrange( + q_atom_to_token_idx[..., :, None] * n + k_atom_to_token_idx[..., None, :], + '... c i j -> ... (c i j)' + ) + + t = functional.batched_gather(t, atom_to_token_idx, dim=-2, has_batch_dim=True) + + if exists(q_window_size) and exists(k_window_size): # pair + t = rearrange(t, '... (c i j) d -> ... c i j d', i=q_window_size, j=k_window_size) + return t + + +def flatten( + atom_to_token_idx: torch.Tensor, + atom_within_token_idx: torch.Tensor, + coord: Optional[torch.Tensor] = None, + coord_mask: Optional[torch.Tensor] = None +) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + assert exists(coord) or exists(coord_mask) + + if exists(coord): + coord = functional.batched_gather( + rearrange(coord, '... i c d -> ... (i c) d'), + atom_to_token_idx * coord.shape[-2] + atom_within_token_idx, + dim=-2, + has_batch_dim=True + ) + if exists(coord_mask): + coord_mask = functional.batched_gather( + rearrange(coord_mask, '... i c -> ... (i c)'), + atom_to_token_idx * coord_mask.shape[-1] + atom_within_token_idx, + dim=-1, + has_batch_dim=True + ) + assert not exists(coord) or coord.shape[:-1] == coord_mask.shape + if exists(coord) and exists(coord_mask): + return coord, coord_mask + elif exists(coord_mask): + return coord_mask + return coord + + +def unflatten( + atom_to_token_idx: torch.Tensor, + atom_within_token_idx: torch.Tensor, + coord: Optional[torch.Tensor] = None, + coord_mask: Optional[torch.Tensor] = None, + num_tokens: int = None +) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + assert exists(coord) or exists(coord_mask) + num_tokens = default(num_tokens, int(torch.max(atom_to_token_idx) + 1)) + c = residue_constants.atom14_type_num + index = (atom_to_token_idx * c + atom_within_token_idx).long() + + if exists(coord): + coord = torch.scatter( + torch.zeros( + coord.shape[:-2] + (num_tokens * c, coord.shape[-1]), + device=coord.device, + dtype=coord.dtype + ), -2, repeat(index, '... -> ... d', d=coord.shape[-1]), coord + ) + coord = rearrange(coord, '... (i c) d -> ... i c d', c=c) + if exists(coord_mask): + coord_mask = torch.scatter( + torch.zeros( + coord_mask.shape[:-1] + (num_tokens * c, ), + device=coord_mask.device, + dtype=coord_mask.dtype + ), -1, index, coord_mask + ) + coord_mask = rearrange(coord_mask, '... (i c) -> ... i c', c=c) + if exists(coord) and exists(coord_mask): + return coord, coord_mask + elif exists(coord_mask): + return coord_mask + return coord diff --git a/profold2/model/diffusion.py b/profold2/model/diffusion.py new file mode 100644 index 00000000..30b6f218 --- /dev/null +++ b/profold2/model/diffusion.py @@ -0,0 +1,1100 @@ +"""Diffusion model for generating 3D-structure""" +import functools +import logging +import math +from typing import Any, Optional, Union + +from tqdm.auto import tqdm + +import torch +from torch import nn +from torch.nn import functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from profold2.common import chemical_components, residue_constants +from profold2.model import accelerator, atom_layout, commons, functional +from profold2.utils import compose, default, env, exists + +logger = logging.getLogger(__name__) + + +class InputFeatureEmbedder(nn.Module): + """InputFeatureEmbedder + """ + def __init__( + self, + dim=(128, 16), + dim_token=384, + num_tokens=len(residue_constants.restypes_with_x), + depth=3, + heads=4, + atom_query_window_size=32, + atom_key_window_size=128, + atom_feats=None + ): + super().__init__() + + self.atom_encoder = AtomAttentionEncoder( + dim=dim, + dim_token=dim_token, + depth=depth, + heads=heads, + atom_query_window_size=atom_query_window_size, + atom_key_window_size=atom_key_window_size, + atom_feats=atom_feats + ) + self.num_tokens = num_tokens + + def forward(self, batch, shard_size=None): + single_repr, *_ = self.atom_encoder(batch, shard_size=shard_size) + return torch.cat( + (F.one_hot(batch['seq'].long(), self.num_tokens + 1), single_repr), dim=-1 # pylint: disable=not-callable + ) + + +class RelativePositionEncoding(nn.Module): + """RelativePositionEncoding + """ + def __init__(self, dim, r_max=32, s_max=2): + super().__init__() + + _, dim_pair = commons.embedd_dim_get(dim) + self.r_max = r_max + self.s_max = s_max + + # (d_{ij}^{seq_index}, d_{ij}^{token_index}, b_{ij}^{seq_entity}, d{ij}^{seq_sym} + # 2*(r_{max} + 1) + 2*(r_{max} + 1) + 1 + 2*(s_{max} + 1) + self.proj = nn.Linear(2 * 2 * self.r_max + 2 * self.s_max + 7, dim_pair, bias=False) + + def forward( + self, seq_index, seq_color, seq_sym, seq_entity, token_index, shard_size=None + ): + def run_proj(seq_index_j, seq_color_j, seq_sym_j, seq_entity_j, token_index_j): + bij_seq_index = (seq_index[..., :, None] == seq_index_j[..., None, :]) + bij_seq_color = (seq_color[..., :, None] == seq_color_j[..., None, :]) + bij_seq_entity = (seq_entity[..., :, None] == seq_entity_j[..., None, :]) + + dij_seq_index = F.one_hot( # pylint: disable=not-callable + torch.where( + bij_seq_color, + torch.clamp( + seq_index[..., :, None] - seq_index_j[..., None, :], + min=-self.r_max, max=self.r_max + ) + self.r_max, + 2 * self.r_max + 1 + ).long(), + 2 * (self.r_max + 1) + ) + dij_token_index = F.one_hot( # pylint: disable=not-callable + torch.where( + bij_seq_color * bij_seq_index, + torch.clamp( + token_index[..., :, None] - token_index_j[..., None, :], + min=-self.r_max, max=self.r_max + ) + self.r_max, + 2 * self.r_max + 1 + ).long(), + 2 * (self.r_max + 1) + ) + dij_seq_sym = F.one_hot( # pylint: disable=not-callable + torch.where( + seq_entity[..., :, None] == seq_entity_j[..., None, :], + torch.clamp( + seq_sym[..., :, None] - seq_sym_j[..., None, :], + min=-self.s_max, max=self.s_max + ) + self.s_max, + 2 * self.s_max + 1 + ).long(), + 2 * (self.s_max + 1) + ) + with accelerator.autocast(enabled=False): + feats = torch.cat( + (dij_seq_index, dij_token_index, bij_seq_entity[..., None], dij_seq_sym), + dim=-1 + ).float() + return self.proj(feats) + + return functional.sharded_apply( + run_proj, [seq_index, seq_color, seq_sym, seq_entity, token_index], + shard_size=None if self.training else shard_size, + shard_dim=-1, + cat_dim=-2 + ) + + +class AtomPairwiseEmbedding(nn.Module): + """AtomPairwiseEmbedding + """ + def __init__(self, dim): + super().__init__() + + dim_single, dim_pair = commons.embedd_dim_get(dim) # dim_atom_{single, pair} + self.to_pairwise_repr = nn.Linear(dim_single, dim_pair * 2, bias=False) + + def forward(self, x, query_window_size=32, key_window_size=128): + x_i, x_j = torch.chunk(self.to_pairwise_repr(x), 2, dim=-1) + x_i, x_j, *_ = atom_layout.unfold( + query_window_size, key_window_size, dim=-2, q=x_i, k=x_j + ) + return x_i[..., None, :] + x_j[..., None, :, :] + + +class DiffusionTransformerBlock(nn.Module): + """DiffusionTransformerBlock""" + def __init__(self, *, dim, dim_cond, heads, group_size=1, dropout=0., **kwargs): + super().__init__() + + dim_single, dim_pair = commons.embedd_dim_get(dim) + + class _Block(nn.Module): + """Wrap AttentionWithBias and ConditionedFeedForward to Block""" + def __init__(self): + super().__init__() + + self.attn = commons.AttentionWithBias( + dim=dim_single, + heads=heads, + dim_cond=dim_cond, + q_use_bias=True, + o_use_bias=False, + g_use_bias=False, + **kwargs + ) + self.ff = commons.ConditionedFeedForward(dim_single, dim_cond) + self.dropout_fn = functools.partial(commons.shaped_dropout, p=dropout) + + def forward( + self, + x, + *, + cond, + mask, + context, + context_cond, + context_mask, + pair_bias, + pair_mask, + shard_size=None + ): + def dropout_wrap(f, *args, **kwargs): + shape = x.shape[:-2] + (1, 1) + return self.dropout_fn( + f(*args, **kwargs), shape=shape, training=self.training + ) + + # run attn and ff parallel: x += attn(x) + ff(x) + x = commons.tensor_add( + x, + dropout_wrap( + self.attn, + x, + cond=cond, + mask=mask, + context=context, + context_cond=context_cond, + context_mask=context_mask, + pair_bias=pair_bias, + pair_mask=pair_mask + ) + dropout_wrap(self.ff, x, cond, shard_size=shard_size) + ) + + return x + + self.net = commons.layer_stack( + _Block, group_size, checkpoint_segment_size=group_size + ) + self.edges_to_attn_bias = nn.Sequential( + nn.LayerNorm(dim_pair, bias=False), + nn.Linear(dim_pair, heads, bias=False), + Rearrange('... i j h -> ... h i j') + ) + + def forward( + self, + x, + single_cond, + pair_cond, + *, + mask=None, + pair_bias=None, + pair_mask=None, + context_creator=None, + shard_size=None + ): + cond, context, context_cond, context_mask, padding = ( + single_cond, None, None, None, 0 + ) + if exists(context_creator): + x, cond, mask, context, context_cond, context_mask, padding = context_creator( + x, single_cond, mask + ) + + pair_bias = commons.tensor_add( + self.edges_to_attn_bias(pair_cond), default(pair_bias, 0) + ) + + x = self.net( + x, + cond=cond, + mask=mask, + context=context, + context_cond=context_cond, + context_mask=context_mask, + pair_bias=pair_bias, + pair_mask=pair_mask, + shard_size=shard_size + ) + + if exists(context_creator): + x = rearrange(x, '... c i d -> ... (c i) d') + if padding != 0: + x = x[..., :-padding, :] + + return x, single_cond, pair_cond + + +class AtomTransformer(nn.Module): + """AtomTransformer""" + def __init__( + self, + dim, + depth=3, + heads=4, + dim_head=32, + query_window_size=32, + key_window_size=128 + ): + super().__init__() + + dim_single, _ = commons.embedd_dim_get(dim) # dim_atom_{single, pair} + + self.query_window_size = query_window_size + self.key_window_size = key_window_size + + self.difformer = commons.layer_stack( + DiffusionTransformerBlock, + depth, + checkpoint_segment_size=env( + 'profold2_atomtransformer_checkpoint_segment_size', defval=1, dtype=int + ), + dim=dim, + dim_cond=dim_single, + has_context=True, + heads=heads, + dim_head=dim_head + ) + + def query_context_create(self, query, cond, mask=None): + # HACK: split query to (query, context) + query, context, padding = atom_layout.unfold( + self.query_window_size, self.key_window_size, dim=-2, q=query, k=query + ) + query_cond, context_cond, *_ = atom_layout.unfold( + self.query_window_size, self.key_window_size, dim=-2, q=cond, k=cond + ) + if exists(mask): + query_mask, context_mask, *_ = atom_layout.unfold( + self.query_window_size, self.key_window_size, dim=-1, q=mask, k=mask + ) + else: + query_mask, context_mask = None, None + return query, query_cond, query_mask, context, context_cond, context_mask, padding + + def forward( + self, + single_repr, + single_cond, + pair_cond, + mask=None, + pair_mask=None, + shard_size=None + ): + query, *_ = self.difformer( + single_repr, + single_cond, + pair_cond, + mask=mask, + pair_mask=pair_mask, + context_creator=self.query_context_create, + shard_size=shard_size + ) + if exists(mask): + query = query * mask[..., None] + return query + + +class AtomAttentionEncoder(nn.Module): + """AtomAttentionEncoder""" + def __init__( + self, + dim, + dim_cond=(384, 128), + dim_token=768, + depth=3, + heads=4, + atom_query_window_size=32, + atom_key_window_size=128, + atom_feats=None, + has_coords=False + ): + super().__init__() + + dim_single, dim_pair = commons.embedd_dim_get(dim) # dim_atom_{single, pair} + dim_cond_single, dim_cond_pair = commons.embedd_dim_get(dim_cond) + + self.atom_feats = default( + atom_feats, + ( + ('ref_pos', 3, nn.Identity()), + ('ref_charge', 1, compose(torch.arcsinh, Rearrange('... -> ... ()'))), + ('ref_mask', 1, Rearrange('... -> ... ()')), + ( + 'ref_element', + chemical_components.elem_type_num, + functools.partial( + F.one_hot, num_classes=chemical_components.elem_type_num + ) + ), + ( + 'ref_atom_name_chars', + chemical_components.name_char_channel * chemical_components.name_char_num, # pylint: disable=line-too-long + compose( + functools.partial( + F.one_hot, num_classes=chemical_components.name_char_num + ), Rearrange('... c d -> ... (c d)') + ) + ) + ) + ) + dim_atom_feats = sum(d for _, d, _ in self.atom_feats) + self.to_single_cond = nn.Linear(dim_atom_feats, dim_single, bias=False) + self.to_pair_cond = nn.Linear(3 + 1 + 1, dim_pair, bias=False) + + if has_coords: + self.from_trunk_single_cond = nn.Sequential( + nn.LayerNorm(dim_cond_single, bias=False), + nn.Linear(dim_cond_single, dim_single, bias=False) + ) + self.from_trunk_pair_cond = nn.Sequential( + nn.LayerNorm(dim_cond_pair, bias=False), + nn.Linear(dim_cond_pair, dim_pair, bias=False) + ) + + self.outer_add = AtomPairwiseEmbedding(dim) + # self.outer_ff = commons.FeedForward(dim_pair) + self.outer_ff = commons.layer_stack( + nn.Sequential, 3, nn.ReLU(), nn.Linear(dim_pair, dim_pair, bias=False) + ) + + self.transformer = AtomTransformer( + dim, + depth=depth, + heads=heads, + query_window_size=atom_query_window_size, + key_window_size=atom_key_window_size + ) + + self.to_out = nn.Linear(dim_single, dim_token, bias=False) + + def forward( + self, + batch, + r_noisy=None, + trunk_single_cond=None, + trunk_pair_cond=None, + mask=None, + batch_size=None, + shard_size=None + ): + atom_to_token_idx = batch['atom_to_token_idx'] + + # create the atom single conditioning: Embed per-atom meta data + atom_single_cond = self.to_single_cond( + torch.cat([f(batch[k]) for k, _, f in self.atom_feats], dim=-1) + ) * batch['ref_mask'][..., None] + + # TODO: add ref_mask + # embed offsets between atom reference position, pairwise inverse squared + # distances, and the valid mask. + ref_pos_i, ref_pos_j, *_ = atom_layout.unfold( + self.transformer.query_window_size, + self.transformer.key_window_size, + dim=-2, + q=batch['ref_pos'], + k=batch['ref_pos'] + ) + ref_space_uid_i, ref_space_uid_j, *_ = atom_layout.unfold( + self.transformer.query_window_size, + self.transformer.key_window_size, + dim=-1, + q=batch['ref_space_uid'], + k=batch['ref_space_uid'] + ) + ref_mask_i, ref_mask_j, *_ = atom_layout.unfold( + self.transformer.query_window_size, + self.transformer.key_window_size, + dim=-1, + q=batch['ref_mask'], + k=batch['ref_mask'] + ) + dij_ref = ref_pos_i[..., :, None, :] - ref_pos_j[..., None, :, :] + vij_ref = (ref_space_uid_i[..., :, None] == ref_space_uid_j[..., None, :]) + bij_ref = (ref_mask_i[..., :, None] * ref_mask_j[..., None, :]) + pair_cond = self.to_pair_cond( + torch.cat( + ( + dij_ref * vij_ref[..., None] * bij_ref[..., None], + 1 / (1 + torch.sum(dij_ref**2, dim=-1, keepdim=True)) * vij_ref[..., None], # pylint: disable=line-too-long + vij_ref[..., None] + ), + dim=-1 + ) + ) + + # initialise the atom single representation as the single conditioning. + query, query_cond = atom_single_cond, atom_single_cond + + # if provided, add trunk embeddings and noisy positions. + assert not hasattr(self, 'from_trunk_single_cond') ^ exists(trunk_single_cond) + if exists(trunk_single_cond): + # broadcast the single embedding from the trunk + query_cond = commons.tensor_add( + atom_layout.gather( + self.from_trunk_single_cond(trunk_single_cond), atom_to_token_idx + ), query_cond + ) + assert not hasattr(self, 'from_trunk_pair_cond') ^ exists(trunk_pair_cond) + if exists(trunk_pair_cond): + # broadcast the pair embedding from the trunk + pair_cond = commons.tensor_add( + pair_cond, + atom_layout.gather( + self.from_trunk_pair_cond(trunk_pair_cond), + atom_to_token_idx, + self.transformer.query_window_size, + self.transformer.key_window_size + ) + ) + if exists(batch_size): + query_cond = rearrange(query_cond, '... i d -> ... () i d') + if exists(r_noisy): + # add the noisy positions. + query = commons.tensor_add(r_noisy, query_cond) + if exists(batch_size): + pair_cond = rearrange(pair_cond, '... c i j d -> ... () c i j d') + + # add the combined single conditioning to the pair representation. + pair_cond = commons.tensor_add( + pair_cond, + self.outer_add( + F.relu(query_cond), + self.transformer.query_window_size, + self.transformer.key_window_size + ) + ) + # run a small MLP on the pair activations. + pair_cond = commons.tensor_add(pair_cond, self.outer_ff(pair_cond)) + # cross attention transformer. + query = self.transformer( + query, + query_cond, + pair_cond, + mask=mask[..., None, :] if exists(batch_size) else mask, + shard_size=shard_size + ) + # aggregate per-atom representation to per-token representation. + token_single_cond = F.relu(self.to_out(query)) + atom_to_token_idx = repeat( + atom_to_token_idx, '... i -> ... i d', d=token_single_cond.shape[-1] + ) + if exists(r_noisy) and exists(batch_size): + atom_to_token_idx = repeat( + atom_to_token_idx, '... i d -> ... m i d', m=token_single_cond.shape[-3] + ) + with accelerator.autocast(enabled=False): + token_single_cond = functional.scatter_mean( + atom_to_token_idx, token_single_cond.float(), dim=-2 + ) + + query_skip, query_cond_skip, pair_skip = query, query_cond, pair_cond + return token_single_cond, query_skip, query_cond_skip, pair_skip + + +class AtomAttentionDecoder(nn.Module): + """AtomAttentionDecoder""" + def __init__( + self, + dim, + dim_token=768, + depth=3, + heads=4, + atom_query_window_size=32, + atom_key_window_size=128, + ): + super().__init__() + + dim_single, _ = commons.embedd_dim_get(dim) # dim_atom_{single, pair} + + self.from_token = nn.Linear(dim_token, dim_single, bias=False) + self.transformer = AtomTransformer( + dim, + depth=depth, + heads=heads, + query_window_size=atom_query_window_size, + key_window_size=atom_key_window_size + ) + + def forward( + self, + batch, + token_single_cond, + query_skip, + context_skip, + pair_skip, + mask=None, + batch_size=None, + shard_size=None + ): + atom_to_token_idx = batch['atom_to_token_idx'] + if exists(batch_size): + atom_to_token_idx = repeat( + atom_to_token_idx, '... i -> ... m i', m=token_single_cond.shape[-3] + ) + # broadcast per-token activations to per-atom activations and add the skip + # connection + atom_single_cond = commons.tensor_add( + atom_layout.gather(self.from_token(token_single_cond), atom_to_token_idx), + query_skip + ) + # cross attention transformer + atom_single_cond = self.transformer( + atom_single_cond, + context_skip, + pair_skip, + mask=mask[..., None, :] if exists(batch_size) else mask, + shard_size=shard_size + ) + return atom_single_cond + + +class FourierEmbedding(nn.Module): + """FourierEmbedding + """ + def __init__(self, dim, seed=2147483647): + super().__init__() + + generator = torch.Generator() + generator.manual_seed( + env('profold2_fourier_embedding_seed', defval=seed, dtype=int) + ) + # randomly generate weight/bias once before training + self.w = nn.Parameter(torch.randn(dim, generator=generator), requires_grad=False) + self.b = nn.Parameter(torch.randn(dim, generator=generator), requires_grad=False) + + def forward(self, t): + # compute embeddings. scale w by t + v = t[..., None] * self.w + self.b + return torch.cos(2 * torch.pi * v) + + +class DiffusionConditioning(nn.Module): + """DiffusionConditioning""" + def __init__(self, dim, dim_noise=256, dim_inputs=449, sigma_data=16.0): + super().__init__() + + dim_single, dim_pair = commons.embedd_dim_get(dim) # dim_cond_{single, pair} + + self.from_single = nn.Sequential( + nn.LayerNorm(dim_single + dim_inputs, bias=False), + nn.Linear(dim_single + dim_inputs, dim_single, bias=False) + ) + self.from_pairwise = nn.Sequential( + nn.LayerNorm(dim_pair * 2, bias=False), + nn.Linear(dim_pair * 2, dim_pair, bias=False) + ) + self.from_pos_emb = RelativePositionEncoding(dim=dim) + self.from_noise = nn.Sequential( + FourierEmbedding(dim_noise), + nn.LayerNorm(dim_noise, bias=False), + nn.Linear(dim_noise, dim_single, bias=False) + ) + self.to_single = commons.residue_stack( + commons.FeedForward, 2, dim_single, mult=2, activation='SwiGLU', use_bias=False + ) + self.to_pairwise = commons.residue_stack( + commons.FeedForward, 2, dim_pair, mult=2, activation='SwiGLU', use_bias=False + ) + + self.sigma_data = sigma_data + + def forward( + self, + batch, + *, + noise_level, + inputs, + trunk_single_cond, + trunk_pair_cond, + batch_size=None, + shard_size=None + ): + # single conditioning + s = self.from_single(torch.cat((trunk_single_cond, inputs), dim=-1)) + + # HACK: fix Automatic Mixed Precision + with accelerator.autocast(enabled=False): + t = self.from_noise( + (torch.log(noise_level.float()) - math.log(self.sigma_data)) / 4. + ) + + s = commons.tensor_add( + rearrange(s, '... i d -> ... () i d') if exists(batch_size) else s, + rearrange(t, '... d -> ... () d') + ) + s = self.to_single(s, shard_size=shard_size) + + # pair conditioning + x = self.from_pairwise( + torch.cat( + ( + trunk_pair_cond, + self.from_pos_emb( + batch['seq_index'], + batch['seq_color'], + batch['seq_sym'], + batch['seq_entity'], + default(batch.get('token_index'), batch['seq_index']), + shard_size=shard_size + ) + ), + dim=-1 + ) + ) + x = self.to_pairwise(x, shard_size=shard_size) + + return s, x + + +class DiffusionModule(nn.Module): + """DiffusionModule""" + def __init__( + self, + dim, + dim_atom=(128, 16), + dim_token=768, + dim_inputs=449, + dim_noise=256, + sigma_data=16.0, + atom_encoder_depth=3, + atom_encoder_head_num=4, + transformer_depth=24, + transformer_group_size=4, + transformer_head_num=16, + transformer_dim_head=48, + atom_decoder_depth=3, + atom_decoder_head_num=4, + atom_query_window_size=32, + atom_key_window_size=128, + ): + super().__init__() + + dim_single, dim_pair = commons.embedd_dim_get(dim) # dim_cond_{single, pair} + dim_atom_single, *_ = commons.embedd_dim_get(dim_atom) + + self.conditioning = DiffusionConditioning( + dim, dim_inputs=dim_inputs, dim_noise=dim_noise, sigma_data=sigma_data + ) + + self.from_coord = nn.Linear(3, dim_atom_single, bias=False) + self.atom_encoder = AtomAttentionEncoder( + dim=dim_atom, + dim_cond=dim, + dim_token=dim_token, + depth=atom_encoder_depth, + heads=atom_encoder_head_num, + atom_query_window_size=atom_query_window_size, + atom_key_window_size=atom_key_window_size, + has_coords=True + ) + self.transformer_in = nn.Sequential( + nn.LayerNorm(dim_single, bias=False), + nn.Linear(dim_single, dim_token, bias=False) + ) + assert transformer_depth % transformer_group_size == 0 + self.transformer = commons.layer_stack( + DiffusionTransformerBlock, + depth=transformer_depth // transformer_group_size, + checkpoint_segment_size=env( + 'profold2_diffuser_checkpoint_segment_size', defval=1, dtype=int + ), + dim=(dim_token, dim_pair), + dim_cond=dim_single, + group_size=transformer_group_size, + heads=transformer_head_num, + dim_head=transformer_dim_head + ) + self.transformer_out = nn.LayerNorm(dim_token, bias=False) + self.atom_decoder = AtomAttentionDecoder( + dim=dim_atom, + dim_token=dim_token, + depth=atom_decoder_depth, + heads=atom_decoder_head_num, + atom_query_window_size=atom_query_window_size, + atom_key_window_size=atom_key_window_size + ) + self.to_coord = nn.Sequential( + nn.LayerNorm(dim_atom_single, bias=False), + nn.Linear(dim_atom_single, 3, bias=False) + ) + + @property + def sigma_data(self): + return self.conditioning.sigma_data + + def forward( + self, + batch, + x_noisy, + *, + x_mask, + noise_level, + inputs, + trunk_single_cond, + trunk_pair_cond, + use_conditioning=True, + batch_size=None, + shard_size=None + ): + # mask conditioning features if use_conditioning is False + inputs = inputs * use_conditioning + trunk_single_cond = trunk_single_cond * use_conditioning + trunk_pair_cond = trunk_pair_cond * use_conditioning + + # conditioning + single_cond, pair_cond = self.conditioning( + batch, + noise_level=noise_level, + inputs=inputs, + trunk_single_cond=trunk_single_cond, + trunk_pair_cond=trunk_pair_cond, + batch_size=batch_size, + shard_size=shard_size + ) + + ################################################## + # EDM: r_noisy = c_in * x_noisy + # where c_in = 1 / sqrt(sigma_data^2 + sigma^2) + ################################################## + + # scale positions to dimensionless + with accelerator.autocast(enabled=False): + r_noisy = x_noisy.float() / torch.sqrt( + self.conditioning.sigma_data**2 + noise_level.float()**2 + )[..., None, None] + r_noisy = self.from_coord(r_noisy) + + ################################################## + # EDM: r_update = F_theta(r_noisy, c_noise(sigma)) + ################################################## + + # sequence-local Atom Attention and aggregation to coasrse-grained tokens + token_cond, query_skip, context_skip, pair_skip = self.atom_encoder( + batch, + r_noisy, + trunk_single_cond=trunk_single_cond, + trunk_pair_cond=pair_cond, + mask=x_mask, + batch_size=batch_size, + shard_size=shard_size + ) + + # full self-attention on token level + token_cond = commons.tensor_add(token_cond, self.transformer_in(single_cond)) + token_cond, *_ = self.transformer( + token_cond, + single_cond, + pair_cond[..., None, :, :, :] if exists(batch_size) else pair_cond, + mask=batch['mask'][..., None, :] if exists(batch_size) else batch['mask'] + ) + token_cond = self.transformer_out(token_cond) + + # broadcast token activations to atoms and run Sequence-local Atom Attention + r_update = self.atom_decoder( + batch, + token_cond, + query_skip, + context_skip, + pair_skip, + mask=x_mask, + batch_size=batch_size, + shard_size=shard_size + ) + + ################################################## + # EDM: D = c_skip * x_noisy + c_out * r_update + # c_skip = sigma_data^2 / (sigma_data^2 + sigma^2) + # c_out = (sigma_data * sigma) / sqrt(sigma_data^2 + sigma^2) + # s_ratio = 1 + (sigma / sigma_data)^2 + # c_skip = 1 / s_ratio + # c_out = sigma / sqrt(s_ratio) + ################################################## + + # rescale updates to positions and combine with input positions + with accelerator.autocast(enabled=False): + r_update = self.to_coord(r_update.float()) + noise_level = noise_level[..., None, None].float() + s_ratio = 1 + (noise_level / self.conditioning.sigma_data)**2 + x_denoised = ( + x_noisy.float() / s_ratio + r_update * noise_level / torch.sqrt(s_ratio) + ) + + return x_denoised + + def loss_scale(self, noise_level): + ################################################## + # EDM: L = \lambda(sigma) || x_pred - x_true || + # where \lambda(sigma) = (sigma^2 + sigma_data^2) / (sigma * sigma_data)^2 + ################################################## + return (noise_level**2 + self.sigma_data**2) / (noise_level * self.sigma_data)**2 + + +class DiffusionSampler(nn.Module): + """DiffusionSampler""" + def __init__( + self, + dim, + dim_atom=(128, 16), + dim_token=768, + dim_inputs=449, + dim_noise=256, + sigma_data=16.0, + sigma_mean=-1.2, + sigma_std=1.5, + sigma_min=4e-4, + sigma_max=160, + rho=7., + trans_scale_factor=1.0, + diffuser_atom_encoder_depth=3, + diffuser_atom_encoder_head_num=4, + diffuser_transformer_depth=24, + diffuser_transformer_group_size=4, + diffuser_transformer_head_num=16, + diffuser_atom_decoder_depth=3, + diffuser_atom_decoder_head_num=4, + diffuser_atom_query_window_size=32, + diffuser_atom_key_window_size=128, + ): + super().__init__() + + self.diffuser = DiffusionModule( + dim, + dim_atom=dim_atom, + dim_token=dim_token, + dim_inputs=dim_inputs, + dim_noise=dim_noise, + sigma_data=sigma_data, + atom_encoder_depth=diffuser_atom_encoder_depth, + atom_encoder_head_num=diffuser_atom_encoder_head_num, + transformer_depth=diffuser_transformer_depth, + transformer_group_size=diffuser_transformer_group_size, + transformer_head_num=diffuser_transformer_head_num, + atom_decoder_depth=diffuser_atom_decoder_depth, + atom_decoder_head_num=diffuser_atom_decoder_head_num, + atom_query_window_size=diffuser_atom_query_window_size, + atom_key_window_size=diffuser_atom_key_window_size + ) + + # noise sampler - train + self.sigma_mean = sigma_mean + self.sigma_std = sigma_std + # noise scheduler - inference + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.rho = rho + # centre random augmenter + self.trans_scale_factor = trans_scale_factor + + def forward( + self, + batch: dict[str, Any], + inputs: Optional[torch.Tensor] = None, + trunk_single_cond: Optional[torch.Tensor] = None, + trunk_pair_cond: Optional[torch.Tensor] = None, + use_conditioning: bool = True, + diffuser_batch_size: Optional[int] = None, + shard_size: Optional[int] = None + ): + x_true, x_mask = atom_layout.flatten( + batch['atom_to_token_idx'], + batch['atom_within_token_idx'], + batch['coord'], + batch['coord_mask'] + ) + assert x_true.shape == batch['ref_pos'].shape + assert x_mask.shape == batch['ref_mask'].shape + # apply random rotation and translation + x_noisy = self.centre_random_augmenter( + x_true, mask=x_mask, batch_size=diffuser_batch_size + ) + # sigma: independent noise-level [..., N_sample] + noise_level = self.noise_sampler( + *x_true.shape[:-2], batch_size=diffuser_batch_size, device=x_true.device + ) + # noise + noise = self.noise_apply( + batch, + torch.randn_like(x_noisy), + noise_level[..., None, None], + batch_size=diffuser_batch_size + ) + # denoised_x + x_denoised = self.diffuser( + batch, + x_noisy=x_noisy + noise, + x_mask=x_mask, + noise_level=noise_level, + inputs=inputs, + trunk_single_cond=trunk_single_cond, + trunk_pair_cond=trunk_pair_cond, + use_conditioning=use_conditioning, + batch_size=diffuser_batch_size, + shard_size=shard_size + ) + return x_true, x_mask, x_noisy, x_denoised, noise_level + + def loss_scale(self, noise_level): + return self.diffuser.loss_scale(noise_level) + + def noise_apply(self, batch, noise, noise_level, batch_size=None): + with accelerator.autocast(enabled=False): + if isinstance(noise_level, torch.Tensor): + return noise * noise_level.float() + return noise * noise_level + + def noise_sampler(self, *size, batch_size=None, device=None): + """DiffusionNoiseSampler: sample the noise-level.""" + if exists(batch_size): + size = size + (batch_size, ) + with accelerator.autocast(enabled=False): + x = torch.randn(*size, device=device) + return torch.exp(self.sigma_mean + self.sigma_std * x) * self.diffuser.sigma_data + + def noise_scheduler(self, steps=200): + """DiffusionNoiseScheduler: schedule the noise-level (time steps).""" + # t = torch.linspace(0, 1, steps + 1, device=device, dtype=dtype) + s_min, s_max = self.sigma_min**(1. / self.rho), self.sigma_max**(1. / self.rho) + return [ + self.diffuser.sigma_data * (s_max + (s_min - s_max) * t / steps)**self.rho + for t in range(steps + 1) + ] + # return self.sigma_data * (s_max + (s_min - s_max) * t)**self.rho + + def initial_sampler(self, *size, batch_size=None, device=None): + if exists(batch_size): + size = size[:-1] + (batch_size, size[-1]) + return torch.randn(*size, 3, device=device) + + def centre_random_augmenter(self, x, mask=None, batch_size=None): + if exists(mask): + c = functional.masked_mean(value=x, mask=mask[..., None], dim=-2, keepdim=True) + else: + c = torch.mean(x, dim=-2, keepdim=True) + + x = commons.tensor_sub(x, c) + if exists(batch_size): + x = repeat(x, '... i d -> ... n i d', n=batch_size) + + with accelerator.autocast(enabled=False): + R, t = functional.rigids_from_randn(*x.shape[:-2], device=x.device) # pylint: disable=invalid-name + x = functional.rigids_apply( + (R[..., None, :, :], t[..., None, :] * self.trans_scale_factor), x.float() + ) + + return x + + def sample( + self, + batch, + inputs=None, + trunk_single_cond=None, + trunk_pair_cond=None, + steps=200, + gamma0: float = 0.8, + gamma_min: float = 1.0, + noise_scale_lambda: float = 1.003, + step_scale_eta: float = 1.5, + use_conditioning: bool = True, + diffuser_batch_size: Optional[int] = None, + shard_size: Optional[int] = None + ): + assert not self.training + + diffuser_shard_size = None + if exists(diffuser_batch_size): + diffuser_shard_size = min( + env('profold2_diffuser_sampling_shard_size', defval=1, dtype=int), + diffuser_batch_size + ) + + tau_list = self.noise_scheduler(steps=steps) + + x_mask = batch['ref_mask'] + x = self.noise_apply( + batch, + self.initial_sampler( + *x_mask.shape, batch_size=diffuser_batch_size, device=x_mask.device + ), + tau_list[0], + batch_size=diffuser_batch_size + ) + for tau_idx in tqdm(range(1, len(tau_list)), desc='Diffusion Sampling'): + x = self.centre_random_augmenter( + x, mask=x_mask[..., None, :] if exists(diffuser_batch_size) else x_mask + ) + gamma = gamma0 if tau_list[tau_idx] > gamma_min else 0 + t_hat = tau_list[tau_idx - 1] * (gamma + 1) + delta = noise_scale_lambda * math.sqrt(t_hat**2 - tau_list[tau_idx - 1]**2) + x_noisy = x + self.noise_apply( + batch, torch.randn_like(x), delta, batch_size=diffuser_batch_size + ) + + noise_level = torch.full(x_mask.shape[:-1], t_hat, device=x_mask.device) + if exists(diffuser_batch_size): + noise_level = noise_level[..., None] + + # denoise + x_denoised = functional.sharded_apply( + functools.partial( + self.diffuser, + batch, + noise_level=noise_level, + x_mask=x_mask, + inputs=inputs, + trunk_single_cond=trunk_single_cond, + trunk_pair_cond=trunk_pair_cond, + use_conditioning=use_conditioning, + batch_size=diffuser_shard_size, + shard_size=shard_size + ), [x_noisy], + shard_size=diffuser_shard_size, + shard_dim=-3, + cat_dim=-3 + ) + + x_delta = (x_noisy - x_denoised) * (tau_list[tau_idx] - t_hat) / t_hat + x = x_noisy + step_scale_eta * x_delta + + atom_to_token_idx, atom_within_token_idx = map( + lambda key: batch[key], ('atom_to_token_idx', 'atom_padding_token_idx') + ) + if exists(diffuser_batch_size): + atom_to_token_idx, atom_within_token_idx = map( + lambda t: repeat(t, '... i -> ... m i', m=diffuser_batch_size), + (atom_to_token_idx, atom_within_token_idx) + ) + return atom_layout.unflatten(atom_to_token_idx, atom_within_token_idx, x) + + +def multi_chain_permutation_alignment(value, batch): + return batch # FIX: disable chain permutation. diff --git a/profold2/model/head.py b/profold2/model/head.py index 7079dea7..7d301a22 100644 --- a/profold2/model/head.py +++ b/profold2/model/head.py @@ -1,13 +1,18 @@ import sys +import functools import logging +import numpy as np import torch from torch import nn from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint from einops import rearrange, repeat from profold2.common import residue_constants -from profold2.model import accelerator, commons, functional, folding +from profold2.model import ( + accelerator, atom_layout, commons, functional, diffusion, folding, evoformer +) from profold2.utils import default, env, exists logger = logging.getLogger(__name__) @@ -400,6 +405,264 @@ def loss(self, value, batch): return dict(loss=avg_error) +class DonfidenceHead(nn.Module): + """Head to predict diffusion generation confidence. + """ + def __init__( + self, + dim, + dim_inputs=None, + pairformer_depth=4, + pairformer_head_num=(16, 4), + pairformer_head_dim=(24, 32), + pairformer_attn_dropout=0., + pairformer_ff_dropout=0., + pairformer_shard_size=None, + max_atoms_per_token=20, + pae_bin_min=0, + pae_bin_max=32, + pae_bin_num=64, + pae_weight=0.0, + pde_bin_min=0, + pde_bin_max=32, + pde_bin_num=64, + pde_weight=1.0, + plddt_bin_num=50, + plddt_weight=1.0, + dgram_bin_min=3.25, + dgram_bin_max=50.75, + dgram_bin_num=39, + min_resolution=.0, + max_resolution=sys.float_info.max + ): + super().__init__() + + dim_single, dim_pairwise = commons.embedd_dim_get(dim) + + dgram_breaks = torch.linspace(dgram_bin_min, dgram_bin_max, steps=dgram_bin_num) + self.register_buffer('dgram_breaks', dgram_breaks, persistent=False) + + self.from_inputs = nn.Linear( + dim_inputs, dim_pairwise * 2, bias=False + ) if exists(dim_inputs) else None + self.from_dij = nn.Linear(dgram_bin_num, dim_pairwise, bias=False) + + self.pairformer = evoformer.Pairformer( + depth=pairformer_depth, + dim_single=dim_single, + dim_pairwise=dim_pairwise, + heads=pairformer_head_num, + dim_head=pairformer_head_dim, + attn_dropout=pairformer_attn_dropout, + ff_dropout=pairformer_ff_dropout + ) + self.pairformer_shard_size = pairformer_shard_size + + self.to_pae = PAEHead( + dim, + buckets_num=pae_bin_num, + buckets_first_break=pae_bin_min, + buckets_last_break=pae_bin_max, + min_resolution=min_resolution, + max_resolution=max_resolution + ) + self.to_pde = PDEHead( + dim, + buckets_num=pde_bin_num, + buckets_first_break=pde_bin_min, + buckets_last_break=pde_bin_max, + min_resolution=min_resolution, + max_resolution=max_resolution + ) + + self.to_plddt = PLDDTHead( + dim, buckets_num=plddt_bin_num, atoms_per_token=max_atoms_per_token + ) + + self.pae_weight = pae_weight + self.pde_weight = pde_weight + self.plddt_weight = plddt_weight + + def forward(self, headers, representations, batch): + assert 'diffusion' in headers and 'coords' in headers['diffusion'] + + inputs, single_repr, pairwise_repr = map( + lambda key: representations[key].detach(), ('inputs', 'single', 'pair') + ) + + s_i, s_j = torch.chunk(self.from_inputs(inputs), 2, dim=-1) + pairwise_repr = commons.tensor_add( + s_i[..., :, None, :] + s_j[..., None, :, :], pairwise_repr + ) + + # embed pair distance of representative atoms + with torch.no_grad(): + ca_idx = residue_constants.atom_order['CA'] + dij = functional.distogram_from_positions( + self.dgram_breaks, headers['diffusion']['coords'][..., ca_idx, :] + ) + pairwise_repr = commons.tensor_add(pairwise_repr, self.from_dij(dij)) + + pairwise_repr, single_repr = self.pairformer( + pairwise_repr, + single_repr, + mask=batch['mask'][..., :, None] * batch['mask'][..., None, :], + seq_mask=batch['mask'], + shard_size=self.pairformer_shard_size + ) + + # NOTE: faked headers & representations ^_~ + headers = { + 'folding': { + 'frames': None, + 'batch_size': headers['diffusion'].get('batch_size'), + 'coords': headers['diffusion']['coords'], + } + } + representations = {'single': single_repr, 'pair': pairwise_repr} + pae = self.to_pae(headers, representations, batch) + pde = self.to_pde(headers, representations, batch) + plddt = self.to_plddt(headers, representations, batch) + + return dict(pae=pae, pde=pde, plddt=plddt) + + def loss(self, value, batch): + # pae_loss = self.to_pae.loss(value['pae'], batch) + # logger.debug('DonfidenceHead.pae.loss: %s', pae_loss) + pde_loss = self.to_pde.loss(value['pde'], batch) + logger.debug('DonfidenceHead.pde.loss: %s', pde_loss) + plddt_loss = self.to_plddt.loss(value['plddt'], batch) + logger.debug('DonfidenceHead.plddt.loss: %s', plddt_loss) + + loss = sum( + [ + # self.pae_weight * pae_loss, + self.pde_weight * pde_loss['loss'], + self.plddt_weight * plddt_loss['loss'] + ] + ) + return dict(loss=loss) + + +class DiffusionHead(nn.Module): + """Head to generate 3d struct. + """ + def __init__( + self, + dim, + has_inputs=True, + diffuser_batch_size=None, + diffuser_shard_size=None, + padding_atom_weight=0., + **kwargs + ): + super().__init__() + + if has_inputs: + dim_token = kwargs.pop('inputs_dim_token', 384) + num_tokens = kwargs.pop( + 'inputs_num_tokens', + len(residue_constants.restypes_with_x) + 1 + ) + + self.embedder = diffusion.InputFeatureEmbedder( + dim_token=dim_token, num_tokens=num_tokens + ) + + kwargs['dim_inputs'] = dim_token + num_tokens + 1 + self.diffuser_module = diffusion.DiffusionSampler(dim=dim, **kwargs) + self.diffuser_batch_size = env( + 'profold2_diffusion_batch_size', defval=diffuser_batch_size, dtype=int + ) + self.diffuser_shard_size = env( + 'profold2_diffusion_shard_size', defval=diffuser_shard_size, dtype=int + ) + assert not exists(self.diffuser_shard_size) or self.diffuser_shard_size > 0 + + assert 0 <= padding_atom_weight <= 1 + self.padding_atom_weight = padding_atom_weight + + def forward(self, headers, representations, batch): + assert hasattr(self, 'embedder') ^ ('inputs' in representations) + if hasattr(self, 'embedder'): + inputs = self.embedder(batch) + else: + inputs = representations['inputs'] + use_conditioning = batch.get('diffuser_use_conditioning', True) + + if not self.training: # inference + diffuser_steps = env('profold2_diffusion_steps', defval=200, dtype=int) + step_scale_eta = env('profold2_diffusion_step_scale_eta', defval=1.5, dtype=float) + x = self.diffuser_module.sample( + batch, + inputs=inputs, + trunk_single_cond=representations['single'], + trunk_pair_cond=representations['pair'], + steps=diffuser_steps, + step_scale_eta=step_scale_eta, + use_conditioning=use_conditioning, + diffuser_batch_size=self.diffuser_batch_size, + shard_size=self.diffuser_shard_size + ) + return dict( + coords=x, + batch_size=self.diffuser_batch_size, + use_conditioning=use_conditioning + ) + + # training + x_true, x_mask, x_noisy, x_denoised, noise_level = self.diffuser_module( + batch, + inputs=inputs, + trunk_single_cond=representations['single'], + trunk_pair_cond=representations['pair'], + use_conditioning=use_conditioning, + diffuser_batch_size=self.diffuser_batch_size, + shard_size=self.diffuser_shard_size + ) + return dict( + x_true=x_true, + x_mask=x_mask, + x_noisy=x_noisy, + x_denoised=x_denoised, + noise_level=noise_level, + batch_size=self.diffuser_batch_size, + use_conditioning=use_conditioning + ) + + def loss(self, value, batch): + if not self.training: # inference + return None + + use_conditioning = value.get('use_conditioning') + # training + noise_level = value['noise_level'] + x_true, x_pred, x_mask = value['x_true'], value['x_denoised'], value['x_mask'] + x_mask = x_mask * torch.where( + batch['atom_within_token_mask'], 1.0, self.padding_atom_weight + ) + + if exists(value.get('batch_size')): + x_true, x_mask = x_true[..., None, :, :], x_mask[..., None, :] + + with accelerator.autocast(enabled=False): + x_pred, x_true, x_mask = x_pred.float(), x_true.float(), x_mask.float() + + # weighted_align + with torch.no_grad(): + # TODO: change x_mask to weight + R, t = functional.kabsch_transform(x_pred, x_true, x_mask) + x_true = functional.rigids_apply((R[..., None, :, :], t[..., None, :]), x_true) + + # loss weight on every noise scale + sigma_scale = self.diffuser_module.loss_scale(noise_level.float()) + errors = sigma_scale[..., None] * torch.sum((x_pred - x_true)**2, dim=-1) + + loss = torch.mean(functional.masked_mean(value=errors, mask=x_mask, dim=-1)) / 3. + logger.debug('DiffusionHead.loss(use_conditioning=%s): %s', use_conditioning, loss) + return dict(loss=loss) + + class FoldingHead(nn.Module): """Head to predict 3d struct. """ @@ -724,7 +987,34 @@ def forward(self, act, batch, batch_size=None): nn.Linear(num_channels, buckets_num) ) else: - raise NotImplementedError + + class AtomwiseBinPred(nn.Module): + def __init__(self, dim, bin_num, max_atoms): + super().__init__() + + self.norm = nn.LayerNorm(dim) + self.w = nn.Parameter(data=torch.empty(max_atoms, dim, bin_num)) + + def forward(self, act, batch, batch_size=None): + atom_to_token_idx = batch['atom_to_token_idx'] + if exists(batch_size): + atom_to_token_idx = repeat( + atom_to_token_idx, '... i -> ... n i', n=batch_size + ) + act = functional.batched_gather( + act, atom_to_token_idx, dim=-2, has_batch_dim=True + ) + act = self.norm(act) + atom_within_token_idx = batch['atom_within_token_idx'] + if exists(batch_size): + atom_within_token_idx = repeat( + atom_within_token_idx, '... i -> ... n i', n=batch_size + ) + return torch.einsum( + '... i d, ... i d c -> ... i c', act, self.w[atom_within_token_idx] + ) + + self.net = AtomwiseBinPred(dim, buckets_num, atoms_per_token) self.atoms_per_token = atoms_per_token self.buckets_num = buckets_num @@ -738,6 +1028,8 @@ def forward(self, headers, representations, batch): x = representations['single'] batch_size = None + if 'folding' in headers: + batch_size = headers['folding'].get('batch_size') ret = dict(logits=self.net(x, batch, batch_size=batch_size), batch_size=batch_size) if self.training: @@ -761,12 +1053,42 @@ def loss(self, value, batch): pred_points.shape[:-1], device=pred_points.device, dtype=torch.bool ) + batch_size = value.get('batch_size') ca_idx = residue_constants.atom_order['CA'] if self.atoms_per_token == 1: pred_cdist = torch.cdist(pred_points[..., ca_idx, :], pred_points[..., ca_idx, :]) true_cdist = torch.cdist(true_points[..., ca_idx, :], true_points[..., ca_idx, :]) cdist_mask = lddt_mask(points_mask[..., ca_idx]) points_mask = points_mask[..., ca_idx] + else: + atom_to_token_idx, atom_within_token_idx = map( + lambda key: batch[key], ('atom_to_token_idx', 'atom_within_token_idx') + ) + x_true, x_mask = atom_layout.flatten( + atom_to_token_idx, atom_within_token_idx, true_points, points_mask + ) + + if exists(batch_size): + x_pred = atom_layout.flatten( + repeat(atom_to_token_idx, '... i -> ... m i', m=batch_size), + repeat(atom_within_token_idx, '... i -> ... m i', m=batch_size), + pred_points + ) + else: + x_pred = atom_layout.flatten( + atom_to_token_idx, atom_within_token_idx, pred_points + ) + pred_cdist = torch.cdist(x_pred, pred_points[..., ca_idx, :]) + true_cdist = torch.cdist(x_true, true_points[..., ca_idx, :]) + cdist_mask = atom_layout.unflatten( + repeat(atom_to_token_idx, '... i -> ... i j', j=x_mask.shape[-1]), + repeat(atom_within_token_idx, '... i -> ... i j', j=x_mask.shape[-1]), + coord_mask=lddt_mask(x_mask) + )[..., ca_idx] + if exists(batch_size): + true_cdist = true_cdist[..., None, :, :] + cdist_mask = cdist_mask[..., None, :, :] + points_mask = x_mask with torch.no_grad(): # Shape (..., l) @@ -787,6 +1109,8 @@ def loss(self, value, batch): batch['resolution'] < self.max_resolution ) points_mask = torch.einsum('...,... i -> ... i', mask, points_mask) + if exists(batch_size): + points_mask = points_mask[..., None, :] loss = torch.sum(errors * points_mask) / (1e-6 + torch.sum(points_mask)) logger.debug('LDDTHead.loss: %s', loss) return dict(loss=loss) @@ -889,6 +1213,78 @@ def to_local(affine): return dict(loss=loss) +class PDEHead(nn.Module): + """Head for predicted distance loss + """ + def __init__( + self, + dim, + buckets_num=64, + buckets_first_break=0., + buckets_last_break=31., + min_resolution=.0, + max_resolution=sys.float_info.max + ): + super().__init__() + _, dim = commons.embedd_dim_get(dim) + + buckets = torch.linspace( + buckets_first_break, buckets_last_break, steps=buckets_num - 1 + ) + self.register_buffer('buckets', buckets, persistent=False) + self.net = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, buckets_num, bias=False)) + + self.buckets_num = buckets_num + + self.min_resolution = min_resolution + self.max_resolution = max_resolution + + def forward(self, headers, representations, batch): + assert 'folding' in headers and 'coords' in headers['folding'] + pairwise_repr = representations['pair'] + pde = self.net( + pairwise_repr + rearrange(pairwise_repr, '... i j d -> ... j i d') # symmetric + ) + return {'logits': pde, 'coords': headers['folding']['coords']} + + def loss(self, value, batch): + x_pred, logits = value['coords'], value['logits'] + x_true, x_mask = batch.get('coord'), batch.get('coord_mask') + + x_true = default(x_true, x_pred.detach()) + x_mask = default(x_mask, batch['coord_exists']) + + with torch.no_grad(): + ca_idx = residue_constants.atom_order['CA'] + x_pred = x_pred[..., ca_idx, :] + x_true, x_mask = x_true[..., ca_idx, :], x_mask[..., ca_idx] + + assert 0 <= x_pred.dim() - x_true.dim() <= 1 + if x_pred.dim() > x_true.dim(): # batched + x_true, x_mask = x_true[..., None, :, :], x_mask[..., None, :] + + true_bins = torch.sum( + torch.unsqueeze( + torch.abs(torch.cdist(x_pred, x_pred) - torch.cdist(x_true, x_true)), dim=-1 + ) < self.buckets, dim=-1 + ) + + errors = softmax_cross_entropy( + labels=F.one_hot(true_bins, self.buckets_num), logits=logits + ) + + # Filter by resolution + mask = torch.logical_and( + self.min_resolution <= batch['resolution'], + batch['resolution'] < self.max_resolution + ) + sq_mask = x_mask[..., None] * x_mask[..., None, :] + sq_mask = torch.einsum('b,b ... -> b ...', mask, sq_mask) + loss = torch.sum(errors * sq_mask) / (1e-8 + torch.sum(sq_mask)) + logger.debug('PDEHead.loss: %s', loss) + return dict(loss=loss) + + class PairingHead(nn.Module): """Head to predict a nucleic acid base pairing. """ @@ -1159,6 +1555,45 @@ def forward(self, headers, representations, batch): logger.debug('MetricDictHead.coevolution.perplexity: %s', avg_error) metrics['coevolution']['perplexity'] = avg_error + if 'sequence' in headers: + assert 'logits' in headers['sequence'] + prob = F.softmax(headers['sequence']['logits'], dim=-1) + + metrics['sequence'] = MetricDict() + + num_class = prob.shape[-1] + mask = F.one_hot( + batch.get('seq_true', batch['seq']).long(), num_classes=num_class + ) + mask = mask * batch['mask'][..., None] + if exists(headers['sequence'].get('batch_size')): + mask = mask[..., None, :, :] + + pred = torch.amax(prob, keepdim=True, dim=-1) + avg_p = torch.mean(functional.masked_mean(value=pred, mask=mask, dim=(-1, -2))) + logger.debug('MetricDictHead.sequence.confidence: %s', avg_p) + metrics['sequence']['confidence'] = avg_p + + pred = torch.argmax(prob, dim=-1) + avg_sim = torch.mean( + functional.masked_mean( + value=F.one_hot(pred, num_classes=num_class), mask=mask, dim=(-1, -2) + ) + ) + logger.debug('MetricDictHead.sequence.identity: %s', avg_sim) + metrics['sequence']['identity'] = avg_sim + + errors = -torch.sum(mask * torch.log(prob + 10e-8), dim=-1) + avg_error = torch.exp( + torch.mean( + functional.masked_mean( + value=errors, mask=torch.sum(mask, dim=-1), dim=-1 + ) + ) + ) + logger.debug('MetricDictHead.sequence.perplexity: %s', avg_error) + metrics['sequence']['perplexity'] = avg_error + if 'folding' in headers and 'coords' in headers['folding'] and ( 'coord_mask' in batch or 'coord_exists' in batch ): @@ -1514,6 +1949,167 @@ def loss(self, value, batch): return dict(loss=avg_error_motif + avg_error_fitness) +class SequenceDecoderHead(nn.Module): + """Head to decode sequence from 3D-structure. + """ + def __init__( + self, + dim, + pairformer_depth=4, + pairformer_head_num=(16, 4), + pairformer_head_dim=(24, 32), + pairformer_attn_dropout=0., + pairformer_ff_dropout=0., + pairformer_shard_size=None, + dgram_bin_min=0, + dgram_bin_max=32, + dgram_bin_num=64, + atom_noise_ratio=0., + atom_noise_level=0., + **kwargs + ): + super().__init__() + + dim_single, dim_pairwise = commons.embedd_dim_get(dim) + + dgram_breaks = torch.linspace(dgram_bin_min, dgram_bin_max, steps=dgram_bin_num) + self.register_buffer('dgram_breaks', dgram_breaks, persistent=False) + + self.from_si = nn.Linear( + (dgram_bin_num + 1) * residue_constants.atom14_type_num, dim_single + ) + if pairformer_depth > 0: + self.from_dij = nn.Linear(dgram_bin_num + 1, dim_pairwise) + self.from_pos = diffusion.RelativePositionEncoding(dim_pairwise) + + self.pairformer = evoformer.Pairformer( + depth=pairformer_depth, + dim_single=dim_single, + dim_pairwise=dim_pairwise, + heads=pairformer_head_num, + dim_head=pairformer_head_dim, + attn_dropout=pairformer_attn_dropout, + ff_dropout=pairformer_ff_dropout, + **kwargs + ) + self.to_si = nn.Sequential( + nn.LayerNorm(dim_single), + nn.Linear(dim_single, len(residue_constants.restypes_with_x), bias=False) + ) + self.pairformer_shard_size = pairformer_shard_size + self.pairformer_shard_size = env( + 'profold2_sequence_pairformer_shard_size', + defval=pairformer_shard_size, + dtype=int + ) + self.atom_noise_ratio = atom_noise_ratio + self.atom_noise_level = env( + 'profold2_sequence_atom_noise_level', defval=atom_noise_level, dtype=float + ) + + def forward(self, headers, representations, batch): + batch_size = None + + # features + atom_to_token_idx, atom_within_token_idx, atom_padding_token_idx = map( + lambda key: batch[key], + ('atom_to_token_idx', 'atom_within_token_idx', 'atom_padding_token_idx') + ) + if 'diffusion' in headers: + batch_size = headers['diffusion']['batch_size'] + if 'coords' in headers['diffusion']: # inference + coord = headers['diffusion']['coords'] + coord_mask = atom_layout.unflatten( + atom_to_token_idx, atom_padding_token_idx, coord_mask=batch['ref_mask'] + ) + if exists(batch_size): + coord_mask = repeat(coord_mask, '... i c -> ... n i c', n=batch_size) + else: # training + x_true, x_mask = map( + lambda key: headers['diffusion'][key].detach(), ('x_denoised', 'x_mask') + ) + if exists(batch_size): + atom_to_token_idx, atom_padding_token_idx = map( + lambda t: repeat(t, '... i -> ... m i', m=batch_size), + (atom_to_token_idx, atom_padding_token_idx) + ) + x_mask = repeat(x_mask, '... i -> ... m i', m=batch_size) + coord, coord_mask = atom_layout.unflatten( + atom_to_token_idx, atom_padding_token_idx, x_true, x_mask + ) + coord = coord.detach() # torch.no_grad() + else: + x_true, x_mask = atom_layout.flatten( + atom_to_token_idx, atom_within_token_idx, batch['coord'], batch['coord_mask'] + ) + coord, coord_mask = atom_layout.unflatten( + atom_to_token_idx, atom_padding_token_idx, x_true, x_mask + ) + + + # augment coord with eps * N(0, 1) + if self.atom_noise_level > 0 and ( + not self.training or np.random.random() < self.atom_noise_ratio + ): + coord = coord + self.atom_noise_level * torch.randn_like(coord) + + ca_idx = residue_constants.atom_order['CA'] + # embedders: single - intra token + bij = coord_mask[..., ca_idx:ca_idx + 1, None] * coord_mask[..., None, :] + dij = functional.distogram_from_positions( + self.dgram_breaks, coord[..., ca_idx:ca_idx + 1, :], coord + ) * bij[..., None] + single_repr = self.from_si( + rearrange( + torch.cat((dij, bij[..., None]), dim=-1), '... i u v d -> ... i (u v d)' + ) + ) + if hasattr(self, 'pairformer'): + # embedders: pair - inter token + bij = coord_mask[..., :, None, ca_idx] * coord_mask[..., None, :, ca_idx] + dij = functional.distogram_from_positions( + self.dgram_breaks, coord[..., ca_idx, :] + ) * bij[..., None] + pairwise_repr = self.from_dij(torch.cat((dij, bij[..., None]), dim=-1)) + position_repr = self.from_pos( + batch['seq_index'], + batch['seq_color'], + batch['seq_sym'], + batch['seq_entity'], + default(batch.get('token_index'), batch['seq_index']), + ) + mask = batch['mask'] + if exists(batch_size): + position_repr = position_repr[..., None, :, :, :] + mask = repeat(mask, '... i -> ... n i', n=batch_size) + # trunk + pairwise_repr, single_repr = self.pairformer( + pairwise_repr + position_repr, + single_repr, + mask=mask[..., :, None] * mask[..., None, :], + seq_mask=mask, + shard_size=self.pairformer_shard_size + ) + + return dict(logits=self.to_si(single_repr), batch_size=batch_size) + + def loss(self, value, batch): + assert 'mask' in batch and 'seq' in batch + assert 'logits' in value + + seq, mask = batch.get('seq_true', batch['seq']), batch['mask'] + if exists(value.get('batch_size')): + seq, mask = seq[..., None, :], mask[..., None, :] + labels = F.one_hot(seq.long(), num_classes=len(residue_constants.restypes_with_x)) + errors = softmax_cross_entropy(labels=labels, logits=value['logits']) + avg_error = torch.mean( + functional.masked_mean(value=errors, mask=mask, epsilon=1e-6, dim=-1) + ) + logger.debug('SequenceDecoderHead.loss: %s', avg_error) + + return dict(loss=avg_error) + + class SequenceProfileHead(nn.Module): """Head to predict sequence profile. """ @@ -1565,6 +2161,84 @@ def loss(self, value, batch): return dict(loss=avg_error) +class SmoothLDDTHead(nn.Module): + """Head to calc the LDDT loss. + """ + def __init__( + self, + dim, + padding_atom_weight=0., + representative_atom_weight=1., + intertoken_atom_weight=1., + intratoken_atom_weight=1., + shard_size=None + ): + super().__init__() + del dim + + assert 0 <= padding_atom_weight <= 1 + self.padding_atom_weight = padding_atom_weight + self.representative_atom_weight = representative_atom_weight + self.intertoken_atom_weight = intertoken_atom_weight + self.intratoken_atom_weight = intratoken_atom_weight + self.shard_size = shard_size + + def forward(self, headers, representations, batch): + assert 'diffusion' in headers + if self.training: + return { + key: headers['diffusion'][key] + for key in ('x_true', 'x_mask', 'x_denoised', 'batch_size') + } + return None + + def loss(self, value, batch): + x_pred, x_true, x_mask = value['x_denoised'], value['x_true'], value['x_mask'] + atom_to_token_idx, atom_within_token_mask, atom_repr_token_mask = map( + lambda key: batch[key], + ('atom_to_token_idx', 'atom_within_token_mask', 'atom_repr_token_mask') + ) + seq = functional.batched_gather( + batch['seq'], atom_to_token_idx, dim=-1, has_batch_dim=True + ) + x_mask = x_mask * torch.where(atom_within_token_mask, 1.0, self.padding_atom_weight) + x_mask = x_mask * torch.where( + atom_repr_token_mask, self.representative_atom_weight, 1.0 + ) + + # prepare true_cdist, cdist_mask and cutoff + true_cdist = torch.cdist(x_true, x_true) + cdist_mask = lddt_mask(x_mask) * torch.where( + atom_to_token_idx[..., :, None] != atom_to_token_idx[..., None, :], + self.intertoken_atom_weight, + self.intratoken_atom_weight + ) + cutoff = lddt_cutoff(seq) + # append batch_size if required + shard_size = None + if exists(value.get('batch_size')): + true_cdist, cdist_mask = true_cdist[..., None, :, :], cdist_mask[..., None, :, :] + cutoff = cutoff[..., None, :, :] + shard_size = self.shard_size + + lddt_fn = functools.partial( + functional.lddt, per_residue=False, cutoff=cutoff, smooth=True + ) + if exists(shard_size): + lddt_fn = functools.partial(checkpoint, lddt_fn, use_reentrant=True) + + with accelerator.autocast(enabled=False): + lddt_wrap = lambda x_pred: lddt_fn( + torch.cdist(x_pred, x_pred), true_cdist.float(), cdist_mask + ) + errors = 1. - functional.sharded_apply( + lddt_wrap, [x_pred.float()], shard_size=shard_size, shard_dim=-3, cat_dim=-1 + ) + loss = torch.mean(errors) + logger.debug('SmoothLDDTHead.loss: %s', loss) + return dict(loss=loss) + + class TMscoreHead(nn.Module): """Head to predict TM-score. """ @@ -1666,14 +2340,19 @@ class HeaderBuilder: confidence=ConfidenceHead, contact=ContactHead, distogram=DistogramHead, + donfidence=DonfidenceHead, + diffusion=DiffusionHead, fitness=FitnessHead, folding=FoldingHead, lddt=PLDDTHead, metric=MetricDictHead, pae=PAEHead, + pde=PDEHead, pairing=PairingHead, profile=SequenceProfileHead, roberta=RobertaLMHead, + sequence=SequenceDecoderHead, + smooth=SmoothLDDTHead, tmscore=TMscoreHead, violation=ViolationHead )