diff --git a/main.py b/main.py index 03210e9..c5e4388 100644 --- a/main.py +++ b/main.py @@ -671,9 +671,11 @@ def _parse_result(a3m_string): if args.verbose: print(f"predict affinity ranking score with {model}:{ref_pkl}") + # model params setattr(args, "model_file", ref_pkl) setattr(args, "model_ckpt", os.path.join(args.ref_pkl, f"{model}_model.pth")) + # prepare variants feat = data.get_multimer(compose_pid(pid, "P"), chains) assert len(feat["str_var"]) == len(feat["variant_pid"]) a3m_string = "\n".join( @@ -685,12 +687,13 @@ def _parse_result(a3m_string): output_file_path = os.path.join(args.output_dir, f"{model}_{pid}.a3m") with open(output_file_path, "w") as output_file: setattr(args, "output_file", output_file) - energy.main(args) + energy.main(args) # calc the Elo-score with open(output_file_path, "r") as output_file: a3m_string = output_file.read() for pid, pred in _parse_result(a3m_string): pred_dict[pid].append(pred) + # write results to csv with open(os.path.join(args.output_dir, f"{model}_pred.csv"), "w") as f: writer = csv.DictWriter(f, fieldnames=["id", "chains"] + task.task_name_list) writer.writeheader() diff --git a/predict.sh b/predict.sh index fdf581c..eff6da9 100755 --- a/predict.sh +++ b/predict.sh @@ -134,6 +134,7 @@ done python ${CWD}/main.py peptide_align \ --output_dir ${output_dir}/a3m \ --target_db ${output_dir}/tcr_pmhc_P.fa \ + --target_db ${CWD}/data/tcr_pmhc_db_P.fa \ --verbose \ ${CWD}/data/tcr_pmhc_db/fasta/*_P.fasta \