Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 96 additions & 36 deletions bindcraft.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
####################################
###################### BindCraft Run
####################################

# MODIFICATION FOR PARALLEL EXECUTION
# Minimal changes from v1.5.1 to allow safe execution as multiple concurrent processes
# using either a job scheduler (like SLURM) or local background jobs.
#
# 1. File locking is used to prevent data corruption when writing to shared CSV files.
# 2. A central counter file (_progress_counter.txt) tracks the total accepted designs,
# allowing all processes to stop once the target is reached.
#
# Requires FileLock library:
# conda install -c conda-forge filelock
#
# Example SLURM usage:
# #SBATCH --array=0-31
#
# Example local background jobs:
# for i in {1..8}; do nohup python -u ./bindcraft.py --settings './settings_target/PDL1.json' --filters './settings_filters/default_filters.json' --advanced './settings_advanced/default_4stage_multimer.json' > output_${i}.log 2> error_${i}.log & done

### Import dependencies
from functions import *
from filelock import FileLock

# Check if JAX-capable GPU is available, otherwise exit
check_jax_gpu()
Expand Down Expand Up @@ -48,10 +67,28 @@
final_csv = os.path.join(target_settings["design_path"], 'final_design_stats.csv')
failure_csv = os.path.join(target_settings["design_path"], 'failure_csv.csv')

create_dataframe(trajectory_csv, trajectory_labels)
create_dataframe(mpnn_csv, design_labels)
create_dataframe(final_csv, final_labels)
generate_filter_pass_csv(failure_csv, args.filters)
# Define paths for the progress counter and all shared CSV files and their locks
progress_counter_file = os.path.join(target_settings["design_path"], '_progress_counter.txt')
progress_lock = FileLock(progress_counter_file + ".lock")

trajectory_csv_lock = FileLock(trajectory_csv + ".lock")
mpnn_csv_lock = FileLock(mpnn_csv + ".lock")
final_csv_lock = FileLock(final_csv + ".lock")
failure_csv_lock = FileLock(failure_csv + ".lock")
finalization_lock = FileLock(os.path.join(target_settings["design_path"], "_finalization.lock"))

# Initialize counter file if it doesn't exist. This is safe for multiple processes.
with progress_lock:
if not os.path.exists(progress_counter_file):
with open(progress_counter_file, 'w') as f:
f.write('0')

# Initialize dataframes safely by passing the corresponding lock to each function
create_dataframe(trajectory_csv, trajectory_labels, trajectory_csv_lock)
create_dataframe(mpnn_csv, design_labels, mpnn_csv_lock)
create_dataframe(final_csv, final_labels, final_csv_lock)
generate_filter_pass_csv(failure_csv, args.filters, failure_csv_lock)


####################################
####################################
Expand All @@ -70,11 +107,13 @@

### start design loop
while True:
### check if we have the target number of binders
final_designs_reached = check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels)

if final_designs_reached:
# stop design loop execution
# Check global progress counter before starting a new trajectory
with progress_lock:
with open(progress_counter_file, 'r') as f:
accepted_count = int(f.read())

if accepted_count >= target_settings["number_of_final_designs"]:
print(f"Target of {target_settings['number_of_final_designs']} designs reached. Worker process is shutting down.")
break

### check if we reached maximum allowed trajectories
Expand Down Expand Up @@ -108,7 +147,7 @@
### Begin binder hallucination
trajectory = binder_hallucination(design_name, target_settings["starting_pdb"], target_settings["chains"],
target_settings["target_hotspot_residues"], length, seed, helicity_value,
design_models, advanced_settings, design_paths, failure_csv)
design_models, advanced_settings, design_paths, failure_csv, failure_csv_lock)
trajectory_metrics = copy_dict(trajectory._tmp["best"]["aux"]["log"]) # contains plddt, ptm, i_ptm, pae, i_pae
trajectory_pdb = os.path.join(design_paths["Trajectory"], design_name + ".pdb")

Expand Down Expand Up @@ -159,7 +198,8 @@
trajectory_interface_scores['interface_hbond_percentage'], trajectory_interface_scores['interface_delta_unsat_hbonds'], trajectory_interface_scores['interface_delta_unsat_hbonds_percentage'],
trajectory_alpha_interface, trajectory_beta_interface, trajectory_loops_interface, trajectory_alpha, trajectory_beta, trajectory_loops, trajectory_interface_AA, trajectory_target_rmsd,
trajectory_time_text, traj_seq_notes, settings_file, filters_file, advanced_file]
insert_data(trajectory_csv, trajectory_data)
with trajectory_csv_lock:
insert_data(trajectory_csv, trajectory_data)

if advanced_settings["enable_mpnn"]:
# initialise MPNN counters
Expand All @@ -170,7 +210,8 @@

### MPNN redesign of starting binder
mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)
existing_mpnn_sequences = set(pd.read_csv(mpnn_csv, usecols=['Sequence'])['Sequence'].values)
with mpnn_csv_lock:
existing_mpnn_sequences = set(pd.read_csv(mpnn_csv, usecols=['Sequence'])['Sequence'].values)

