diff --git a/README.md b/README.md index 6d6d432..1d3ce6a 100644 --- a/README.md +++ b/README.md @@ -27,10 +27,6 @@ If you'd like to reproduce this analysis on our data, check out the following do - [Preprocessing documentation](./docs/PREPROCESSING.md) for generating pointclouds and SDFs from from our input movies (step 2). - [Development documentation](./docs/DEVELOPMENT.md) for guidance working on the code in this repository. -# Using the models - -Coming soon - # Quickstart Use the cytodl api to run any given experiment. For e.g., train a rotation invariant point cloud autoencoder on the PCNA dataset diff --git a/configs/results/npm1.yaml b/configs/results/npm1.yaml index 0a4ad9f..5f65f60 100644 --- a/configs/results/npm1.yaml +++ b/configs/results/npm1.yaml @@ -26,4 +26,12 @@ data_paths: "/data/npm1/pc.yaml", ] classification_label: ["STR_connectivity_cc_thresh"] -regression_label: ["mean_surface_area", "mean_volume", "avg_dists", "std_dists"] +regression_label: + [ + "mean_surface_area", + "std_surface_area", + "mean_volume", + "std_volume", + "avg_dists", + "std_dists", + ] diff --git a/configs/results/npm1_64_res.yaml b/configs/results/npm1_64_res.yaml index a78a156..2766629 100644 --- a/configs/results/npm1_64_res.yaml +++ b/configs/results/npm1_64_res.yaml @@ -3,11 +3,6 @@ image_path: ./morphology_appropriate_representation_learning/preprocessed_data/n pc_path: ./morphology_appropriate_representation_learning/preprocessed_data/npm1_64_res/manifest.csv model_checkpoints: [ - "./morphology_appropriate_representation_learning/model_checkpoints/npm1/Classical_image_seg.ckpt", - "./morphology_appropriate_representation_learning/model_checkpoints/npm1/Rotation_invariant_image_seg.ckpt", - "./morphology_appropriate_representation_learning/model_checkpoints/npm1/Classical_image_SDF.ckpt", - "./morphology_appropriate_representation_learning/model_checkpoints/npm1/Rotation_invariant_image_SDF.ckpt", - "./morphology_appropriate_representation_learning/model_checkpoints/npm1/Rotation_invariant_pointcloud_SDF.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/npm1_64_res/Classical_image_seg.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/npm1_64_res/Rotation_invariant_image_seg.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/npm1_64_res/Classical_image_SDF.ckpt", @@ -21,19 +16,9 @@ names: "Classical_image_SDF", "Rotation_invariant_image_SDF", "Rotation_invariant_pointcloud_SDF", - "Classical_image_seg_64", - "Rotation_invariant_image_seg_64", - "Classical_image_SDF_64", - "Rotation_invariant_image_SDF_64", - "Rotation_invariant_pointcloud_SDF_64", ] data_paths: [ - "/data/npm1/classical_image_seg.yaml", - "/data/npm1/so3_image_seg.yaml", - "/data/npm1/classical_image_sdf.yaml", - "/data/npm1/so3_image_sdf.yaml", - "/data/npm1/pc.yaml", "/data/npm1_64_res/classical_image_seg.yaml", "/data/npm1_64_res/so3_image_seg.yaml", "/data/npm1_64_res/classical_image_sdf.yaml", @@ -41,4 +26,12 @@ data_paths: "/data/npm1_64_res/pc.yaml", ] classification_label: ["STR_connectivity_cc_thresh"] -regression_label: ["mean_surface_area", "mean_volume", "avg_dists", "std_dists"] +regression_label: + [ + "mean_surface_area", + "std_surface_area", + "mean_volume", + "std_volume", + "avg_dists", + "std_dists", + ] diff --git a/configs/results/other_polymorphic.yaml b/configs/results/other_polymorphic.yaml index 78b7e19..1b31cf3 100644 --- a/configs/results/other_polymorphic.yaml +++ b/configs/results/other_polymorphic.yaml @@ -3,27 +3,27 @@ image_path: ./morphology_appropriate_representation_learning/preprocessed_data/o pc_path: ./morphology_appropriate_representation_learning/preprocessed_data/other_polymorphic/manifest.csv model_checkpoints: [ - "./morphology_appropriate_representation_learning/model_checkpoints/other_polymorphic/Rotation_invariant_pointcloud_SDF.ckpt", - "./morphology_appropriate_representation_learning/model_checkpoints/other_polymorphic/Rotation_invariant_image_SDF.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/other_polymorphic/Classical_image_seg.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/other_polymorphic/Rotation_invariant_image_seg.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/other_polymorphic/Classical_image_SDF.ckpt", - "./morphology_appropriate_representation_learning/model_checkpoints/other_polymorphic/Classical_image_seg.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/other_polymorphic/Rotation_invariant_image_SDF.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/other_polymorphic/Rotation_invariant_pointcloud_SDF.ckpt", ] names: [ - "Rotation_invariant_pointcloud_SDF", - "Rotation_invariant_image_SDF", + "Classical_image_seg", "Rotation_invariant_image_seg", "Classical_image_SDF", - "Classical_image_seg", + "Rotation_invariant_image_SDF", + "Rotation_invariant_pointcloud_SDF", ] data_paths: [ - "/data/other_polymorphic/pc.yaml", - "/data/other_polymorphic/so3_image_sdf.yaml", + "/data/other_polymorphic/classical_image_seg.yaml", "/data/other_polymorphic/so3_image_seg.yaml", "/data/other_polymorphic/classical_image_sdf.yaml", - "/data/other_polymorphic/classical_image_seg.yaml", + "/data/other_polymorphic/so3_image_sdf.yaml", + "/data/other_polymorphic/pc.yaml", ] classification_label: ["structure_name"] regression_label: diff --git a/docs/DEVELOPMENT.md b/docs/DEVELOPMENT.md index 735752b..35e7ce4 100644 --- a/docs/DEVELOPMENT.md +++ b/docs/DEVELOPMENT.md @@ -19,15 +19,16 @@ benchmarking_representations │ │ ├── cellpack │ │ ├── chandrasekaran_et_al │ │ ├── data +│   │   │   ├── cellpack <- Config files associated with cellpack simulations │   │   │   ├── get_datamodules.py <- Get final list of datamodules per dataset │ │ │ └── preprocessing <- Preprocessing scripts to generate point clouds and SDFs -│ │ ├── features │ │   ├── features <- Metrics for benchmarking each model │ │   │   ├── archetype.py <- Archetype analysis functions │ │   │   ├── classification.py <- Test set classification accuracies using logistic regression classifiers │ │   │   ├── outlier_compactness.py <- Intrinsic dimensionality calculation and outlier classification │ │   │   ├── reconstruction.py <- Functions for reconstruction viz across models │ │   │   ├── regression.py <- Linear regression test set r^2 +│ │   │   ├── evolve.py <- Evolution energy and interpolation distance metrics │ │   │   ├── rotation_invariance.py <- Sensitivity to four 90 degree rotations in embedding space │ │   │   └── plot.py <- Polar plot viz across metrics │   │   ├── models <- Training and inference scripts @@ -36,8 +37,9 @@ benchmarking_representations │   │   │   ├── save_embeddings.py <- Save embeddings using inference functions │   │   │   ├── load_models.py <- Load trained models based on checkpoint paths │   │   │   └── compute_features.py <- Compute multi-metric features for each model based on saved embeddings -│ │ ├── notebooks <- Jupyter notebooks │ │ └── visualization +│   │   │   ├── mitsuba_render_image.py <- Mitsuba rendering for image segmentations +│   │   │   ├── mitsuba_render_pc.py <- Mitsuba rendering for pointclouds │ └── pointcloudutils │      ├── datamodules <- Custom datamodules │      │   └── cellpack.py <- CellPACK data specific datamodule diff --git a/docs/PREPROCESSING.md b/docs/PREPROCESSING.md index dfdb0ae..76de45b 100644 --- a/docs/PREPROCESSING.md +++ b/docs/PREPROCESSING.md @@ -52,7 +52,8 @@ src # Polymorphic structures: Generate SDFs -Use the segmentation data for polymorphic structures as input to the SDF generation step. +Use the segmentation data for polymorphic structures as input to the SDF generation step. + ``` src └── br diff --git a/docs/USAGE.md b/docs/USAGE.md index 07148ad..66cb405 100644 --- a/docs/USAGE.md +++ b/docs/USAGE.md @@ -68,7 +68,11 @@ export CYTODL_CONFIG_PATH=$PWD/configs/ ## Steps to download pre-processed data -Coming soon. +Preprocessing the data can take several hours. To skip this step, download the preprocessed data for each dataset. This will use around 740 GB. + +```bash +aws s3 cp --no-sign-request --recursive s3://allencell/aics/morphology_appropriate_representation_learning/preprocessed_data/ +``` ## Steps to train models @@ -130,7 +134,7 @@ $ pwd 3. Download the 30 models. This will use almost 4GB. ```bash -aws s3 cp --no-sign-request --recursive s3://allencell/aics/morphology_appropriate_representation_learning/model_checkpoints/ morphology_appropriate_representation_learning/model_checkpoints/ +aws s3 cp --no-sign-request --recursive s3://allencell/aics/morphology_appropriate_representation_learning/model_checkpoints/ ``` ### Option 2: Download individual checkpoints @@ -141,25 +145,23 @@ By default, the checkpoint files are expected in `benchmarking_representations/m ## Compute embeddings -To compute embeddings from the trained models, update the data paths in the [datamodule files](../configs/data/) to point to your pre-processed data. -Then, run the following commands. +Skip to the [next section](#3-interpretability-analysis) if you'd like to just use our pre-computed embeddings. Otherwise, to compute embeddings from the trained models, update the data paths in the [datamodule files](../configs/data/) to point to your pre-processed data. Then, run the following commands. -| Dataset | Embedding command | -| ----------------- | -------------------------------------------------------------------------------------------------------------------------------------------- | -| cellpack | `python src/br/analysis/run_embeddings.py --save_path "./outputs/" --sdf False --dataset_name cellpack --batch_size 5 --debug False` | -| npm1_perturb | `python src/br/analysis/run_embeddings.py --save_path "./outputs/" --sdf True --dataset_name npm1_perturb --batch_size 5 --debug False` | -| npm1 | `python src/br/analysis/run_embeddings.py --save_path "./outputs/" --sdf True --dataset_name npm1 --batch_size 5 --debug False` | -| other_polymorphic | `python src/br/analysis/run_embeddings.py --save_path "./outputs/" --sdf True --dataset_name other_polymorphic --batch_size 5 --debug False` | -| other_punctate | `python src/br/analysis/run_embeddings.py --save_path "./outputs/" --sdf False --dataset_name other_punctate --batch_size 5 --debug False` | -| pcna | `python src/br/analysis/run_embeddings.py --save_path "./outputs/" --sdf False --dataset_name pcna --batch_size 5 --debug False` | +| Dataset | Embedding command | +| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| cellpack | `python src/br/analysis/run_embeddings.py --save_path "./outputs_cellpack/" --sdf False --dataset_name cellpack --batch_size 5 --debug False` | +| npm1_perturb | `python src/br/analysis/run_embeddings.py --save_path "./outputs_npm1_perturb/" --sdf True --dataset_name npm1_perturb --batch_size 5 --debug False` | +| npm1 | `python src/br/analysis/run_embeddings.py --save_path "./outputs_npm1/" --sdf True --dataset_name npm1 --batch_size 5 --debug False` | +| npm1_64_res | `python src/br/analysis/run_embeddings.py --save_path "./outputs_npm1_64_res/" --sdf True --dataset_name npm1_64_res --batch_size 5 --debug False --eval_scaled_img_resolution 64` | +| other_polymorphic | `python src/br/analysis/run_embeddings.py --save_path "./outputs_other_polymorphic/" --sdf True --dataset_name other_polymorphic --batch_size 5 --debug False` | +| other_punctate | `python src/br/analysis/run_embeddings.py --save_path "./outputs_other_punctate/" --sdf False --dataset_name other_punctate --batch_size 5 --debug False` | +| pcna | `python src/br/analysis/run_embeddings.py --save_path "./outputs_pcna/" --sdf False --dataset_name pcna --batch_size 5 --debug False` | # 3. Interpretability analysis ## Steps to download pre-computed embeddings -Many of the results from the paper can be reproduced just from the embeddings produced by the model. However, some results rely on statistics about the costs of running the models, which are not included with the embeddings. - -You can download our pre-computed embeddings here. +Many of the results from the paper can be reproduced just from the embeddings produced by the model. You can download our pre-computed embeddings here. - [cellPACK synthetic dataset](https://open.quiltdata.com/b/allencell/tree/aics/morphology_appropriate_representation_learning/model_embeddings/cellpack/) - [DNA replication foci dataset](https://open.quiltdata.com/b/allencell/tree/aics/morphology_appropriate_representation_learning/model_embeddings/pcna/) @@ -172,24 +174,30 @@ You can download our pre-computed embeddings here. 1. To compute benchmarking features from the embeddings and trained models, run the following commands. -| Dataset | Benchmarking features | -| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| cellpack | `python src/br/analysis/run_features.py --save_path "./outputs_cellpack/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/cellpack" --sdf False --dataset_name "cellpack" --debug False` | -| npm1 | `python src/br/analysis/run_features.py --save_path "./outputs_npm1/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1" --sdf True --dataset_name "npm1" --debug False` | -| npm1_64_res | `python src/br/analysis/run_features.py --save_path "./outputs_npm1_64_res/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1_64_res" --sdf True --dataset_name "npm1_64_res" --debug False` | -| other_polymorphic | `python src/br/analysis/run_features.py --save_path "./outputs_other_polymorphic/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_polymorphic" --sdf True --dataset_name "other_polymorphic" --debug False` | -| other_punctate | `python src/br/analysis/run_features.py --save_path "./outputs_other_punctate/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_punctate" --sdf False --dataset_name "other_punctate" --debug False` | -| pcna | `python src/br/analysis/run_features.py --save_path "./outputs_pcna/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --sdf False --dataset_name "pcna" --debug False` | +| Dataset | Benchmarking features command | +| ----------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| cellpack | `python src/br/analysis/run_features.py --save_path "./outputs_cellpack/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/cellpack" --sdf False --dataset_name "cellpack" --debug False` | +| npm1 | `python src/br/analysis/run_features.py --save_path "./outputs_npm1/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1" --sdf True --dataset_name "npm1" --debug False` | +| npm1_64_res | `python src/br/analysis/run_features.py --save_path "./outputs_npm1_64_res/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1_64_res" --sdf True --dataset_name "npm1_64_res" --debug False --eval_scaled_img_resolution 64` | +| other_polymorphic | `python src/br/analysis/run_features.py --save_path "./outputs_other_polymorphic/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_polymorphic" --sdf True --dataset_name "other_polymorphic" --debug False` | +| other_punctate | `python src/br/analysis/run_features.py --save_path "./outputs_other_punctate/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_punctate" --sdf False --dataset_name "other_punctate" --debug False` | +| pcna | `python src/br/analysis/run_features.py --save_path "./outputs_pcna/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --sdf False --dataset_name "pcna" --debug False` | + +To combine features from different runs and compare, run + +``` +python src/br/analysis/run_features_combine.py --feature_path_1 './outputs_npm1/' --feature_path_2 './outputs_npm1_64_res/' --save_path "./outputs_npm1_combine/" --dataset_name_1 "npm1" --dataset_name_2 "npm1_64_res" +``` 2. To run analysis like latent walks and archetype analysis on the embeddings and trained models, run the following commands. -| Dataset | Benchmarking features | -| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| cellpack | `python src/br/analysis/run_analysis.py --save_path "./outputs_cellpack/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/cellpack" --dataset_name "cellpack" --run_name "Rotation_invariant_pointcloud_jitter" --sdf False` | -| npm1 | `python src/br/analysis/run_analysis.py --save_path "./outputs_npm1/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1" --dataset_name "npm1" --run_name "Rotation_invariant_pointcloud_SDF" --sdf True` | -| other_polymorphic | `python src/br/analysis/run_analysis.py --save_path "./outputs_other_polymorphic/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_polymorphic" --dataset_name "other_polymorphic" --run_name "Rotation_invariant_pointcloud_SDF" --sdf True` | -| other_punctate | `python src/br/analysis/run_analysis.py --save_path "./outputs_other_punctate/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_punctate" --dataset_name "other_punctate" --run_name "Rotation_invariant_pointcloud_structurenorm" --sdf False` | -| pcna | `python src/br/analysis/run_analysis.py --save_path "./outputs_pcna/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --dataset_name "pcna" --run_name "Rotation_invariant_pointcloud_jitter" --sdf False` | +| Dataset | Analysis command | +| ----------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| cellpack | `python src/br/analysis/run_analysis.py --save_path "./outputs_cellpack/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/cellpack" --dataset_name "cellpack" --run_name "Rotation_invariant_pointcloud_jitter" --sdf False --pacmap False` | +| npm1 | `python src/br/analysis/run_analysis.py --save_path "./outputs_npm1/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1" --dataset_name "npm1" --run_name "Rotation_invariant_pointcloud_SDF" --sdf True --pacmap False` | +| other_polymorphic | `python src/br/analysis/run_analysis.py --save_path "./outputs_other_polymorphic/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_polymorphic" --dataset_name "other_polymorphic" --run_name "Rotation_invariant_pointcloud_SDF" --sdf True --pacmap True` | +| other_punctate | `python src/br/analysis/run_analysis.py --save_path "./outputs_other_punctate/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_punctate" --dataset_name "other_punctate" --run_name "Rotation_invariant_pointcloud_structurenorm" --sdf False --pacmap True` | +| pcna | `python src/br/analysis/run_analysis.py --save_path "./outputs_pcna/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --dataset_name "pcna" --run_name "Rotation_invariant_pointcloud_jitter" --sdf False --pacmap False` | 3. To run drug perturbation analysis using the pre-computed features, run diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 379455f..b28bd74 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -9,8 +9,10 @@ import matplotlib.pyplot as plt import mesh_to_sdf import numpy as np +import pacmap import pandas as pd import pyvista as pv +import seaborn as sns import torch import trimesh import yaml @@ -157,7 +159,7 @@ def setup_gpu(): print("No suitable GPU or MIG ID found. Exiting...") -def setup_evaluation_params(manifest, run_names): +def setup_evaluation_params(manifest, run_names, eval_scaled_img_resolution=None): """Return evaluation params related to. 1. loss_eval_list - which loss to use for each model (Defaults to Chamfer loss) @@ -171,6 +173,7 @@ def setup_evaluation_params(manifest, run_names): eval_scaled_img_params = [{}] * len(run_names) if "SDF" in "\t".join(run_names): + eval_scaled_img = [True] * len(run_names) loss_eval_list = [torch.nn.MSELoss(reduction="none")] * len(run_names) sample_points_list = [False] * len(run_names) skew_scale = None @@ -184,18 +187,21 @@ def setup_evaluation_params(manifest, run_names): skew_scale, ) - eval_scaled_img_resolution = 32 + if not eval_scaled_img_resolution: + eval_scaled_img_resolution = 32 gt_mesh_dir = manifest["mesh_folder"].iloc[0] gt_sampled_pts_dir = manifest["pointcloud_folder"].iloc[0] + if len(os.listdir(gt_sampled_pts_dir)) == 1: + gt_sampled_pts_dir = gt_sampled_pts_dir + "0/" gt_scale_factor_dict_path = manifest["scale_factor"].iloc[0] eval_scaled_img_params = [] for name_ in run_names: if "seg" in name_: model_type = "seg" - elif "SDF" in name_: - model_type = "sdf" elif "pointcloud" in name_: model_type = "iae" + else: + model_type = "sdf" eval_scaled_img_params.append( { "eval_scaled_img_model_type": model_type, @@ -215,7 +221,13 @@ def setup_evaluation_params(manifest, run_names): sample_points_list.append(True) else: sample_points_list.append(False) - return eval_scaled_img, eval_scaled_img_params, loss_eval_list, sample_points_list, skew_scale + return ( + eval_scaled_img, + eval_scaled_img_params, + loss_eval_list, + sample_points_list, + skew_scale, + ) def setup_evolve_params(run_names, data_config_list, keys): @@ -231,9 +243,9 @@ def setup_evolve_params(run_names, data_config_list, keys): for name_ in run_names: if "seg" in name_: model_type = "seg" - elif "pointcloud_SDF" in name_: + elif "pointcloud" in name_: model_type = "iae" - elif "SDF" in name_: + else: model_type = "sdf" eval_meshed_img_model_type.append(model_type) @@ -622,6 +634,44 @@ def archetypes_polymorphic(this_save_path, archetypes_df, all_ret, all_features) arch_dict.to_csv(this_save_path / "archetypes.csv") +def make_pacmap(this_save_path, all_ret, feats_archs): + + cols = [i for i in all_ret.columns if "mu" in i] + feats = all_ret[cols].values + embedding = pacmap.PaCMAP(n_components=2, n_neighbors=10, MN_ratio=0.5, FP_ratio=2.0) + X_transformed = embedding.fit_transform(feats, init="pca") + + archs_transform = embedding.transform(feats_archs, init="pca", basis=feats) + labels = all_ret["structure_name"].values + colors = sns.color_palette("Paired", len(np.unique(labels))) + + cdict = {i: colors[j] for j, i in enumerate(np.unique(labels))} + + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + for g in np.unique(labels): + ix = np.where(labels == g) + ax.scatter( + X_transformed[ix, 0], + X_transformed[ix, 1], + c=cdict[g], + label=g, + s=0.6, + alpha=0.6, + ) + ax.legend() + lgnd = plt.legend(loc="upper right", numpoints=1, fontsize=10) + + # change the marker size manually for both lines + for handle in lgnd.legend_handles: + handle.set_sizes([6.0]) + + ax.scatter(archs_transform[:, 0], archs_transform[:, 1], c="k", s=20, marker="x") + ax.set_xlabel("PaCMAP dim 1") + ax.set_ylabel("PaCMAP dim 2") + fig.savefig(this_save_path / "pacmap_archetypes.png", bbox_inches="tight", dpi=300) + fig.savefig(this_save_path / "pacmap_archetypes.pdf", bbox_inches="tight", dpi=300) + + def generate_reconstructions(all_models, data_list, run_names, keys, test_ids, device, save_path): with torch.no_grad(): for j, model in enumerate(all_models): @@ -652,7 +702,10 @@ def generate_reconstructions(all_models, data_list, run_names, keys, test_ids, d ) batch["points"] = uni_sample_points xhat, z, z_params = model( - move(batch, device), decode=True, inference=True, return_params=True + move(batch, device), + decode=True, + inference=True, + return_params=True, ) recon = xhat[this_key].detach().cpu().numpy().squeeze() recon = recon.reshape( @@ -798,7 +851,11 @@ def _plot_pc(input, recon, recon_canonical, struct, cmap, vmin, vmax, dataset_na if (index_ == 2) and (canon_z_ind != z_ind): xy_inds = [i for i in [0, 1, 2] if i != canon_z_ind] axes[index_].scatter( - this_p[:, xy_inds[0]], this_p[:, xy_inds[1]], c="black", s=2, alpha=0.5 + this_p[:, xy_inds[0]], + this_p[:, xy_inds[1]], + c="black", + s=2, + alpha=0.5, ) else: if not cmap: @@ -874,7 +931,14 @@ def _plot_pc(input, recon, recon_canonical, struct, cmap, vmin, vmax, dataset_na df=all_df_input, pcts=[5, 95] ) fig = _plot_pc( - input, recon, recon_canonical, struct, cmap, vmin, vmax, dataset_name + input, + recon, + recon_canonical, + struct, + cmap, + vmin, + vmax, + dataset_name, ) this_save_path_ = Path(reconstructions_path) / Path(m) @@ -957,7 +1021,9 @@ def save_supplemental_figure_sdf_reconstructions(df, test_ids, reconstructions_p gt_seg = gt_segs[i] gt_sdf = np.clip(gt_sdfs[i], -2, 2) gt_sdf_i = gt_test_i_sdfs[i].reshape( - eval_scaled_img_resolution, eval_scaled_img_resolution, eval_scaled_img_resolution + eval_scaled_img_resolution, + eval_scaled_img_resolution, + eval_scaled_img_resolution, ) row_index = i recons = [] @@ -1003,7 +1069,11 @@ def save_supplemental_figure_sdf_reconstructions(df, test_ids, reconstructions_p axs[row_index, i + 2].set_title("") # run_to_displ_name[model_order[i]]) axs[row_index, 4].imshow( - gt_sdf[:, :, mid_slice_].T, cmap="seismic", origin="lower", vmin=-2, vmax=2 + gt_sdf[:, :, mid_slice_].T, + cmap="seismic", + origin="lower", + vmin=-2, + vmax=2, ) axs[row_index, 4].axis("off") axs[row_index, 4].set_title("") # (f'GT SDF CellId {c}') @@ -1014,7 +1084,9 @@ def save_supplemental_figure_sdf_reconstructions(df, test_ids, reconstructions_p axs[row_index, i + 5].set_title("") # run_to_displ_name[model_order[i]]) axs[row_index, 7].imshow( - gt_sdf_i[:, :, mid_slice_].T.clip(-0.5, 0.5), cmap="seismic", origin="lower" + gt_sdf_i[:, :, mid_slice_].T.clip(-0.5, 0.5), + cmap="seismic", + origin="lower", ) axs[row_index, 7].axis("off") axs[row_index, 7].set_title("") # (f'GT SDF CellId {c}') diff --git a/src/br/analysis/run_analysis.py b/src/br/analysis/run_analysis.py index c018690..91d98b9 100644 --- a/src/br/analysis/run_analysis.py +++ b/src/br/analysis/run_analysis.py @@ -11,6 +11,7 @@ dataset_specific_subsetting, latent_walk_polymorphic, latent_walk_save_recons, + make_pacmap, pseudo_time_analysis, setup_gpu, str2bool, @@ -76,6 +77,9 @@ def main(args): this_save_path = Path(args.save_path) / Path("archetypes") this_save_path.mkdir(parents=True, exist_ok=True) + if args.pacmap: + make_pacmap(this_save_path, all_ret, archetypes_df) + if args.sdf: archetypes_polymorphic(this_save_path, archetypes_df, all_ret, matrix) else: @@ -93,7 +97,10 @@ def main(args): ) parser.add_argument("--run_name", type=str, required=True, help="Name of model") parser.add_argument( - "--embeddings_path", type=str, required=True, help="Path to the saved embeddings." + "--embeddings_path", + type=str, + required=True, + help="Path to the saved embeddings.", ) parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") parser.add_argument( @@ -102,6 +109,13 @@ def main(args): required=True, help="boolean indicating whether the model involves SDFs", ) + parser.add_argument( + "--pacmap", + type=str2bool, + required=False, + default=False, + help="boolean indicating whether to plot a pacmap projection of the representations and archetypes", + ) args = parser.parse_args() # Validate that required paths are provided diff --git a/src/br/analysis/run_embeddings.py b/src/br/analysis/run_embeddings.py index 673ab6b..82148c0 100644 --- a/src/br/analysis/run_embeddings.py +++ b/src/br/analysis/run_embeddings.py @@ -1,4 +1,3 @@ -# Free up cache import argparse import os import sys @@ -30,7 +29,7 @@ def main(args): loss_eval_list, sample_points_list, skew_scale, - ) = setup_evaluation_params(manifest, run_names) + ) = setup_evaluation_params(manifest, run_names, args.eval_scaled_img_resolution) # make save path directory Path(args.save_path).mkdir(parents=True, exist_ok=True) @@ -74,6 +73,13 @@ def main(args): parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") parser.add_argument("--batch_size", type=int, default=2, help="Batch size for processing.") parser.add_argument("--debug", type=str2bool, default=True, help="Enable debug mode.") + parser.add_argument( + "--eval_scaled_img_resolution", + type=int, + default=None, + required=False, + help="Resolution for SDF reconstruction", + ) args = parser.parse_args() diff --git a/src/br/analysis/run_features.py b/src/br/analysis/run_features.py index 09b5490..6bf74ca 100644 --- a/src/br/analysis/run_features.py +++ b/src/br/analysis/run_features.py @@ -1,4 +1,3 @@ -# Free up cache import argparse import os import sys @@ -17,6 +16,11 @@ from br.models.load_models import get_data_and_models from br.models.save_embeddings import save_emissions +REMOVE_RESULT = { + "cellpack": "Rotation_invariant_pointcloud_jitter", + "pcna": "Rotation_invariant_pointcloud_jitter", +} + def main(args): # Setup GPUs and set the device @@ -41,83 +45,94 @@ def main(args): ) = get_data_and_models(args.dataset_name, batch_size, config_path + "/results/", args.debug) max_embed_dim = min(latent_dims) + if args.dataset_name in REMOVE_RESULT.keys(): + remove_name = REMOVE_RESULT[args.dataset_name] + ind = run_names.index(remove_name) + del data_list[ind] + del all_models[ind] + del run_names[ind] + del model_sizes[ind] + # make save path directory Path(args.save_path).mkdir(parents=True, exist_ok=True) - # Save model sizes to CSV - sizes_ = pd.DataFrame() - sizes_["model"] = run_names - sizes_["model_size"] = model_sizes - sizes_.to_csv(os.path.join(args.save_path, "model_sizes.csv")) - - # Load evaluation params - ( - eval_scaled_img, - eval_scaled_img_params, - loss_eval_list, - sample_points_list, - skew_scale, - ) = setup_evaluation_params(manifest, run_names) - - # Save emission stats for each model - max_batches = 40 - save_emissions( - args.save_path, - data_list, - all_models, - run_names, - max_batches, - args.debug, - device, - loss_eval_list, - sample_points_list, - skew_scale, - eval_scaled_img, - eval_scaled_img_params, - ) - - # Compute multi-metric benchmarking params - ( - rot_inv_params, - compactness_params, - classification_params, - evolve_params, - regression_params, - ) = get_feature_params(config_path + "/results/", args.dataset_name, manifest, keys, run_names) - - metric_list = [ - "Rotation Invariance Error", - "Evolution Energy", - "Reconstruction", - "Classification", - "Compactness", - ] - if regression_params["target_cols"]: - metric_list.append("Regression") - - # Compute multi-metric benchmarking features - compute_features( - dataset=args.dataset_name, - results_path=config_path + "/results/", - embeddings_path=args.embeddings_path, - save_folder=args.save_path, - data_list=data_list, - all_models=all_models, - run_names=run_names, - use_sample_points_list=sample_points_list, - keys=keys, - device=device, - max_embed_dim=max_embed_dim, - splits_list=["train", "val", "test"], - compute_embeds=False, - classification_params=classification_params, - regression_params=regression_params, - metric_list=metric_list, - loss_eval_list=loss_eval_list, - evolve_params=evolve_params, - rot_inv_params=rot_inv_params, - compactness_params=compactness_params, - ) + if not args.skip_features: + # Save model sizes to CSV + sizes_ = pd.DataFrame() + sizes_["model"] = run_names + sizes_["model_size"] = model_sizes + sizes_.to_csv(os.path.join(args.save_path, "model_sizes.csv")) + + # Load evaluation params + ( + eval_scaled_img, + eval_scaled_img_params, + loss_eval_list, + sample_points_list, + skew_scale, + ) = setup_evaluation_params(manifest, run_names, args.eval_scaled_img_resolution) + + # Save emission stats for each model + max_batches = 40 + save_emissions( + args.save_path, + data_list, + all_models, + run_names, + max_batches, + args.debug, + device, + loss_eval_list, + sample_points_list, + skew_scale, + eval_scaled_img, + eval_scaled_img_params, + ) + + # Compute multi-metric benchmarking params + ( + rot_inv_params, + compactness_params, + classification_params, + evolve_params, + regression_params, + ) = get_feature_params( + config_path + "/results/", args.dataset_name, manifest, keys, run_names + ) + + metric_list = [ + "Rotation Invariance Error", + "Evolution Energy", + "Reconstruction", + "Classification", + "Compactness", + ] + if regression_params["target_cols"]: + metric_list.append("Regression") + + # Compute multi-metric benchmarking features + compute_features( + dataset=args.dataset_name, + results_path=config_path + "/results/", + embeddings_path=args.embeddings_path, + save_folder=args.save_path, + data_list=data_list, + all_models=all_models, + run_names=run_names, + use_sample_points_list=sample_points_list, + keys=keys, + device=device, + max_embed_dim=max_embed_dim, + splits_list=["train", "val", "test"], + compute_embeds=False, + classification_params=classification_params, + regression_params=regression_params, + metric_list=metric_list, + loss_eval_list=loss_eval_list, + evolve_params=evolve_params, + rot_inv_params=rot_inv_params, + compactness_params=compactness_params, + ) # Polar plot visualization # Load saved csvs @@ -130,7 +145,15 @@ def main(args): unique_metrics = [i for i in csvs if "classification" in i or "regression" in i] # Collect dataframe and make plots df, df_non_agg = collect_outputs(args.save_path, "std", run_names, csvs) - plot(args.save_path, df, run_names, args.dataset_name, "std", unique_metrics, df_non_agg) + plot( + args.save_path, + df, + run_names, + args.dataset_name, + "std", + unique_metrics, + df_non_agg, + ) if __name__ == "__main__": @@ -139,7 +162,10 @@ def main(args): "--save_path", type=str, required=True, help="Path to save the embeddings." ) parser.add_argument( - "--embeddings_path", type=str, required=True, help="Path to the saved embeddings." + "--embeddings_path", + type=str, + required=True, + help="Path to the saved embeddings.", ) parser.add_argument( "--meta_key", @@ -156,6 +182,19 @@ def main(args): ) parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") parser.add_argument("--debug", type=str2bool, default=False, help="Enable debug mode.") + parser.add_argument( + "--skip_features", + type=str2bool, + default=False, + help="Boolean indicating whether to skip feature calculation and load pre-computed csvs", + ) + parser.add_argument( + "--eval_scaled_img_resolution", + type=int, + default=None, + required=False, + help="Resolution for SDF reconstruction", + ) args = parser.parse_args() @@ -171,4 +210,6 @@ def main(args): python src/br/analysis/run_features.py --save_path "./outputs/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --sdf False --dataset_name "pcna" python src/br/analysis/run_features.py --save_path "/outputs_cellpack/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/cellpack" --sdf False --dataset_name "cellpack" --debug False + + python src/br/analysis/run_features.py --save_path "./outputs_npm1_64_res_remake/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1_64_res" --sdf True --dataset_name "npm1_64_res" --debug False """ diff --git a/src/br/analysis/run_features_combine.py b/src/br/analysis/run_features_combine.py new file mode 100644 index 0000000..2afb1e0 --- /dev/null +++ b/src/br/analysis/run_features_combine.py @@ -0,0 +1,81 @@ +import argparse +import os +import sys +from pathlib import Path + +import pandas as pd + +from br.features.plot import collect_outputs, plot +from br.models.utils import get_all_configs_per_dataset + + +def main(args): + + Path(args.save_path).mkdir(parents=True, exist_ok=True) + + # Get config path from CYTODL_CONFIG_PATH + config_path = os.environ.get("CYTODL_CONFIG_PATH") + run_names_1 = get_all_configs_per_dataset(config_path + "/results/")[args.dataset_name_1][ + "names" + ] + + # Polar plot visualization + # Load saved csvs + csvs = [i for i in os.listdir(args.feature_path_1) if i.split(".")[-1] == "csv"] + csvs = [i.split(".")[0] for i in csvs] + # Remove non metric related csvs + csvs = [i for i in csvs if i not in run_names_1 and i not in ["image", "pcloud"]] + + for csv in csvs: + df1 = pd.read_csv(args.feature_path_1 + csv + ".csv") + df2 = pd.read_csv(args.feature_path_2 + csv + ".csv") + df2["model"] = df2["model"].apply(lambda x: x + "_2") + df = pd.concat([df1, df2], axis=0).reset_index(drop=True) + df.to_csv(args.save_path + csv + ".csv") + + run_names_2 = [i + "_2" for i in run_names_1] + run_names = run_names_1 + run_names_2 + csvs = [i for i in os.listdir(args.save_path) if i.split(".")[-1] == "csv"] + csvs = [i.split(".")[0] for i in csvs] + # Remove non metric related csvs + csvs = [i for i in csvs if i not in run_names_1 and i not in ["image", "pcloud"]] + + # classification and regression metrics are unique to each dataset + unique_metrics = [i for i in csvs if "classification" in i or "regression" in i] + # Collect dataframe and make plots + df, df_non_agg = collect_outputs(args.save_path, "std", run_names, csvs) + plot( + args.save_path, + df, + run_names, + args.dataset_name_1 + "_" + args.dataset_name_2, + "std", + unique_metrics, + df_non_agg, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Script for computing features") + parser.add_argument( + "--feature_path_1", + type=str, + required=True, + help="Path to features for dataset 1.", + ) + parser.add_argument( + "--feature_path_2", + type=str, + required=True, + help="Path to features for dataset 2.", + ) + parser.add_argument("--save_path", type=str, required=True, help="Path to save results.") + parser.add_argument("--dataset_name_1", type=str, required=True, help="Name of the dataset 1.") + parser.add_argument("--dataset_name_2", type=str, required=True, help="Name of the dataset 2.") + args = parser.parse_args() + main(args) + + """ + Example run: + python src/br/analysis/run_features_combine.py --feature_path_1 './outputs_npm1_remake/' --feature_path_2 './outputs_npm1_64_res_remake/' --save_path "./outputs_npm1_combine/" --dataset_name_1 "npm1" --dataset_name_2 "npm1_64_res" + """ diff --git a/src/br/analysis/save_reconstructions.py b/src/br/analysis/save_reconstructions.py index 6b1bc8a..11700b2 100644 --- a/src/br/analysis/save_reconstructions.py +++ b/src/br/analysis/save_reconstructions.py @@ -1,4 +1,3 @@ -# Free up cache import argparse import os import sys @@ -32,7 +31,15 @@ "50b52c3e-4756-4684-a281-0141525ded9f", "8713eea5-da72-4644-96fe-ba8340edb67d", ], - "other_punctate": ["721646", "873680", "994027", "490385", "451974", "811336", "835431"], + "other_punctate": [ + "721646", + "873680", + "994027", + "490385", + "451974", + "811336", + "835431", + ], "npm1": ["964798", "661110", "644401", "967887", "703621"], "other_polymorphic": ["691110", "723687", "816468", "800894"], } @@ -96,10 +103,16 @@ def main(args): parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") parser.add_argument("--debug", type=str2bool, default=False, help="Enable debug mode.") parser.add_argument( - "--sdf", type=str2bool, default=True, help="Whether the experiments involve SDFs" + "--sdf", + type=str2bool, + default=True, + help="Whether the experiments involve SDFs", ) parser.add_argument( - "--test_ids", default=False, nargs="+", help="List of test set cellids to reconstruct" + "--test_ids", + default=False, + nargs="+", + help="List of test set cellids to reconstruct", ) parser.add_argument( "--generate_reconstructions", diff --git a/src/br/chandrasekaran_et_al/utils.py b/src/br/chandrasekaran_et_al/utils.py index b584b48..65d0844 100644 --- a/src/br/chandrasekaran_et_al/utils.py +++ b/src/br/chandrasekaran_et_al/utils.py @@ -201,24 +201,36 @@ def _plot(all_rep, save_path, run_names): x_order = ordered_drugs test = all_rep.sort_values(by="q_value").reset_index(drop=True) + test["1/q_value"] = test["q_value"].apply(lambda x: 1 / x) g = sns.catplot( data=test, x="Drugs", - y="q_value", + y="1/q_value", kind="bar", hue="model", order=x_order, hue_order=run_names, - palette=["#A6ACE0", "#6277DB", "#D9978E", "#D8553B", "#2ED9FF", "#91db57", "#db57d3"], - aspect=2, + palette=[ + "#A6ACE0", + "#6277DB", + "#D9978E", + "#D8553B", + "#2ED9FF", + "#91db57", + "#db57d3", + ], + aspect=3, height=5, dodge=True, ) g.set_xticklabels(rotation=90) - g.set(ylim=(0, 0.1)) - plt.axhline(y=0.05, color="black") - g.set(ylabel="q value") + # g.set(ylim=(0, 0.1)) + g.set(ylim=(0, 100)) + # plt.axhline(y=0.05, color="black") + plt.axhline(y=20, color="black", linestyle="--") + # g.set(ylabel="q value") + g.set(ylabel="1/q value") this_path = Path(save_path) Path(this_path).mkdir(parents=True, exist_ok=True) g.savefig(this_path / "q_values.png", dpi=300, bbox_inches="tight") diff --git a/src/br/data/preprocessing/sdf_preprocessing/get_max_bounding_box.py b/src/br/data/preprocessing/sdf_preprocessing/get_max_bounding_box.py index 69a4754..914c48a 100644 --- a/src/br/data/preprocessing/sdf_preprocessing/get_max_bounding_box.py +++ b/src/br/data/preprocessing/sdf_preprocessing/get_max_bounding_box.py @@ -1,19 +1,18 @@ +import argparse +from multiprocessing import Pool +from pathlib import Path import pandas as pd from aicsimageio import AICSImage from monai.transforms import FillHoles from tqdm import tqdm -from br.data.utils import ( - get_mesh_from_image, -) -from multiprocessing import Pool -from pathlib import Path -import argparse + +from br.data.utils import get_mesh_from_image def get_bounds(r): - return_df = {'x_delta': [], 'y_delta': [], 'z_delta': [], 'cell_id': [], 'max_delta': []} - cellid = r['CellId'] + return_df = {"x_delta": [], "y_delta": [], "z_delta": [], "cell_id": [], "max_delta": []} + cellid = r["CellId"] hole_fill_transform = FillHoles() seg = AICSImage(r["crop_seg_masked"]).data.squeeze() @@ -26,14 +25,15 @@ def get_bounds(r): y_delta = bounds[3] - bounds[2] z_delta = bounds[5] - bounds[4] max_delta = max([x_delta, y_delta, z_delta]) - return_df['x_delta'].append(x_delta) - return_df['y_delta'].append(y_delta) - return_df['z_delta'].append(z_delta) - return_df['cell_id'].append(cellid) - return_df['max_delta'].append(max_delta) + return_df["x_delta"].append(x_delta) + return_df["y_delta"].append(y_delta) + return_df["z_delta"].append(z_delta) + return_df["cell_id"].append(cellid) + return_df["max_delta"].append(max_delta) return return_df + def main(args): # make save path directory Path(args.save_path).mkdir(parents=True, exist_ok=True) @@ -61,7 +61,7 @@ def main(args): jobs = [i for i in jobs if i is not None] return_df = pd.DataFrame(jobs).reset_index(drop=True) - return_df.to_csv(Path(args.save_path) / Path('bounds.csv')) + return_df.to_csv(Path(args.save_path) / Path("bounds.csv")) if __name__ == "__main__": @@ -83,7 +83,6 @@ def main(args): help="Path to append to relative paths in preprocessed manifest", ) - args = parser.parse_args() main(args) @@ -91,4 +90,3 @@ def main(args): Example run: python get_max_bounding_box.py --save_path './test_img/' --manifest ""../../../../../morphology_appropriate_representation_learning/preprocessed_data/npm1/manifest.csv" --global_path "../../../../../" """ - diff --git a/src/br/features/evolve.py b/src/br/features/evolve.py index df3fd1a..951b54e 100644 --- a/src/br/features/evolve.py +++ b/src/br/features/evolve.py @@ -208,24 +208,27 @@ def model_pass_reconstruct( target_bounds_initial = get_mesh_bbox_shape(mesh_initial) target_bounds_final = get_mesh_bbox_shape(mesh_final) target_bounds = [max(i, j) for i, j in zip(target_bounds_initial, target_bounds_final)] - recon_int, recon_initial, recon_final = voxelize_recon_meshes( - [mesh, mesh_initial, mesh_final], target_bounds - ) - recon_initial = np.where(recon_initial > 0.5, 1, 0) - recon_int = np.where(recon_int > 0.5, 1, 0) - recon_final = np.where(recon_final > 0.5, 1, 0) - - mse_total = 1 - jaccard_similarity_score( - recon_final.flatten(), recon_initial.flatten(), pos_label=1 - ) - mse_intial = 1 - jaccard_similarity_score( - recon_initial.flatten(), recon_int.flatten(), pos_label=1 - ) - mse_final = 1 - jaccard_similarity_score( - recon_final.flatten(), recon_int.flatten(), pos_label=1 - ) - energy = (mse_intial + mse_final) / mse_total - return energy.item() + try: + recon_int, recon_initial, recon_final = voxelize_recon_meshes( + [mesh, mesh_initial, mesh_final], target_bounds + ) + recon_initial = np.where(recon_initial > 0.5, 1, 0) + recon_int = np.where(recon_int > 0.5, 1, 0) + recon_final = np.where(recon_final > 0.5, 1, 0) + + mse_total = 1 - jaccard_similarity_score( + recon_final.flatten(), recon_initial.flatten(), pos_label=1 + ) + mse_intial = 1 - jaccard_similarity_score( + recon_initial.flatten(), recon_int.flatten(), pos_label=1 + ) + mse_final = 1 - jaccard_similarity_score( + recon_final.flatten(), recon_int.flatten(), pos_label=1 + ) + energy = (mse_intial + mse_final) / mse_total + return energy.item() + except: + return np.NaN else: if (key == "pcloud") and (not use_sample_points) and (init_x.shape[-1] <= 4): init_x = init_x[:, :, :3] diff --git a/src/br/features/plot.py b/src/br/features/plot.py index f8f73df..0c5608f 100644 --- a/src/br/features/plot.py +++ b/src/br/features/plot.py @@ -20,6 +20,7 @@ METRIC_DICT = { "reconstruction": {"metric": ["loss"], "min": [True]}, "regression": {"metric": ["test_r2"], "min": [False]}, + "regression_dists": {"metric": ["test_r2"], "min": [False]}, "classification": {"metric": ["top_1_acc"], "min": [False]}, "emissions": {"metric": ["emissions", "inference_time"], "min": [True, True]}, "evolution_energy": { @@ -140,6 +141,7 @@ def collect_outputs(path, norm, model_order=None, metric_list=None): rep_dict_var = { "reconstruction_loss": "Reconstruction", "regression_test_r2": "Feature Regression", + "regression_dists_test_r2": "Feature Regression_dists", "compactness_compactness": "Compactness", "rotation_invariance_error_value": "Rotation Invariance Error", "evolution_energy_closest_embedding_distance": "Embedding Distance", @@ -262,10 +264,10 @@ def plot( fig.write_image(path / f"{title}.pdf", scale=3) if df_non_agg is not None: - sns.set(font_scale=1.1) - sns.set_theme(style="white") + sns.set(style="white", font_scale=1.3) for var in df_non_agg["variable"].unique(): this_df = df_non_agg.loc[df_non_agg["variable"] == var].reset_index(drop=True) + g = sns.catplot( data=this_df, y="model", @@ -290,15 +292,30 @@ def plot( alpha=0.6, ec="k", linewidth=1, - s=1, + s=2, ) - g.set( - xlim=[ - np.nanquantile(this_df["value"].values, 0.05), - np.nanquantile(this_df["value"].values, 0.95), - ] - ) + if (var != "Model Size") and (var != "Emissions"): + g.set( + xlim=[ + np.nanquantile(this_df["value"].values, 0.001), + np.nanquantile(this_df["value"].values, 0.999), + ] + ) + elif var == "Emissions": + g.set( + xlim=[ + np.nanquantile(this_df["value"].values, 0.05), + np.nanquantile(this_df["value"].values, 0.95), + ] + ) + else: + g.set( + xlim=[ + this_df["value"].values.min() - 0.1 * this_df["value"].values.min(), + this_df["value"].values.max() + 0.1 * this_df["value"].values.max(), + ] + ) g.set(yticklabels=[]) diff --git a/src/br/features/regression.py b/src/br/features/regression.py index 2504084..a3b4edf 100644 --- a/src/br/features/regression.py +++ b/src/br/features/regression.py @@ -19,6 +19,15 @@ def get_regression_df(all_ret, target_cols, feature_df_path, df_feat=None): this_mo = all_ret.loc[all_ret["model"] == model].reset_index(drop=True) if df_feat is not None and target not in this_mo.columns: this_mo = this_mo.merge(df_feat, on="CellId") + + if "avg_dists" in target: + this_mo = this_mo[this_mo[target] != 0] + print(len(this_mo)) + if "std_dists" in target: + avg_col = f"avg{target.split('std')[1]}" + this_mo = this_mo[this_mo[avg_col] != 0] + print(len(this_mo)) + test_r2, test_mse = get_regression(this_mo, target) for i in range(len(test_r2)): ret_dict5["model"].append(model) diff --git a/src/br/models/predict_model.py b/src/br/models/predict_model.py index 991fd00..defde13 100644 --- a/src/br/models/predict_model.py +++ b/src/br/models/predict_model.py @@ -223,7 +223,9 @@ def base_forward( key = "image" if eval_scaled_img and eval_scaled_img_model_type == "iae": - uni_sample_points = get_iae_reconstruction_3d_grid() + uni_sample_points = get_iae_reconstruction_3d_grid( + bb_min=-0.5, bb_max=0.5, resolution=eval_scaled_img_resolution, padding=0.1 + ) uni_sample_points = uni_sample_points.unsqueeze(0).repeat(this_batch[key].shape[0], 1, 1) this_batch["points"] = uni_sample_points xhat, z, z_params = model( @@ -317,6 +319,7 @@ def base_forward( ) for cellid, recon_data in zip(cellids, recon_data_list) ] + with Pool(processes=8) as pool: errs = pool.map(multi_proc_scale_img_eval, args) @@ -459,7 +462,6 @@ def process_batch( emissions_df["inference_time"] = time else: out, z, loss, x_vis_list = [*model_outputs] - all_outputs.append(out) all_embeds.append(z) all_emissions.append(emissions_df) diff --git a/src/br/models/save_embeddings.py b/src/br/models/save_embeddings.py index f7c7224..483e69c 100644 --- a/src/br/models/save_embeddings.py +++ b/src/br/models/save_embeddings.py @@ -298,7 +298,7 @@ def save_emissions( loss_eval = get_pc_loss() if loss_eval_list is None else loss_eval_list[j_ind] with torch.no_grad(): count = 0 - for i in tqdm(this_data.test_dataloader()): + for i in tqdm(this_data.train_dataloader()): if count < max_batches: track_emissions = True else: