Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 40 additions & 44 deletions profold2/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,10 @@
from torch.nn import functional as F
from torch.utils.data import WeightedRandomSampler
from torch.utils.data.distributed import DistributedSampler
from einops import rearrange, repeat
from einops import repeat

from profold2.common import residue_constants
from profold2.data import cropper
from profold2.data.parsers import parse_fasta
from profold2.data.padding import pad_sequential, pad_rectangle
from profold2.data.utils import (
compose_pid, decompose_pid, fix_residue_id, fix_atom_id, fix_coord, parse_seq_index,
parse_seq_type, seq_type_dict, seq_index_join, seq_index_split, str_seq_index
)
from profold2.data import cropper, padding, parsers, utils
from profold2.utils import default, env, exists, timing

logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -274,8 +268,8 @@ def _make_label_features(descriptions, attr_dict, task_num=1, defval=1.0):
def _make_label(desc):
label, label_mask = None, None

pid, chain, _ = decompose_pid(desc.split()[0], return_domain=True)
for k in (pid, compose_pid(pid, chain)):
pid, chain, _ = utils.decompose_pid(desc.split()[0], return_domain=True)
for k in (pid, utils.compose_pid(pid, chain)):
if k in attr_dict:
if 'label' in attr_dict[k]:
label = attr_dict[k]['label']
Expand Down Expand Up @@ -312,8 +306,8 @@ def _make_label(desc):


def _make_var_pid(desc):
var_pid, c, _ = decompose_pid(desc.split()[0], return_domain=True)
var_pid = compose_pid(var_pid, c)
var_pid, c, _ = utils.decompose_pid(desc.split()[0], return_domain=True)
var_pid = utils.compose_pid(var_pid, c)
return var_pid


Expand Down Expand Up @@ -442,7 +436,7 @@ def _make_seq_features(
max_seq_len=None
):
residue_index = torch.arange(len(sequence), dtype=torch.int)
residue_index = parse_seq_index(description, sequence, residue_index)
residue_index = utils.parse_seq_index(description, sequence, residue_index)

