From 9daaf37693d2db59440c884b67446b59311a1e26 Mon Sep 17 00:00:00 2001 From: chungongyu Date: Tue, 8 Jul 2025 13:23:04 +0800 Subject: [PATCH] feat: write ranking affinity score to csv file --- main.py | 45 ++++++++++++++++++++++++++++++++++++++++++--- task.py | 1 + 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 85a9424..03210e9 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,7 @@ import click import sqlitedict +import numpy as np from profold2.data.dataset import ProteinStructureDataset from profold2.data.parsers import parse_fasta, parse_a3m @@ -615,6 +616,7 @@ def attr_update_weight_and_task(**args): ) @click.option("--mask", type=str, default="-", hidden=True) @click.option("--task_def", type=str, default=json.dumps(task.make_def()), hidden=True) +@click.option("--task_pid_prefix", type=str, default="tcr_pmhc_", hidden=True) @click.option( "--chunksize", type=int, @@ -642,7 +644,23 @@ def predict(**args): max_var_depth=None ) + def _parse_result(a3m_string): + _, descriptions = parse_fasta(a3m_string) + for fields in map(lambda x: x.split("\t"), descriptions): + pid, fields = fields[0], fields[1:] + if pid.startswith(args.task_pid_prefix): + pred = [None] * task.task_num + for field in fields: + i = field.find(":") + if i != -1: + if field[:i] == "Elo_score": + pred = json.loads(field[i + 1:]) + break + yield pid, pred + for model in args.model: + pred_dict = defaultdict(list) + for ref_pkl in glob.glob(os.path.join(args.ref_pkl, f"{model}_*.pkl")): pdb_id = os.path.basename(ref_pkl) assert pdb_id.startswith(f"{model}_") @@ -663,11 +681,32 @@ def predict(**args): ) with io.StringIO(a3m_string) as a3m_file: setattr(args, "a3m_file", [a3m_file]) - with open( - os.path.join(args.output_dir, f"{model}_{pid}.a3m"), "w" - ) as output_file: + + output_file_path = os.path.join(args.output_dir, f"{model}_{pid}.a3m") + with open(output_file_path, "w") as output_file: setattr(args, "output_file", output_file) energy.main(args) + with open(output_file_path, "r") as output_file: + a3m_string = output_file.read() + for pid, pred in _parse_result(a3m_string): + pred_dict[pid].append(pred) + + with open(os.path.join(args.output_dir, f"{model}_pred.csv"), "w") as f: + writer = csv.DictWriter(f, fieldnames=["id", "chains"] + task.task_name_list) + writer.writeheader() + for pid, pred_list in pred_dict.items(): + chain_list, *_ = data.chain_list[pid] # FIX: data.get_chain_list(protein_id) + assert chain_list, (pid, pid in data.chain_list, len(data.chain_list)) + _, pred_mask = task.make_label(0, chain_list) + + pred_list, pred_mask = np.asarray(pred_list), np.asarray(pred_mask) + pred_list = np.sum(pred_list * pred_mask[None], axis=0) / pred_list.shape[0] + + row = {"id": pid, "chains": "_".join(chain_list)} + for idx, (pred, mask) in enumerate(zip(pred_list, pred_mask)): + if mask: + row[task.task_name_list[idx]] = pred + writer.writerow(row) if __name__ == "__main__": diff --git a/task.py b/task.py index d633af6..c117e9a 100644 --- a/task.py +++ b/task.py @@ -9,6 +9,7 @@ task_num = len(task_mapping.keys()) +task_name_list = ["pMHC", "pTCR", "TCR_pMHC"] def make_def(): task_def = defaultdict(list)