Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 130 additions & 116 deletions functions/generic_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
####################################
################## General functions
####################################
# MODIFICATION FOR PARALLEL EXECUTION
# Minimal changes from v1.5.1 to allow safe execution as a SLURM job array.

### Import dependencies
import os
import json
Expand Down Expand Up @@ -53,75 +56,82 @@ def generate_directories(design_path):
return design_paths

# generate CSV file for tracking designs not passing filters
def generate_filter_pass_csv(failure_csv, filter_json):
if not os.path.exists(failure_csv):
with open(filter_json, 'r') as file:
data = json.load(file)

# Create a list of modified keys
names = ['Trajectory_logits_pLDDT', 'Trajectory_softmax_pLDDT', 'Trajectory_one-hot_pLDDT', 'Trajectory_final_pLDDT', 'Trajectory_Contacts', 'Trajectory_Clashes', 'Trajectory_WrongHotspot']
special_prefixes = ('Average_', '1_', '2_', '3_', '4_', '5_')
tracked_filters = set()

for key in data.keys():
processed_name = key # Use the full key by default

# Check if the key starts with any special prefixes
for prefix in special_prefixes:
if key.startswith(prefix):
# Strip the prefix and use the remaining part
processed_name = key.split('_', 1)[1]
break

# Handle 'InterfaceAAs' with appending amino acids
if 'InterfaceAAs' in processed_name:
# Generate 20 variations of 'InterfaceAAs' with amino acids appended
amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
for aa in amino_acids:
variant_name = f"InterfaceAAs_{aa}"
if variant_name not in tracked_filters:
names.append(variant_name)
tracked_filters.add(variant_name)
elif processed_name not in tracked_filters:
# Add processed name if it hasn't been added before
names.append(processed_name)
tracked_filters.add(processed_name)

# make dataframe with 0s
df = pd.DataFrame(columns=names)
df.loc[0] = [0] * len(names)

df.to_csv(failure_csv, index=False)
def generate_filter_pass_csv(failure_csv, filter_json, failure_csv_lock):
with failure_csv_lock:
if not os.path.exists(failure_csv):
with open(filter_json, 'r') as file:
data = json.load(file)

# Create a list of modified keys
names = ['Trajectory_logits_pLDDT', 'Trajectory_softmax_pLDDT', 'Trajectory_one-hot_pLDDT', 'Trajectory_final_pLDDT', 'Trajectory_Contacts', 'Trajectory_Clashes', 'Trajectory_WrongHotspot']
special_prefixes = ('Average_', '1_', '2_', '3_', '4_', '5_')
tracked_filters = set()

for key in data.keys():
processed_name = key # Use the full key by default

# Check if the key starts with any special prefixes
for prefix in special_prefixes:
if key.startswith(prefix):
# Strip the prefix and use the remaining part
processed_name = key.split('_', 1)[1]
break

# Handle 'InterfaceAAs' with appending amino acids
if 'InterfaceAAs' in processed_name:
# Generate 20 variations of 'InterfaceAAs' with amino acids appended
amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
for aa in amino_acids:
variant_name = f"InterfaceAAs_{aa}"
if variant_name not in tracked_filters:
names.append(variant_name)
tracked_filters.add(variant_name)
elif processed_name not in tracked_filters:
# Add processed name if it hasn't been added before
names.append(processed_name)
tracked_filters.add(processed_name)

# make dataframe with 0s
df = pd.DataFrame(columns=names)
df.loc[0] = [0] * len(names)

df.to_csv(failure_csv, index=False)

# update failure rates from trajectories and early predictions
def update_failures(failure_csv, failure_column_or_dict):
failure_df = pd.read_csv(failure_csv)

def strip_model_prefix(name):
# Strips the model-specific prefix if it exists
parts = name.split('_')
if parts[0].isdigit():
return '_'.join(parts[1:])
return name

# update dictionary coming from complex prediction
if isinstance(failure_column_or_dict, dict):
# Update using a dictionary of failures
for filter_name, count in failure_column_or_dict.items():
stripped_name = strip_model_prefix(filter_name)
if stripped_name in failure_df.columns:
failure_df[stripped_name] += count
else:
failure_df[stripped_name] = count
else:
# Update a single column from trajectory generation
failure_column = strip_model_prefix(failure_column_or_dict)
if failure_column in failure_df.columns:
failure_df[failure_column] += 1
def update_failures(failure_csv, failure_column_or_dict, failure_csv_lock):
with failure_csv_lock:
try:
failure_df = pd.read_csv(failure_csv)
except pd.errors.EmptyDataError:
# This is a safeguard. If the file is empty for any reason, crash is prevented by skipping the update.
print(f"WARNING: Could not update failure stats because {os.path.basename(failure_csv)} was empty.")
return

def strip_model_prefix(name):
# Strips the model-specific prefix if it exists
parts = name.split('_')
if parts[0].isdigit():
return '_'.join(parts[1:])
return name

# update dictionary coming from complex prediction
if isinstance(failure_column_or_dict, dict):
# Update using a dictionary of failures
for filter_name, count in failure_column_or_dict.items():
stripped_name = strip_model_prefix(filter_name)
if stripped_name in failure_df.columns:
failure_df[stripped_name] += count
else:
failure_df[stripped_name] = count
else:
failure_df[failure_column] = 1

