diff --git a/src/inference_base.py b/src/inference_base.py index ecf266c..3df7592 100644 --- a/src/inference_base.py +++ b/src/inference_base.py @@ -139,18 +139,19 @@ def get_native(pdb_path): valid_atoms_mask = np.zeros(len(structure), dtype=bool) # Iterate over unique residues - for res_id in set(structure.res_id): - # Create a mask for atoms in the current residue - residue_mask = (structure.res_id == res_id) - residue_atoms = structure[residue_mask] + for chain_id in set(structure.chain_id): + for res_id in set(structure.res_id): + # Create a mask for atoms in the current residue + residue_mask = ((structure.res_id == res_id) & (structure.chain_id == chain_id)) + residue_atoms = structure[residue_mask] - # Get atom names for this residue - residue_atom_names = set(residue_atoms.atom_name) - - # Check if all backbone atoms are present - if backbone_atoms.issubset(residue_atom_names): - # If backbone atoms are present, mark this residue as valid - valid_atoms_mask[residue_mask] = True + # Get atom names for this residue + residue_atom_names = set(residue_atoms.atom_name) + + # Check if all backbone atoms are present + if backbone_atoms.issubset(residue_atom_names): + # If backbone atoms are present, mark this residue as valid + valid_atoms_mask[residue_mask] = True # Apply the mask to filter the structure filtered_structure = structure[valid_atoms_mask]