From 025a4c246dc33435b96023118f37d9d2bd4cb318 Mon Sep 17 00:00:00 2001 From: angepapa18 Date: Thu, 27 Feb 2025 13:55:48 +0200 Subject: [PATCH] add an option to handle the number of output poses --- src/inference_base.py | 61 +++++++++++++++++++++++++++++++++ src/inference_multiple_poses.py | 13 +++++++ 2 files changed, 74 insertions(+) create mode 100644 src/inference_multiple_poses.py diff --git a/src/inference_base.py b/src/inference_base.py index ecf266c..78518b3 100644 --- a/src/inference_base.py +++ b/src/inference_base.py @@ -669,6 +669,67 @@ def inference(in_pdb_1, in_pdb_2): return {"energy": min_energy.item()} +def inference_multiple_poses(in_pdb_1, in_pdb_2, num_samples=40, output_dir="output_poses"): + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load ESM model + esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() + batch_converter = alphabet.get_batch_converter() + esm_model = esm_model.to(device).eval() + + # Load score model + model = Score_Model.load_from_checkpoint( + str(Path("./checkpoints/dips/model_0.ckpt")), + map_location=device, + ) + model.to(device).eval() + + # Load PDBs + receptor = get_info_from_pdb(in_pdb_1) + ligand = get_info_from_pdb(in_pdb_2) + + # Prepare inputs + inputs = {"receptor": receptor, "ligand": ligand} + batch = get_batch_from_inputs(inputs, batch_converter, esm_model, device) + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + + # Define parameters + num_steps = 40 + use_clash_force = False + + all_outputs = [] + + # Run sampling + for i in range(num_samples): + rec_pos, lig_pos, rot_update, tr_update, outputs = Euler_Maruyama_sampler( + model=model, + batch=batch.copy(), + num_steps=num_steps, + device=device, + use_clash_force=use_clash_force, + ) + + lig_aa_coords = modify_aa_coords(ligand["aa_coords"], inputs["ligand"]["bb_coords"], rot_update, tr_update) + rec_structure = receptor["structure"] + lig_structure = ligand["structure"] + lig_structure.coord = lig_aa_coords + + complex_structure = combine_atom_arrays(rec_structure, lig_structure) + + file = PDBFile() + file.set_structure(complex_structure) + output_pdb_path = os.path.join(output_dir, f"output_pose_{i + 1}.pdb") + file.write(output_pdb_path) + + all_outputs.append({"pose_id": i + 1, "energy": outputs["energy"].item(), "pdb_path": output_pdb_path}) + + return all_outputs + + if __name__ == "__main__": # Initialize the parser parser = argparse.ArgumentParser(description="A description of what your program does") diff --git a/src/inference_multiple_poses.py b/src/inference_multiple_poses.py new file mode 100644 index 0000000..72a5df9 --- /dev/null +++ b/src/inference_multiple_poses.py @@ -0,0 +1,13 @@ +import argparse +from inference_base import inference_multiple_poses + +def parse_args(): + parser = argparse.ArgumentParser(description="Process two required PDB files.") + parser.add_argument("pdb_1", type=str, help="Path to the first PDB file") + parser.add_argument("pdb_2", type=str, help="Path to the second PDB file") + parser.add_argument("--num_samples", type=int, default=40, help="Number of output poses/samples, default=40") + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + inference_multiple_poses(args.pdb_1, args.pdb_2, args.num_samples)