failure_df.to_csv(failure_csv, index=False)
# Update a single column from trajectory generation
failure_column = strip_model_prefix(failure_column_or_dict)
if failure_column in failure_df.columns:
failure_df[failure_column] += 1
else:
failure_df[failure_column] = 1

failure_df.to_csv(failure_csv, index=False)

# Check if number of trajectories generated
def check_n_trajectories(design_paths, advanced_settings):
Expand All @@ -134,53 +144,56 @@ def check_n_trajectories(design_paths, advanced_settings):
return False

# Check if we have required number of accepted targets, rank them, and analyse sequence and structure properties
def check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels):
accepted_binders = [f for f in os.listdir(design_paths["Accepted"]) if f.endswith('.pdb')]
def check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels, finalization_lock, mpnn_csv_lock, final_csv_lock):
with finalization_lock:
accepted_binders = [f for f in os.listdir(design_paths["Accepted"]) if f.endswith('.pdb')]

if len(accepted_binders) >= target_settings["number_of_final_designs"]:
print(f"Target number {str(len(accepted_binders))} of designs reached! Reranking...")

# clear the Ranked folder in case we added new designs in the meantime so we rerank them all
for f in os.listdir(design_paths["Accepted/Ranked"]):
os.remove(os.path.join(design_paths["Accepted/Ranked"], f))

# load dataframe of designed binders
with mpnn_csv_lock:
design_df = pd.read_csv(mpnn_csv)
design_df = design_df.sort_values('Average_i_pTM', ascending=False)

# create final csv dataframe to copy matched rows, initialize with the column labels
final_df = pd.DataFrame(columns=final_labels)

# check the ranking of the designs and copy them with new ranked IDs to the folder
rank = 1
for _, row in design_df.iterrows():
for binder in accepted_binders:
target_settings["binder_name"], model = binder.rsplit('_model', 1)
if target_settings["binder_name"] == row['Design']:
# rank and copy into ranked folder
row_data = {'Rank': rank, **{label: row[label] for label in design_labels}}
final_df = pd.concat([final_df, pd.DataFrame([row_data])], ignore_index=True)
old_path = os.path.join(design_paths["Accepted"], binder)
new_path = os.path.join(design_paths["Accepted/Ranked"], f"{rank}_{target_settings['binder_name']}_model{model.rsplit('.', 1)[0]}.pdb")
shutil.copyfile(old_path, new_path)

rank += 1
break

# save the final_df to final_csv
with final_csv_lock:
final_df.to_csv(final_csv, index=False)

# zip large folders to save space
if advanced_settings["zip_animations"]:
zip_and_empty_folder(design_paths["Trajectory/Animation"], '.html')

if advanced_settings["zip_plots"]:
zip_and_empty_folder(design_paths["Trajectory/Plots"], '.png')

return True

if len(accepted_binders) >= target_settings["number_of_final_designs"]:
print(f"Target number {str(len(accepted_binders))} of designs reached! Reranking...")

# clear the Ranked folder in case we added new designs in the meantime so we rerank them all
for f in os.listdir(design_paths["Accepted/Ranked"]):
os.remove(os.path.join(design_paths["Accepted/Ranked"], f))

# load dataframe of designed binders
design_df = pd.read_csv(mpnn_csv)
design_df = design_df.sort_values('Average_i_pTM', ascending=False)

# create final csv dataframe to copy matched rows, initialize with the column labels
final_df = pd.DataFrame(columns=final_labels)

# check the ranking of the designs and copy them with new ranked IDs to the folder
rank = 1
for _, row in design_df.iterrows():
for binder in accepted_binders:
target_settings["binder_name"], model = binder.rsplit('_model', 1)
if target_settings["binder_name"] == row['Design']:
# rank and copy into ranked folder
row_data = {'Rank': rank, **{label: row[label] for label in design_labels}}
final_df = pd.concat([final_df, pd.DataFrame([row_data])], ignore_index=True)
old_path = os.path.join(design_paths["Accepted"], binder)
new_path = os.path.join(design_paths["Accepted/Ranked"], f"{rank}_{target_settings['binder_name']}_model{model.rsplit('.', 1)[0]}.pdb")
shutil.copyfile(old_path, new_path)

rank += 1
break

# save the final_df to final_csv
final_df.to_csv(final_csv, index=False)

# zip large folders to save space
if advanced_settings["zip_animations"]:
zip_and_empty_folder(design_paths["Trajectory/Animation"], '.html')

if advanced_settings["zip_plots"]:
zip_and_empty_folder(design_paths["Trajectory/Plots"], '.png')

return True

else:
return False
else:
return False

# Load required helicity value
def load_helicity(advanced_settings):
Expand Down Expand Up @@ -282,10 +295,11 @@ def load_af2_models(af_multimer_setting):
return design_models, prediction_models, multimer_validation

# create csv for insertion of data
def create_dataframe(csv_file, columns):
if not os.path.exists(csv_file):
df = pd.DataFrame(columns=columns)
df.to_csv(csv_file, index=False)
def create_dataframe(csv_file, columns, lock):
with lock:
if not os.path.exists(csv_file):
df = pd.DataFrame(columns=columns)
df.to_csv(csv_file, index=False)

# insert row of statistics into csv
def insert_data(csv_file, data_array):
Expand Down