diff --git a/functions/generic_utils.py b/functions/generic_utils.py index c5ce6d0..5022066 100644 --- a/functions/generic_utils.py +++ b/functions/generic_utils.py @@ -1,6 +1,9 @@ #################################### ################## General functions #################################### +# MODIFICATION FOR PARALLEL EXECUTION +# Minimal changes from v1.5.1 to allow safe execution as a SLURM job array. + ### Import dependencies import os import json @@ -53,75 +56,82 @@ def generate_directories(design_path): return design_paths # generate CSV file for tracking designs not passing filters -def generate_filter_pass_csv(failure_csv, filter_json): - if not os.path.exists(failure_csv): - with open(filter_json, 'r') as file: - data = json.load(file) - - # Create a list of modified keys - names = ['Trajectory_logits_pLDDT', 'Trajectory_softmax_pLDDT', 'Trajectory_one-hot_pLDDT', 'Trajectory_final_pLDDT', 'Trajectory_Contacts', 'Trajectory_Clashes', 'Trajectory_WrongHotspot'] - special_prefixes = ('Average_', '1_', '2_', '3_', '4_', '5_') - tracked_filters = set() - - for key in data.keys(): - processed_name = key # Use the full key by default - - # Check if the key starts with any special prefixes - for prefix in special_prefixes: - if key.startswith(prefix): - # Strip the prefix and use the remaining part - processed_name = key.split('_', 1)[1] - break - - # Handle 'InterfaceAAs' with appending amino acids - if 'InterfaceAAs' in processed_name: - # Generate 20 variations of 'InterfaceAAs' with amino acids appended - amino_acids = 'ACDEFGHIKLMNPQRSTVWY' - for aa in amino_acids: - variant_name = f"InterfaceAAs_{aa}" - if variant_name not in tracked_filters: - names.append(variant_name) - tracked_filters.add(variant_name) - elif processed_name not in tracked_filters: - # Add processed name if it hasn't been added before - names.append(processed_name) - tracked_filters.add(processed_name) - - # make dataframe with 0s - df = pd.DataFrame(columns=names) - df.loc[0] = [0] * len(names) - - df.to_csv(failure_csv, index=False) +def generate_filter_pass_csv(failure_csv, filter_json, failure_csv_lock): + with failure_csv_lock: + if not os.path.exists(failure_csv): + with open(filter_json, 'r') as file: + data = json.load(file) + + # Create a list of modified keys + names = ['Trajectory_logits_pLDDT', 'Trajectory_softmax_pLDDT', 'Trajectory_one-hot_pLDDT', 'Trajectory_final_pLDDT', 'Trajectory_Contacts', 'Trajectory_Clashes', 'Trajectory_WrongHotspot'] + special_prefixes = ('Average_', '1_', '2_', '3_', '4_', '5_') + tracked_filters = set() + + for key in data.keys(): + processed_name = key # Use the full key by default + + # Check if the key starts with any special prefixes + for prefix in special_prefixes: + if key.startswith(prefix): + # Strip the prefix and use the remaining part + processed_name = key.split('_', 1)[1] + break + + # Handle 'InterfaceAAs' with appending amino acids + if 'InterfaceAAs' in processed_name: + # Generate 20 variations of 'InterfaceAAs' with amino acids appended + amino_acids = 'ACDEFGHIKLMNPQRSTVWY' + for aa in amino_acids: + variant_name = f"InterfaceAAs_{aa}" + if variant_name not in tracked_filters: + names.append(variant_name) + tracked_filters.add(variant_name) + elif processed_name not in tracked_filters: + # Add processed name if it hasn't been added before + names.append(processed_name) + tracked_filters.add(processed_name) + + # make dataframe with 0s + df = pd.DataFrame(columns=names) + df.loc[0] = [0] * len(names) + + df.to_csv(failure_csv, index=False) # update failure rates from trajectories and early predictions -def update_failures(failure_csv, failure_column_or_dict): - failure_df = pd.read_csv(failure_csv) - - def strip_model_prefix(name): - # Strips the model-specific prefix if it exists - parts = name.split('_') - if parts[0].isdigit(): - return '_'.join(parts[1:]) - return name - - # update dictionary coming from complex prediction - if isinstance(failure_column_or_dict, dict): - # Update using a dictionary of failures - for filter_name, count in failure_column_or_dict.items(): - stripped_name = strip_model_prefix(filter_name) - if stripped_name in failure_df.columns: - failure_df[stripped_name] += count - else: - failure_df[stripped_name] = count - else: - # Update a single column from trajectory generation - failure_column = strip_model_prefix(failure_column_or_dict) - if failure_column in failure_df.columns: - failure_df[failure_column] += 1 +def update_failures(failure_csv, failure_column_or_dict, failure_csv_lock): + with failure_csv_lock: + try: + failure_df = pd.read_csv(failure_csv) + except pd.errors.EmptyDataError: + # This is a safeguard. If the file is empty for any reason, crash is prevented by skipping the update. + print(f"WARNING: Could not update failure stats because {os.path.basename(failure_csv)} was empty.") + return + + def strip_model_prefix(name): + # Strips the model-specific prefix if it exists + parts = name.split('_') + if parts[0].isdigit(): + return '_'.join(parts[1:]) + return name + + # update dictionary coming from complex prediction + if isinstance(failure_column_or_dict, dict): + # Update using a dictionary of failures + for filter_name, count in failure_column_or_dict.items(): + stripped_name = strip_model_prefix(filter_name) + if stripped_name in failure_df.columns: + failure_df[stripped_name] += count + else: + failure_df[stripped_name] = count else: - failure_df[failure_column] = 1 - - failure_df.to_csv(failure_csv, index=False) + # Update a single column from trajectory generation + failure_column = strip_model_prefix(failure_column_or_dict) + if failure_column in failure_df.columns: + failure_df[failure_column] += 1 + else: + failure_df[failure_column] = 1 + + failure_df.to_csv(failure_csv, index=False) # Check if number of trajectories generated def check_n_trajectories(design_paths, advanced_settings): @@ -134,53 +144,56 @@ def check_n_trajectories(design_paths, advanced_settings): return False # Check if we have required number of accepted targets, rank them, and analyse sequence and structure properties -def check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels): - accepted_binders = [f for f in os.listdir(design_paths["Accepted"]) if f.endswith('.pdb')] +def check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels, finalization_lock, mpnn_csv_lock, final_csv_lock): + with finalization_lock: + accepted_binders = [f for f in os.listdir(design_paths["Accepted"]) if f.endswith('.pdb')] + + if len(accepted_binders) >= target_settings["number_of_final_designs"]: + print(f"Target number {str(len(accepted_binders))} of designs reached! Reranking...") + + # clear the Ranked folder in case we added new designs in the meantime so we rerank them all + for f in os.listdir(design_paths["Accepted/Ranked"]): + os.remove(os.path.join(design_paths["Accepted/Ranked"], f)) + + # load dataframe of designed binders + with mpnn_csv_lock: + design_df = pd.read_csv(mpnn_csv) + design_df = design_df.sort_values('Average_i_pTM', ascending=False) + + # create final csv dataframe to copy matched rows, initialize with the column labels + final_df = pd.DataFrame(columns=final_labels) + + # check the ranking of the designs and copy them with new ranked IDs to the folder + rank = 1 + for _, row in design_df.iterrows(): + for binder in accepted_binders: + target_settings["binder_name"], model = binder.rsplit('_model', 1) + if target_settings["binder_name"] == row['Design']: + # rank and copy into ranked folder + row_data = {'Rank': rank, **{label: row[label] for label in design_labels}} + final_df = pd.concat([final_df, pd.DataFrame([row_data])], ignore_index=True) + old_path = os.path.join(design_paths["Accepted"], binder) + new_path = os.path.join(design_paths["Accepted/Ranked"], f"{rank}_{target_settings['binder_name']}_model{model.rsplit('.', 1)[0]}.pdb") + shutil.copyfile(old_path, new_path) + + rank += 1 + break + + # save the final_df to final_csv + with final_csv_lock: + final_df.to_csv(final_csv, index=False) + + # zip large folders to save space + if advanced_settings["zip_animations"]: + zip_and_empty_folder(design_paths["Trajectory/Animation"], '.html') + + if advanced_settings["zip_plots"]: + zip_and_empty_folder(design_paths["Trajectory/Plots"], '.png') + + return True - if len(accepted_binders) >= target_settings["number_of_final_designs"]: - print(f"Target number {str(len(accepted_binders))} of designs reached! Reranking...") - - # clear the Ranked folder in case we added new designs in the meantime so we rerank them all - for f in os.listdir(design_paths["Accepted/Ranked"]): - os.remove(os.path.join(design_paths["Accepted/Ranked"], f)) - - # load dataframe of designed binders - design_df = pd.read_csv(mpnn_csv) - design_df = design_df.sort_values('Average_i_pTM', ascending=False) - - # create final csv dataframe to copy matched rows, initialize with the column labels - final_df = pd.DataFrame(columns=final_labels) - - # check the ranking of the designs and copy them with new ranked IDs to the folder - rank = 1 - for _, row in design_df.iterrows(): - for binder in accepted_binders: - target_settings["binder_name"], model = binder.rsplit('_model', 1) - if target_settings["binder_name"] == row['Design']: - # rank and copy into ranked folder - row_data = {'Rank': rank, **{label: row[label] for label in design_labels}} - final_df = pd.concat([final_df, pd.DataFrame([row_data])], ignore_index=True) - old_path = os.path.join(design_paths["Accepted"], binder) - new_path = os.path.join(design_paths["Accepted/Ranked"], f"{rank}_{target_settings['binder_name']}_model{model.rsplit('.', 1)[0]}.pdb") - shutil.copyfile(old_path, new_path) - - rank += 1 - break - - # save the final_df to final_csv - final_df.to_csv(final_csv, index=False) - - # zip large folders to save space - if advanced_settings["zip_animations"]: - zip_and_empty_folder(design_paths["Trajectory/Animation"], '.html') - - if advanced_settings["zip_plots"]: - zip_and_empty_folder(design_paths["Trajectory/Plots"], '.png') - - return True - - else: - return False + else: + return False # Load required helicity value def load_helicity(advanced_settings): @@ -282,10 +295,11 @@ def load_af2_models(af_multimer_setting): return design_models, prediction_models, multimer_validation # create csv for insertion of data -def create_dataframe(csv_file, columns): - if not os.path.exists(csv_file): - df = pd.DataFrame(columns=columns) - df.to_csv(csv_file, index=False) +def create_dataframe(csv_file, columns, lock): + with lock: + if not os.path.exists(csv_file): + df = pd.DataFrame(columns=columns) + df.to_csv(csv_file, index=False) # insert row of statistics into csv def insert_data(csv_file, data_array):