sequence = sequence[:max_seq_len]
residue_index = residue_index[:max_seq_len]
Expand Down Expand Up @@ -583,7 +577,7 @@ def _make_pdb_features(
if icode and icode != ' ':
continue

residue_id = fix_residue_id(aa.get_resname())
residue_id = utils.fix_residue_id(aa.get_resname())

if not exists(int_resseq_start):
int_resseq_start = int_resseq
Expand Down Expand Up @@ -611,7 +605,7 @@ def _make_pdb_features(
] # pylint: disable=line-too-long
for atom in aa.get_atoms():
try:
atom14idx = res_atom14_list.index(fix_atom_id(residue_id, atom.id))
atom14idx = res_atom14_list.index(utils.fix_atom_id(residue_id, atom.id))
coord = np.asarray(atom.get_coord())
if np.any(np.isnan(coord)):
continue
Expand All @@ -621,7 +615,9 @@ def _make_pdb_features(
label_mask[atom14idx] = True
except ValueError as e:
logger.debug(e)
labels, bfactors = fix_coord(residue_id, labels, label_mask, bfactors=bfactors)
labels, bfactors = utils.fix_coord(
residue_id, labels, label_mask, bfactors=bfactors
)
coord_list.append(labels)
coord_mask_list.append(label_mask)
bfactor_list.append(bfactors)
Expand Down Expand Up @@ -782,7 +778,7 @@ def _protein_crop_fn(protein, clip):
def _protein_crop_fmt(clip):
assert exists(clip), clip
if 'd' in clip:
clip['d'] = str_seq_index(torch.as_tensor(clip['d']))
clip['d'] = utils.str_seq_index(torch.as_tensor(clip['d']))
return clip


Expand Down Expand Up @@ -1003,7 +999,7 @@ def __getitem__(self, idx):

if 'pid' not in ret:
ret['pid'] = desc.split()[0]
seq_type = seq_type_dict[parse_seq_type(desc)]
seq_type = utils.seq_type_dict[utils.parse_seq_type(desc)]
feat = _make_seq_features(
input_sequence, desc, seq_color=seq_idx + 1, seq_type=seq_type
)
Expand Down Expand Up @@ -1250,10 +1246,10 @@ def _setattr(self, **kwargs_new):

def get_monomer(self, pid, seq_color=1, seq_entity=None, seq_sym=None, crop_fn=None):
# CATH format pid
pid, chain, domains = decompose_pid(pid, return_domain=True)
pid, chain, domains = utils.decompose_pid(pid, return_domain=True)
if exists(domains):
domains = list(seq_index_split(domains))
pid = compose_pid(pid, chain)
domains = list(utils.seq_index_split(domains))
pid = utils.compose_pid(pid, chain)

pkey = self.mapping[pid] if pid in self.mapping else pid
with FileSystem(self.data_dir) as fs:
Expand Down Expand Up @@ -1303,7 +1299,7 @@ def get_monomer(self, pid, seq_color=1, seq_entity=None, seq_sym=None, crop_fn=N

if exists(domains):
# CATH update pid
ret['pid'] = compose_pid(pid, None, seq_index_join(domains))
ret['pid'] = utils.compose_pid(pid, None, utils.seq_index_join(domains))

if ret.get('msa_idx', 0) != 0:
ret = _msa_as_seq(ret, ret['msa_idx'], str_key='msa')
Expand All @@ -1320,7 +1316,7 @@ def get_monomer(self, pid, seq_color=1, seq_entity=None, seq_sym=None, crop_fn=N

def _multimer_yield_cluster(self, var_pid):
for var_pid in set(self.cluster.get(var_pid, []) + [var_pid]):
var_pid, c = decompose_pid(var_pid)
var_pid, c = utils.decompose_pid(var_pid)
if self.has_chain(var_pid, c):
yield var_pid, c

Expand All @@ -1331,7 +1327,7 @@ def _multimer_build_chain_list(self, protein_id, var_list):
def _is_aligned(k, chain_list):
if k != protein_id and k in self.attr_list:
for c, *_ in chain_list:
x = self.get_chain_list(compose_pid(k, c))
x = self.get_chain_list(utils.compose_pid(k, c))
# FIX: some chains may be removed from chain.idx
if exists(x) and len(set(x) & set(chain_list)) == len(x):
return True
Expand All @@ -1351,7 +1347,7 @@ def _is_aligned(k, chain_list):
def get_multimer(self, protein_id, chains):
assert len(chains) > 1

pid, selected_chain = decompose_pid(protein_id) # pylint: disable=unbalanced-tuple-unpacking
pid, selected_chain = utils.decompose_pid(protein_id) # pylint: disable=unbalanced-tuple-unpacking
assert selected_chain in chains
# task_definition
if pid in self.attr_list and 'task_def' in self.attr_list[pid]:
Expand All @@ -1374,7 +1370,7 @@ def get_multimer(self, protein_id, chains):
var_as_seq_prob=0,
attr_list=None
):
feat = self.get_monomer(compose_pid(pid, chain), seq_color=idx + 1)
feat = self.get_monomer(utils.compose_pid(pid, chain), seq_color=idx + 1)
# fix seq_entity
assert 'str_seq' in feat
if feat['str_seq'] not in seq_entity_map:
Expand Down Expand Up @@ -1418,7 +1414,7 @@ def get_multimer(self, protein_id, chains):
feat['del_var'][var_idx]
)

ret['pid'] = compose_pid(pid, ','.join(chains))
ret['pid'] = utils.compose_pid(pid, ','.join(chains))
if self.feat_flags & FEAT_VAR and 'var' in ret and ret['var']:
assert 'length' in ret
var_dict = ret['var']
Expand Down Expand Up @@ -1473,7 +1469,7 @@ def get_multimer(self, protein_id, chains):
for idx, chain in enumerate(chains):
n = ret['length'][idx]
for c, *_ in chain_list:
cluster_id = compose_pid(var_pid, c)
cluster_id = utils.compose_pid(var_pid, c)
if cluster_id not in var_dict:
cluster_id = self.mapping.get(cluster_id, cluster_id)
hit_seq, hit_mask, target_chain, hit_str, hit_del = var_dict[cluster_id]
Expand Down Expand Up @@ -1560,7 +1556,7 @@ def get_multimer(self, protein_id, chains):
return ret

def get_chain_list(self, protein_id):
pid, chain = decompose_pid(protein_id) # pylint: disable=unbalanced-tuple-unpacking
pid, chain = utils.decompose_pid(protein_id) # pylint: disable=unbalanced-tuple-unpacking
if pid in self.chain_list:
chain_group = self.chain_list[pid]
# for g in chain_group:
Expand All @@ -1580,7 +1576,7 @@ def has_chain(self, pid, chain):
return False

def get_resolution(self, protein_id):
pid, _ = decompose_pid(protein_id) # pylint: disable=unbalanced-tuple-unpacking
pid, _ = utils.decompose_pid(protein_id) # pylint: disable=unbalanced-tuple-unpacking
# HACK: for rna or dna rebuild dataset
if pid.startswith('rna-') or pid.startswith('dna-'):
pid = pid[4:]
Expand Down Expand Up @@ -1647,7 +1643,7 @@ def get_var_features_new(
variant_path = f'{self.var_dir}/{protein_id}/msas/{protein_id}.a3m'
if fs.exists(variant_path):
with fs.open(variant_path) as f:
sequences, descriptions = parse_fasta(fs.textise(f.read()))
sequences, descriptions = parsers.parse_fasta(fs.textise(f.read()))
assert len(sequences) == len(descriptions)

if self.attr_list: # filter with attr_list, NOTE: keep the 1st one alway.
Expand Down Expand Up @@ -1710,12 +1706,12 @@ def get_structure_label_npz(
if fs.exists(fasta_file):
with fs.open(fasta_file) as f:
input_fasta_str = fs.textise(f.read())
input_seqs, input_descs = parse_fasta(input_fasta_str)
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
if len(input_seqs) != 1:
raise ValueError(f'More than one input sequence found in {fasta_file}.')
input_sequence = input_seqs[0]
input_description = input_descs[0]
input_type = seq_type_dict[parse_seq_type(input_description)]
input_type = utils.seq_type_dict[utils.parse_seq_type(input_description)]

ret.update(
_make_seq_features(
Expand Down Expand Up @@ -1849,31 +1845,31 @@ def _to_tensor(field, defval=0, dtype=None):

max_batch_len = max(len(s) for s in ret['str_seq'])

ret['seq'] = pad_sequential(
ret['seq'] = padding.pad_sequential(
_to_list('seq'), max_batch_len, padval=residue_constants.unk_restype_index
)
for field in ('seq_index', 'mask'):
ret[field] = pad_sequential(_to_list(field), max_batch_len)
ret[field] = padding.pad_sequential(_to_list(field), max_batch_len)

for field in ('seq_color', 'seq_entity', 'seq_sym'):
if _any(field):
ret[field] = pad_sequential(_to_list(field), max_batch_len)
ret[field] = padding.pad_sequential(_to_list(field), max_batch_len)

if _any('resolu'):
ret['resolution'] = _to_tensor('resolu', -1.0)

if _any('sta'):
for field in ('sta', 'sta_type_mask'):
ret[field] = pad_sequential(_to_list(field), max_batch_len)
ret[field] = padding.pad_sequential(_to_list(field), max_batch_len)

feat_flags = default(feat_flags, FEAT_ALL)
if feat_flags & FEAT_PDB and _any('coord'):
# required
for field in ('coord', 'coord_mask'):
ret[field] = pad_sequential(_to_list(field), max_batch_len)
ret[field] = padding.pad_sequential(_to_list(field), max_batch_len)
# optional
if _any('coord_plddt'):
ret['coord_plddt'] = pad_sequential(
ret['coord_plddt'] = padding.pad_sequential(
_to_list('coord_plddt'), max_batch_len, padval=1.0
)

Expand All @@ -1888,29 +1884,29 @@ def _to_tensor(field, defval=0, dtype=None):
ret['str_msa'] = _to_list('str_msa')
for field in ('msa_idx', 'num_msa'):
ret[field] = _to_tensor(field, dtype=torch.int)
ret['msa'] = pad_rectangle(
ret['msa'] = padding.pad_rectangle(
_to_list('msa'),
max_batch_len,
padval=residue_constants.HHBLITS_AA_TO_ID[('-', residue_constants.PROT)]
)
for field in ('msa_mask', 'del_msa'):
ret[field] = pad_rectangle(_to_list(field), max_batch_len)
ret[field] = padding.pad_rectangle(_to_list(field), max_batch_len)

if feat_flags & FEAT_VAR and _any('variant'):
ret['variant_pid'] = _to_list('variant_pid')
for field in ('var_idx', 'num_var'):
ret[field] = _to_tensor(field, dtype=torch.int)
ret['variant'] = pad_rectangle(
ret['variant'] = padding.pad_rectangle(
_to_list('variant'),
max_batch_len,
padval=residue_constants.HHBLITS_AA_TO_ID[('-', residue_constants.PROT)]
)
for field in ('variant_mask', 'variant_task_mask'):
ret[field] = pad_rectangle(_to_list(field), max_batch_len)
ret[field] = padding.pad_rectangle(_to_list(field), max_batch_len)
for field in ('variant_label', 'variant_label_mask'):
items = _to_list(field)
max_depth = max(item.shape[0] for item in items if exists(item))
ret[field] = pad_sequential(items, max_depth)
ret[field] = padding.pad_sequential(items, max_depth)

return ret

Expand Down