Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from tqdm import tqdm
import pandas as pd
import argschema as ags
from morph_utils.ccf import projection_matrix_for_swc
import numpy as np
from morph_utils.proj_mat_utils import roll_up_proj_mat,normalize_projection_columns_per_cell


class IO_Schema(ags.ArgSchema):
output_directory = ags.fields.OutputDir(description="output directory")
Expand All @@ -12,21 +14,6 @@ class IO_Schema(ags.ArgSchema):
normalize_proj_mat = ags.fields.Boolean(default=True)


def normalize_projection_columns_per_cell(input_df, projection_column_identifiers=['ipsi', 'contra']):
"""
:param input_df: input projection df
:param projection_column_identifiers: list of identifiers for projection columns. i.e. strings that identify projection columns from metadata columns
:return: normalized projection matrix
"""
proj_cols = [c for c in input_df.columns if any([ider in c for ider in projection_column_identifiers])]
input_df[proj_cols] = input_df[proj_cols].fillna(0)

res = input_df[proj_cols].T / input_df[proj_cols].sum(axis=1)
input_df[proj_cols] = res.T

return input_df


def main(output_directory,
output_projection_csv,
projection_threshold,
Expand Down Expand Up @@ -54,6 +41,7 @@ def main(output_directory,
# proj_df_mask = pd.DataFrame(branch_and_tip_projection_records).T.fillna(0)

proj_df.to_csv(output_projection_csv)
roll_up_proj_mat(infile=output_projection_csv, outfile=output_projection_csv.replace(".csv",'_rollup.csv'))
# proj_df_mask.to_csv(output_projection_csv_tip_branch_mask)

if projection_threshold != 0:
Expand All @@ -67,6 +55,7 @@ def main(output_directory,
proj_df_arr[proj_df_arr < projection_threshold] = 0
proj_df = pd.DataFrame(proj_df_arr, columns=proj_df.columns, index=proj_df.index)
proj_df.to_csv(output_projection_csv)
roll_up_proj_mat(output_projection_csv, output_projection_csv.replace(".csv","_rollup.csv"))

# proj_df_mask_arr = proj_df_mask.values
# proj_df_mask_arr[proj_df_mask_arr < projection_threshold] = 0
Expand All @@ -79,6 +68,7 @@ def main(output_directory,

proj_df = normalize_projection_columns_per_cell(proj_df)
proj_df.to_csv(output_projection_csv)
roll_up_proj_mat(infile=output_projection_csv, outfile=output_projection_csv.replace(".csv",'_rollup.csv'))

# proj_df_mask = normalize_projection_columns_per_cell(proj_df_mask)
# proj_df_mask.to_csv(output_projection_csv_tip_branch_mask)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import subprocess
from morph_utils.ccf import projection_matrix_for_swc
from morph_utils.proj_mat_utils import roll_up_proj_mat


class IO_Schema(ags.ArgSchema):
Expand Down Expand Up @@ -88,17 +89,17 @@ def main(ccf_swc_directory,
results.append(res)

else:

this_output_projection_csv = os.path.join(single_sp_proj_dir, swc_fn.replace(".swc",".csv"))
this_output_projection_csv_check = os.path.join(single_sp_proj_dir, swc_fn.replace(".swc",f"_{mask_method}.csv"))

if not os.path.exists(this_output_projection_csv):
if not os.path.exists(this_output_projection_csv_check):

job_file = os.path.join(job_dir,swc_fn.replace(".swc",".sh"))
log_file = os.path.join(job_dir,swc_fn.replace(".swc",".log"))

command = "morph_utils_extract_projection_matrix_single_cell "
command = command+ f" --input_swc_file '{swc_pth}'"
command = command+ f" --output_projection_csv {this_output_projection_csv}"
command = command+ f" --output_projection_csv '{this_output_projection_csv}'"
command = command+ f" --projection_threshold {projection_threshold}"
command = command+ f" --normalize_proj_mat {normalize_proj_mat}"
command = command+ f" --mask_method {mask_method}"
Expand All @@ -114,15 +115,15 @@ def main(ccf_swc_directory,
command_list = [activate_command, command]

slurm_kwargs = {
"--job-name": f"{swc_fn}",
"--job-name": f"'{swc_fn}'",
"--mail-type": "NONE",
"--cpus-per-task": "1",
"--nodes": "1",
"--kill-on-invalid-dep": "yes",
"--mem": "24gb",
"--time": "1:00:00",
"--partition": "celltypes",
"--output": log_file
"--output": f"'{log_file}'"
}

dag_node = {
Expand All @@ -147,15 +148,20 @@ def main(ccf_swc_directory,
job_f.write(val)
job_f.write('\n')


command = "sbatch {}".format(job_file)
command_list = command.split(" ")
result = subprocess.run(command_list, stdout=subprocess.PIPE)
# print("Going to submit this job file")
# print(job_file)
command_list = ["sbatch", job_file]
result = subprocess.run(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

std_out = result.stdout.decode('utf-8')
std_err = result.stderr.decode('utf-8')

job_id = std_out.split("Submitted batch job ")[-1].replace("\n", "")
single_cell_job_ids.append(job_id)
# time.sleep(0.1)
if result.returncode != 0:
print("Error submitting job:", std_err)
else:
job_id = std_out.split("Submitted batch job ")[-1].strip()
single_cell_job_ids.append(job_id)


if run_host!='local':
# aggregate single projection files into proj mat
Expand Down Expand Up @@ -206,18 +212,19 @@ def main(ccf_swc_directory,
job_f.write(val)
job_f.write('\n')

command = "sbatch --dependency=afterany"
for p_jid in single_cell_job_ids:
command = command + f":{p_jid}"
command = command + " {}".format(job_file)
command_list = command.split(" ")
# print(command)
result = subprocess.run(command_list, stdout=subprocess.PIPE)
dependency_str = f"afterany:{':'.join(single_cell_job_ids)}"
command_list = ["sbatch", f"--dependency={dependency_str}", job_file]
result = subprocess.run(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
std_out = result.stdout.decode('utf-8')
std_err = result.stderr.decode('utf-8')

if result.returncode != 0:
print("Error submitting dependent job:", std_err)
else:
job_id = std_out.split("Submitted batch job ")[-1].strip()
print(f"Submitted aggregation job ID: {job_id}")


job_id = std_out.split("Submitted batch job ")[-1].replace("\n", "")


if results != []:

output_projection_csv = output_projection_csv.replace(".csv", f"_{mask_method}.csv")
Expand All @@ -235,6 +242,7 @@ def main(ccf_swc_directory,
# proj_df_mask = pd.DataFrame(branch_and_tip_projection_records).T.fillna(0)

proj_df.to_csv(output_projection_csv)
roll_up_proj_mat(output_projection_csv, output_projection_csv.replace(".csv","_rollup.csv"))
# proj_df_mask.to_csv(output_projection_csv_tip_branch_mask)

if projection_threshold != 0:
Expand All @@ -248,6 +256,8 @@ def main(ccf_swc_directory,
proj_df_arr[proj_df_arr < projection_threshold] = 0
proj_df = pd.DataFrame(proj_df_arr, columns=proj_df.columns, index=proj_df.index)
proj_df.to_csv(output_projection_csv)
roll_up_proj_mat(output_projection_csv, output_projection_csv.replace(".csv","_rollup.csv"))


# proj_df_mask_arr = proj_df_mask.values
# proj_df_mask_arr[proj_df_mask_arr < projection_threshold] = 0
Expand All @@ -260,6 +270,7 @@ def main(ccf_swc_directory,

proj_df = normalize_projection_columns_per_cell(proj_df)
proj_df.to_csv(output_projection_csv)
roll_up_proj_mat(output_projection_csv, output_projection_csv.replace(".csv","_rollup.csv"))

# proj_df_mask = normalize_projection_columns_per_cell(proj_df_mask)
# proj_df_mask.to_csv(output_projection_csv_tip_branch_mask)
Expand Down
Empty file added morph_utils/proj_mat_utils.py
Empty file.