# create set of MPNN sequences with allowed amino acid composition
restricted_AAs = set(aa.strip().upper() for aa in advanced_settings["omit_AAs"].split(',')) if advanced_settings["force_reject_AA"] else set()
Expand Down Expand Up @@ -232,7 +273,7 @@
mpnn_sequence['seq'], mpnn_design_name,
target_settings["starting_pdb"], target_settings["chains"],
length, trajectory_pdb, prediction_models, advanced_settings,
filters, design_paths, failure_csv)
filters, design_paths, failure_csv, failure_csv_lock)

# if AF2 filters are not passed then skip the scoring
if not pass_af2_filters:
Expand Down Expand Up @@ -330,7 +371,6 @@
mpnn_end_time = time.time() - mpnn_time
elapsed_mpnn_text = f"{'%d hours, %d minutes, %d seconds' % (int(mpnn_end_time // 3600), int((mpnn_end_time % 3600) // 60), int(mpnn_end_time % 60))}"


# Insert statistics about MPNN design into CSV, will return None if corresponding model does note exist
model_numbers = range(1, 6)
statistics_labels = ['pLDDT', 'pTM', 'i_pTM', 'pAE', 'i_pAE', 'i_pLDDT', 'ss_pLDDT', 'Unrelaxed_Clashes', 'Relaxed_Clashes', 'Binder_Energy_Score', 'Surface_Hydrophobicity',
Expand All @@ -357,7 +397,8 @@
mpnn_data.extend([elapsed_mpnn_text, seq_notes, settings_file, filters_file, advanced_file])

# insert data into csv
insert_data(mpnn_csv, mpnn_data)
with mpnn_csv_lock:
insert_data(mpnn_csv, mpnn_data)

# find best model number by pLDDT
plddt_values = {i: mpnn_data[i] for i in range(11, 15) if mpnn_data[i] is not None}
Expand All @@ -381,13 +422,24 @@

# insert data into final csv
final_data = [''] + mpnn_data
insert_data(final_csv, final_data)
with final_csv_lock:
insert_data(final_csv, final_data)

# Safely increment the global progress counter
with progress_lock:
with open(progress_counter_file, 'r') as f:
current_count = int(f.read())
new_count = current_count + 1
with open(progress_counter_file, 'w') as f:
f.write(str(new_count))

# copy animation from accepted trajectory
if advanced_settings["save_design_animations"]:
accepted_animation = os.path.join(design_paths["Accepted/Animation"], f"{design_name}.html")
if not os.path.exists(accepted_animation):
shutil.copy(os.path.join(design_paths["Trajectory/Animation"], f"{design_name}.html"), accepted_animation)
source_animation = os.path.join(design_paths["Trajectory/Animation"], f"{design_name}.html")
if os.path.exists(source_animation):
shutil.copy(source_animation, accepted_animation)

# copy plots of accepted trajectory
plot_files = os.listdir(design_paths["Trajectory/Plots"])
Expand All @@ -396,25 +448,34 @@
source_plot = os.path.join(design_paths["Trajectory/Plots"], accepted_plot)
target_plot = os.path.join(design_paths["Accepted/Plots"], accepted_plot)
if not os.path.exists(target_plot):
shutil.copy(source_plot, target_plot)
if os.path.exists(source_plot):
shutil.copy(source_plot, target_plot)

# If this process just saved the final design, it will trigger the ranking.
if new_count >= target_settings["number_of_final_designs"]:
print(f"FINAL DESIGN ({new_count}) FOUND! TRIGGERING FINAL RANKING...")
check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels, finalization_lock, mpnn_csv_lock, final_csv_lock)

else:
print(f"Unmet filter conditions for {mpnn_design_name}")
failure_df = pd.read_csv(failure_csv)
special_prefixes = ('Average_', '1_', '2_', '3_', '4_', '5_')
incremented_columns = set()

for column in filter_conditions:
base_column = column
for prefix in special_prefixes:
if column.startswith(prefix):
base_column = column.split('_', 1)[1]

if base_column not in incremented_columns:
failure_df[base_column] = failure_df[base_column] + 1
incremented_columns.add(base_column)

failure_df.to_csv(failure_csv, index=False)

with failure_csv_lock:
failure_df = pd.read_csv(failure_csv)
special_prefixes = ('Average_', '1_', '2_', '3_', '4_', '5_')
incremented_columns = set()

for column in filter_conditions:
base_column = column
for prefix in special_prefixes:
if column.startswith(prefix):
base_column = column.split('_', 1)[1]

if base_column not in incremented_columns:
failure_df[base_column] = failure_df[base_column] + 1
incremented_columns.add(base_column)

failure_df.to_csv(failure_csv, index=False)

shutil.copy(best_model_pdb, design_paths["Rejected"])

# increase MPNN design number
Expand Down Expand Up @@ -457,6 +518,5 @@
gc.collect()

### Script finished
elapsed_time = time.time() - script_start_time
elapsed_text = f"{'%d hours, %d minutes, %d seconds' % (int(elapsed_time // 3600), int((elapsed_time % 3600) // 60), int(elapsed_time % 60))}"
print("Finished all designs. Script execution for "+str(trajectory_n)+" trajectories took: "+elapsed_text)
# The final summary block is removed, as it is not meaningful in a parallel context. Each worker will exit on its own when the global target is met.
print("Finished all designs.")