From e218a68edbb8305394f23272733ac74d3a8fcaba Mon Sep 17 00:00:00 2001 From: chungongyu Date: Tue, 24 Feb 2026 20:10:14 +0800 Subject: [PATCH] refactor: code reformat --- profold2/data/dataset.py | 84 +++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 44 deletions(-) diff --git a/profold2/data/dataset.py b/profold2/data/dataset.py index ec55b449..4831c02d 100644 --- a/profold2/data/dataset.py +++ b/profold2/data/dataset.py @@ -23,16 +23,10 @@ from torch.nn import functional as F from torch.utils.data import WeightedRandomSampler from torch.utils.data.distributed import DistributedSampler -from einops import rearrange, repeat +from einops import repeat from profold2.common import residue_constants -from profold2.data import cropper -from profold2.data.parsers import parse_fasta -from profold2.data.padding import pad_sequential, pad_rectangle -from profold2.data.utils import ( - compose_pid, decompose_pid, fix_residue_id, fix_atom_id, fix_coord, parse_seq_index, - parse_seq_type, seq_type_dict, seq_index_join, seq_index_split, str_seq_index -) +from profold2.data import cropper, padding, parsers, utils from profold2.utils import default, env, exists, timing logger = logging.getLogger(__file__) @@ -274,8 +268,8 @@ def _make_label_features(descriptions, attr_dict, task_num=1, defval=1.0): def _make_label(desc): label, label_mask = None, None - pid, chain, _ = decompose_pid(desc.split()[0], return_domain=True) - for k in (pid, compose_pid(pid, chain)): + pid, chain, _ = utils.decompose_pid(desc.split()[0], return_domain=True) + for k in (pid, utils.compose_pid(pid, chain)): if k in attr_dict: if 'label' in attr_dict[k]: label = attr_dict[k]['label'] @@ -312,8 +306,8 @@ def _make_label(desc): def _make_var_pid(desc): - var_pid, c, _ = decompose_pid(desc.split()[0], return_domain=True) - var_pid = compose_pid(var_pid, c) + var_pid, c, _ = utils.decompose_pid(desc.split()[0], return_domain=True) + var_pid = utils.compose_pid(var_pid, c) return var_pid @@ -442,7 +436,7 @@ def _make_seq_features( max_seq_len=None ): residue_index = torch.arange(len(sequence), dtype=torch.int) - residue_index = parse_seq_index(description, sequence, residue_index) + residue_index = utils.parse_seq_index(description, sequence, residue_index) sequence = sequence[:max_seq_len] residue_index = residue_index[:max_seq_len] @@ -583,7 +577,7 @@ def _make_pdb_features( if icode and icode != ' ': continue - residue_id = fix_residue_id(aa.get_resname()) + residue_id = utils.fix_residue_id(aa.get_resname()) if not exists(int_resseq_start): int_resseq_start = int_resseq @@ -611,7 +605,7 @@ def _make_pdb_features( ] # pylint: disable=line-too-long for atom in aa.get_atoms(): try: - atom14idx = res_atom14_list.index(fix_atom_id(residue_id, atom.id)) + atom14idx = res_atom14_list.index(utils.fix_atom_id(residue_id, atom.id)) coord = np.asarray(atom.get_coord()) if np.any(np.isnan(coord)): continue @@ -621,7 +615,9 @@ def _make_pdb_features( label_mask[atom14idx] = True except ValueError as e: logger.debug(e) - labels, bfactors = fix_coord(residue_id, labels, label_mask, bfactors=bfactors) + labels, bfactors = utils.fix_coord( + residue_id, labels, label_mask, bfactors=bfactors + ) coord_list.append(labels) coord_mask_list.append(label_mask) bfactor_list.append(bfactors) @@ -782,7 +778,7 @@ def _protein_crop_fn(protein, clip): def _protein_crop_fmt(clip): assert exists(clip), clip if 'd' in clip: - clip['d'] = str_seq_index(torch.as_tensor(clip['d'])) + clip['d'] = utils.str_seq_index(torch.as_tensor(clip['d'])) return clip @@ -1003,7 +999,7 @@ def __getitem__(self, idx): if 'pid' not in ret: ret['pid'] = desc.split()[0] - seq_type = seq_type_dict[parse_seq_type(desc)] + seq_type = utils.seq_type_dict[utils.parse_seq_type(desc)] feat = _make_seq_features( input_sequence, desc, seq_color=seq_idx + 1, seq_type=seq_type ) @@ -1250,10 +1246,10 @@ def _setattr(self, **kwargs_new): def get_monomer(self, pid, seq_color=1, seq_entity=None, seq_sym=None, crop_fn=None): # CATH format pid - pid, chain, domains = decompose_pid(pid, return_domain=True) + pid, chain, domains = utils.decompose_pid(pid, return_domain=True) if exists(domains): - domains = list(seq_index_split(domains)) - pid = compose_pid(pid, chain) + domains = list(utils.seq_index_split(domains)) + pid = utils.compose_pid(pid, chain) pkey = self.mapping[pid] if pid in self.mapping else pid with FileSystem(self.data_dir) as fs: @@ -1303,7 +1299,7 @@ def get_monomer(self, pid, seq_color=1, seq_entity=None, seq_sym=None, crop_fn=N if exists(domains): # CATH update pid - ret['pid'] = compose_pid(pid, None, seq_index_join(domains)) + ret['pid'] = utils.compose_pid(pid, None, utils.seq_index_join(domains)) if ret.get('msa_idx', 0) != 0: ret = _msa_as_seq(ret, ret['msa_idx'], str_key='msa') @@ -1320,7 +1316,7 @@ def get_monomer(self, pid, seq_color=1, seq_entity=None, seq_sym=None, crop_fn=N def _multimer_yield_cluster(self, var_pid): for var_pid in set(self.cluster.get(var_pid, []) + [var_pid]): - var_pid, c = decompose_pid(var_pid) + var_pid, c = utils.decompose_pid(var_pid) if self.has_chain(var_pid, c): yield var_pid, c @@ -1331,7 +1327,7 @@ def _multimer_build_chain_list(self, protein_id, var_list): def _is_aligned(k, chain_list): if k != protein_id and k in self.attr_list: for c, *_ in chain_list: - x = self.get_chain_list(compose_pid(k, c)) + x = self.get_chain_list(utils.compose_pid(k, c)) # FIX: some chains may be removed from chain.idx if exists(x) and len(set(x) & set(chain_list)) == len(x): return True @@ -1351,7 +1347,7 @@ def _is_aligned(k, chain_list): def get_multimer(self, protein_id, chains): assert len(chains) > 1 - pid, selected_chain = decompose_pid(protein_id) # pylint: disable=unbalanced-tuple-unpacking + pid, selected_chain = utils.decompose_pid(protein_id) # pylint: disable=unbalanced-tuple-unpacking assert selected_chain in chains # task_definition if pid in self.attr_list and 'task_def' in self.attr_list[pid]: @@ -1374,7 +1370,7 @@ def get_multimer(self, protein_id, chains): var_as_seq_prob=0, attr_list=None ): - feat = self.get_monomer(compose_pid(pid, chain), seq_color=idx + 1) + feat = self.get_monomer(utils.compose_pid(pid, chain), seq_color=idx + 1) # fix seq_entity assert 'str_seq' in feat if feat['str_seq'] not in seq_entity_map: @@ -1418,7 +1414,7 @@ def get_multimer(self, protein_id, chains): feat['del_var'][var_idx] ) - ret['pid'] = compose_pid(pid, ','.join(chains)) + ret['pid'] = utils.compose_pid(pid, ','.join(chains)) if self.feat_flags & FEAT_VAR and 'var' in ret and ret['var']: assert 'length' in ret var_dict = ret['var'] @@ -1473,7 +1469,7 @@ def get_multimer(self, protein_id, chains): for idx, chain in enumerate(chains): n = ret['length'][idx] for c, *_ in chain_list: - cluster_id = compose_pid(var_pid, c) + cluster_id = utils.compose_pid(var_pid, c) if cluster_id not in var_dict: cluster_id = self.mapping.get(cluster_id, cluster_id) hit_seq, hit_mask, target_chain, hit_str, hit_del = var_dict[cluster_id] @@ -1560,7 +1556,7 @@ def get_multimer(self, protein_id, chains): return ret def get_chain_list(self, protein_id): - pid, chain = decompose_pid(protein_id) # pylint: disable=unbalanced-tuple-unpacking + pid, chain = utils.decompose_pid(protein_id) # pylint: disable=unbalanced-tuple-unpacking if pid in self.chain_list: chain_group = self.chain_list[pid] # for g in chain_group: @@ -1580,7 +1576,7 @@ def has_chain(self, pid, chain): return False def get_resolution(self, protein_id): - pid, _ = decompose_pid(protein_id) # pylint: disable=unbalanced-tuple-unpacking + pid, _ = utils.decompose_pid(protein_id) # pylint: disable=unbalanced-tuple-unpacking # HACK: for rna or dna rebuild dataset if pid.startswith('rna-') or pid.startswith('dna-'): pid = pid[4:] @@ -1647,7 +1643,7 @@ def get_var_features_new( variant_path = f'{self.var_dir}/{protein_id}/msas/{protein_id}.a3m' if fs.exists(variant_path): with fs.open(variant_path) as f: - sequences, descriptions = parse_fasta(fs.textise(f.read())) + sequences, descriptions = parsers.parse_fasta(fs.textise(f.read())) assert len(sequences) == len(descriptions) if self.attr_list: # filter with attr_list, NOTE: keep the 1st one alway. @@ -1710,12 +1706,12 @@ def get_structure_label_npz( if fs.exists(fasta_file): with fs.open(fasta_file) as f: input_fasta_str = fs.textise(f.read()) - input_seqs, input_descs = parse_fasta(input_fasta_str) + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) if len(input_seqs) != 1: raise ValueError(f'More than one input sequence found in {fasta_file}.') input_sequence = input_seqs[0] input_description = input_descs[0] - input_type = seq_type_dict[parse_seq_type(input_description)] + input_type = utils.seq_type_dict[utils.parse_seq_type(input_description)] ret.update( _make_seq_features( @@ -1849,31 +1845,31 @@ def _to_tensor(field, defval=0, dtype=None): max_batch_len = max(len(s) for s in ret['str_seq']) - ret['seq'] = pad_sequential( + ret['seq'] = padding.pad_sequential( _to_list('seq'), max_batch_len, padval=residue_constants.unk_restype_index ) for field in ('seq_index', 'mask'): - ret[field] = pad_sequential(_to_list(field), max_batch_len) + ret[field] = padding.pad_sequential(_to_list(field), max_batch_len) for field in ('seq_color', 'seq_entity', 'seq_sym'): if _any(field): - ret[field] = pad_sequential(_to_list(field), max_batch_len) + ret[field] = padding.pad_sequential(_to_list(field), max_batch_len) if _any('resolu'): ret['resolution'] = _to_tensor('resolu', -1.0) if _any('sta'): for field in ('sta', 'sta_type_mask'): - ret[field] = pad_sequential(_to_list(field), max_batch_len) + ret[field] = padding.pad_sequential(_to_list(field), max_batch_len) feat_flags = default(feat_flags, FEAT_ALL) if feat_flags & FEAT_PDB and _any('coord'): # required for field in ('coord', 'coord_mask'): - ret[field] = pad_sequential(_to_list(field), max_batch_len) + ret[field] = padding.pad_sequential(_to_list(field), max_batch_len) # optional if _any('coord_plddt'): - ret['coord_plddt'] = pad_sequential( + ret['coord_plddt'] = padding.pad_sequential( _to_list('coord_plddt'), max_batch_len, padval=1.0 ) @@ -1888,29 +1884,29 @@ def _to_tensor(field, defval=0, dtype=None): ret['str_msa'] = _to_list('str_msa') for field in ('msa_idx', 'num_msa'): ret[field] = _to_tensor(field, dtype=torch.int) - ret['msa'] = pad_rectangle( + ret['msa'] = padding.pad_rectangle( _to_list('msa'), max_batch_len, padval=residue_constants.HHBLITS_AA_TO_ID[('-', residue_constants.PROT)] ) for field in ('msa_mask', 'del_msa'): - ret[field] = pad_rectangle(_to_list(field), max_batch_len) + ret[field] = padding.pad_rectangle(_to_list(field), max_batch_len) if feat_flags & FEAT_VAR and _any('variant'): ret['variant_pid'] = _to_list('variant_pid') for field in ('var_idx', 'num_var'): ret[field] = _to_tensor(field, dtype=torch.int) - ret['variant'] = pad_rectangle( + ret['variant'] = padding.pad_rectangle( _to_list('variant'), max_batch_len, padval=residue_constants.HHBLITS_AA_TO_ID[('-', residue_constants.PROT)] ) for field in ('variant_mask', 'variant_task_mask'): - ret[field] = pad_rectangle(_to_list(field), max_batch_len) + ret[field] = padding.pad_rectangle(_to_list(field), max_batch_len) for field in ('variant_label', 'variant_label_mask'): items = _to_list(field) max_depth = max(item.shape[0] for item in items if exists(item)) - ret[field] = pad_sequential(items, max_depth) + ret[field] = padding.pad_sequential(items, max_depth) return ret