diff --git a/.travis.yml b/.travis.yml index ec1f39b..e11f6f6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,15 +1,12 @@ -dist: xenial +dist: focal language: python python: - '3.6' install: - pip3 install -r requirements.txt -script: - - pytest +script: pytest deploy: provider: pypi user: theaustinator - on: - tags: true password: - secure: + secure: P5b36T+ulGokWYCGqt/JRN9l4p3HwTlQBazpWxhqz3bHQ8GhmIAYRn5xPno0R4hA0dyO0hKtdp8Vkd+RqTmdvv0HJSHvLp4HC5A9JSugChnA5CRnfbBQ8VWMxNrGGJDADaaRFyT7GCzE317c4LOpcLsCqRUnAIytNqm7VvLCL00nQmeQu2b0sYsjXHbHbQQVcWEdMlMptr+hUqGiDXRjG8/1luMvP7ZUe0IBcVcHKp2hwIjTzwqPCyNF0J92l7DEmTHVyfmPYp5ioqdQDvJ5DjN6bBNELD/uEbmq3A8zagm47m46aGc0uBiT7qpKLW4w+fYlVkph2Uvj3qRvnAzWCcRLtItkJyIj7V2ovKSs7btheUmJHM93JnQc6hRqYfdggKLBDgosvLLOab4xBzJW5K7JwD4Wa9NCs3pidJWcF3WSTm6xpTFj19uEd8m9xQ/KPn3UfOAwSTImi+7Ya/LpYMV1xMwuKyQMWqanlOfpi5CFSzBrrBC44lL/ZhXjgEfocxSvCeqm+aK2gObDq2Eymze/OVvvOrBgnilU5D3EMai5ustkj8RjnjSEFJKAJYeXwGE4Yx73Ger0AZDqvLWbn8Pf7rHfcqpSJJx7jIJUeOwhWP6vme1vpqvW8V/VuwJP6ZobkyQ7xmnl5ceNGTYPOtfzhcpw5S7WAukSzpYQKmQ= diff --git a/dataforest/__init__.py b/dataforest/__init__.py index a21a918..2273ab6 100644 --- a/dataforest/__init__.py +++ b/dataforest/__init__.py @@ -1,6 +1,8 @@ -from dataforest.utils.loaders.update_config import get_current_config, update_config +from dataforest.utils.loaders.config import CONFIG_OPTIONS, load_config as _load_config +from dataforest.utils.loaders.update_config import get_current_config, get_config_updater as _get_config_updater from dataforest.core.Interface import Interface +update_config = _get_config_updater(_load_config) load = Interface.load from_input_dirs = Interface.from_input_dirs diff --git a/dataforest/config/MetaPlotMethods.py b/dataforest/config/MetaPlotMethods.py index 428ccca..68dfcc0 100644 --- a/dataforest/config/MetaPlotMethods.py +++ b/dataforest/config/MetaPlotMethods.py @@ -1,5 +1,10 @@ +from pathlib import Path +from typing import List + from dataforest.config.MetaConfig import MetaConfig from dataforest.utils.loaders.collectors import collect_plots +from dataforest.utils.loaders.path import get_module_paths +from dataforest.utils.plots_config import build_process_plot_method_lookup, parse_plot_kwargs class MetaPlotMethods(MetaConfig): @@ -8,5 +13,33 @@ def PLOT_METHOD_LOOKUP(cls): return {k: v for source in cls.CONFIG["plot_sources"] for k, v in collect_plots(source).items()} @property - def PLOT_METHODS(cls): - return cls.CONFIG["plot_methods"] + def PLOT_MAP(cls): + return cls.CONFIG.get("plot_map", dict()) + + @property + def PROCESS_PLOT_METHODS(cls): + try: + plot_methods = cls.CONFIG["plot_methods"] + except KeyError: + plot_methods = build_process_plot_method_lookup(cls.CONFIG.get("plot_map", dict())) + + return plot_methods + + @property + def PLOT_KWARGS_DEFAULTS(cls): + return cls.CONFIG.get("plot_kwargs_defaults", dict()) + + @property + def PLOT_KWARGS(cls): # TODO-QC: mapping of process, plot to plot_kwargs + plot_map = cls.CONFIG.get("plot_map", dict()) + plot_kwargs_defaults = cls.CONFIG.get("plot_kwargs_defaults", dict()) + plot_kwargs = parse_plot_kwargs(plot_map, plot_kwargs_defaults) + return plot_kwargs + + # @property + # def R_FUNCTIONS_FILEPATH(cls) -> Path: + # return get_module_paths([cls.CONFIG["r_functions_sources"]])[0] + + @property + def R_PLOT_SOURCES(cls) -> List[Path]: + return get_module_paths(cls.CONFIG["r_plot_sources"]) diff --git a/dataforest/config/MetaProcessSchema.py b/dataforest/config/MetaProcessSchema.py index 92d3a87..f1b3e68 100644 --- a/dataforest/config/MetaProcessSchema.py +++ b/dataforest/config/MetaProcessSchema.py @@ -4,6 +4,7 @@ from pathlib import Path from dataforest.config.MetaConfig import MetaConfig +from dataforest.utils.plots_config import parse_plot_map class MetaProcessSchema(MetaConfig): @@ -19,8 +20,10 @@ def FILE_MAP(cls): return cls["file_map"] @property - def PLOT_MAP(cls): - return cls["plot_map"] + def PLOT_MAP(cls): # TODO-QC: process plot map starting here? Make it into a class where you can fetch plot_kwargs? + plot_map = cls.CONFIG.get("plot_map", dict()) + plot_kwargs_defaults = cls.CONFIG.get("plot_kwargs_defaults", dict()) + return parse_plot_map(plot_map, plot_kwargs_defaults) @property def LAYERS(cls): diff --git a/dataforest/core/BranchSpec.py b/dataforest/core/BranchSpec.py index ab65c25..4434757 100644 --- a/dataforest/core/BranchSpec.py +++ b/dataforest/core/BranchSpec.py @@ -2,6 +2,9 @@ from copy import deepcopy from typing import Union, List, Dict +from typeguard import typechecked + +from dataforest.core.RunGroupSpec import RunGroupSpec from dataforest.core.RunSpec import RunSpec from dataforest.utils.exceptions import DuplicateProcessName @@ -19,41 +22,21 @@ class BranchSpec(list): >>> # NOTE: conceptual illustration only, not real processes >>> branch_spec = [ >>> { - >>> "_PROCESS_": "normalize", - >>> "_PARAMS_": { - >>> "min_genes": 5, - >>> "max_genes": 5000, - >>> "min_cells": 5, - >>> "nfeatures": 30, - >>> "perc_mito_cutoff": 20, - >>> "method": "seurat_default", - >>> } - >>> "_SUBSET_": { - >>> "indication": {"disease_1", "disease_3"}, - >>> "collection_center": "mass_general", - >>> }, - >>> "_FILTER_": { - >>> "donor": "D115" - >>> } - >>> }, - >>> { >>> "_PROCESS_": "reduce", # dimensionality reduction >>> "_ALIAS_": "linear_dim_reduce", >>> "_PARAMS_": { >>> "algorithm": "pca", - >>> "n_pcs": 30, + >>> "n_pcs": 30 >>> } >>> }, >>> { >>> "_PROCESS_": "reduce", >>> "_ALIAS_": "nonlinear_dim_reduce", - >>> "_PARAMS_": { - >>> "algorithm": "umap", - >>> "n_neighbors": 15, - >>> "min_dist": 0.1, - >>> "n_components": 2, - >>> "metric": "euclidean" - >>> } + >>> "_PARAMS_": ... + >>> }, + >>> { + >>> "_PROCESS_": "dispersity" + >>> "_PARAMS_": ... >>> } >>> ] >>> branch_spec = BranchSpec(branch_spec) @@ -69,6 +52,8 @@ class BranchSpec(list): process_order: """ + _RUN_SPEC_CLASS = RunSpec + def __init__(self, spec: Union[str, List[dict], "BranchSpec[RunSpec]"]): if isinstance(spec, str): spec = json.loads(spec) @@ -84,6 +69,10 @@ def __init__(self, spec: Union[str, List[dict], "BranchSpec[RunSpec]"]): ) self.process_order: List[str] = [spec_item.name for spec_item in self] + @property + def processes(self): + return [run_spec["_PROCESS_"] for run_spec in self] + @property def shell_str(self): """string version which can be passed via shell and loaded via json""" @@ -188,9 +177,9 @@ def _get_data_operation_list(self, process_name: str, operation_name: str) -> Li operation_list.append(operation) return operation_list - def _build_run_spec_lookup(self) -> Dict[str, "RunSpec"]: + def _build_run_spec_lookup(self) -> Dict[str, Union["RunSpec", "RunGroupSpec"]]: """See class definition""" - run_spec_lookup = {"root": RunSpec({})} + run_spec_lookup = {"root": self._RUN_SPEC_CLASS({})} for run_spec in self: try: process_name = run_spec.name @@ -217,18 +206,25 @@ def _build_precursors_lookup(self, incl_root: bool = False, incl_current: bool = current_precursors = current_precursors + [spec_item.name] return precursors - def __getitem__(self, item: Union[str, int]) -> "RunSpec": + @typechecked + def __getitem__(self, key: Union[str, int, slice]) -> Union["RunSpec", "BranchSpec"]: """Get `RunSpec` either via `int` index or `name`""" - if not isinstance(item, int): - try: - return self._run_spec_lookup[item] - except Exception as e: - raise e + if isinstance(key, str): + return self._run_spec_lookup[key] + elif isinstance(key, slice): + if isinstance(key.stop, str): + if key.start or key.step: + raise ValueError(f"Can only use stop with string slice (ex. [:'process_name'])") + precursors_lookup = self.get_precursors_lookup(incl_current=True) + precursors = precursors_lookup[key.stop] + return self.__class__([self._run_spec_lookup[process] for process in precursors]) + else: + return self.__class__(super().__getitem__(key)) else: - return super().__getitem__(item) + return super().__getitem__(key) def __setitem__(self, k, v): - raise ValueError("Cannot set items dynamically. All items must be defined at init") + raise NotImplementedError("Cannot set items dynamically. All items must be defined at init") def __contains__(self, item): return item in self._run_spec_lookup diff --git a/dataforest/core/DataBase.py b/dataforest/core/DataBase.py index ea241c5..c1929d4 100644 --- a/dataforest/core/DataBase.py +++ b/dataforest/core/DataBase.py @@ -1,5 +1,6 @@ +from abc import ABC from pathlib import Path -from typing import Union, Optional, List, Dict +from typing import Union, Optional, List, AnyStr from dataforest.core.PlotMethods import PlotMethods @@ -9,14 +10,21 @@ class DataBase: Mixin for `DataTree`, `DataBranch`, and derived class """ + _PLOT_METHODS = PlotMethods + def __init__(self): - self.plot = PlotMethods(self) + self.root = None + self.plot = self._PLOT_METHODS(self) + + @property + def root_built(self): + return (Path(self.root) / "meta.tsv").exists() @staticmethod def _combine_datasets( - root: Union[str, Path], - metadata: Optional[Union[str, Path]] = None, - input_paths: Optional[List[Union[str, Path]]] = None, + root: AnyStr, + metadata: Optional[AnyStr] = None, + input_paths: Optional[List[AnyStr]] = None, mode: Optional[str] = None, ): raise NotImplementedError("Must be implemented by subclass") diff --git a/dataforest/core/DataBranch.py b/dataforest/core/DataBranch.py index c19458f..427a6ad 100644 --- a/dataforest/core/DataBranch.py +++ b/dataforest/core/DataBranch.py @@ -5,6 +5,8 @@ import pandas as pd from pathlib import Path +from typeguard import typechecked + from dataforest.core.DataBase import DataBase from dataforest.structures.cache.PathCache import PathCache from dataforest.structures.cache.IOCache import IOCache @@ -62,7 +64,7 @@ class DataBranch(DataBase): {reader, writer} from `{READER, WRITER}_METHODS` which was selected by `_map_default_{reader, writer}_methods` is not appropriate. Keys are `file_alias`es, and values are methods which take `filename`s - and `kwargs`. + and `plot_kwargs`. {READER, WRITER}_KWARGS_MAP: optional overloads for any file for which the default keyword arguments to the {reader, writer} should be @@ -87,7 +89,6 @@ class DataBranch(DataBase): _writer_method_map: """ - PLOT_METHODS: Type = PlotMethods PROCESS_METHODS: Type = ProcessMethods SCHEMA_CLASS: Type = ProcessSchema READER_METHODS: Type = ReaderMethods @@ -100,7 +101,7 @@ class DataBranch(DataBase): _METADATA_NAME: dict = NotImplementedError("Should be implemented by superclass") _COPY_KWARGS: dict = { "root": "root", - "branch_spec": "branch_spec", + "branch_spec": "spec", "verbose": "verbose", "current_process": "current_process", } @@ -187,6 +188,10 @@ def meta(self) -> pd.DataFrame: self._meta = self._get_meta(self.current_process) return self._meta + def add_metadata(self, df: pd.DataFrame): + # TODO: need to write to keep persistent? + self.set_meta(pd.concat([self.meta, df], axis=1)) + def goto_process(self, process_name: str) -> "DataBranch": """ Updates the state of the `DataBranch` object to reflect the @@ -212,6 +217,29 @@ def goto_process(self, process_name: str) -> "DataBranch": self.clear_data(path_map_changes) return self + def groupby(self, by: Union[str, list, set, tuple], **kwargs) -> Tuple[str, "DataBranch"]: + """ + Operates like a pandas group_labels, but does not return a GroupBy object, + and yields (name, DataBranch), where each DataBranch is subset according to `by`, + which corresponds to columns of `self.meta`. + This is useful for batching analysis across various conditions, where + each run requires an DataBranch. + Args: + by: variables over which to group (like pandas) + **kwargs: for pandas group_labels on `self.meta` + + Yields: + name: values for DataBranch `subset` according to keys specified in `by` + branch: new DataBranch which inherits `self.spec` with additional `subset`s + from `by` + """ + if isinstance(by, (tuple, set)): + by = list(by) + for (name, df) in self.meta.groupby(by, **kwargs): + branch = self.copy() + branch.set_meta(df) + yield name, branch + def fork(self, branch_spec: Union[list, BranchSpec]) -> "DataBranch": """ "Forks" the current branch to create a copy with an altered branch_spec, but @@ -263,19 +291,6 @@ def clear_data(self, attrs: Optional[Iterable[str]] = None, all_data: bool = Fal data_attr = f"_{attr_name}" setattr(self, data_attr, None) - def create_root_plots(self, plot_kwargs: Optional[Dict[str, dict]] = None): - if self.is_process_plots_exist("root"): - self.logger.info( - f"plots already present for `root` at {self['root'].plots_path}. To regenerate plots, delete directory" - ) - return - plot_kwargs = plot_kwargs if plot_kwargs else dict() - root_plot_methods = self.plot.plot_methods.get("root", []) - for name in root_plot_methods: - kwargs = plot_kwargs.get(name, dict()) - method = getattr(self.plot, name) - method(**kwargs) - def is_process_plots_exist(self, process_name: str) -> bool: return self[process_name].plots_path.exists() @@ -337,12 +352,27 @@ def get_temp_meta_path(self: "DataBranch", process_name: str): def __getitem__(self, process_name: str) -> ProcessRun: if process_name not in self._process_runs: - if process_name in ("root", None): - process_name = "root" + process_name = "root" if process_name is None else process_name process = self.spec[process_name].process if process_name != "root" else "root" self._process_runs[process_name] = ProcessRun(self, process_name, process) return self._process_runs[process_name] + @property + def current_path(self) -> str: + """ + The paths at current `process_run` + """ + return self._paths[self._current_process] + + @staticmethod + def _combine_datasets( + root_dir: Union[str, Path], + metadata: Optional[Union[str, Path]] = None, + input_paths: Optional[List[Union[str, Path]]] = None, + mode: Optional[str] = None, + ): + raise NotImplementedError("Must be implemented by subclass") + def _get_meta(self, process_name): raise NotImplementedError("This method should be implemented by `DataBranch` subclasses") @@ -365,9 +395,15 @@ def _apply_data_ops(self, process_name: str, df: Optional[pd.DataFrame] = None): df = self.meta.copy() for (subset, filter_) in zip(subset_list, filter_list): for column, val in subset.items(): - df = self._do_subset(df, column, val) + if "_MULTI_" in column: + df = self._do_subset(df, val) + elif val is not None: + df = self._do_subset(df, {column: val}) for column, val in filter_.items(): - df = self._do_filter(df, column, val) + if "_MULTI_" in column: + df = self._do_filter(df, val) + elif val is not None: + df = self._do_filter(df, {column: val}) df.replace(" ", "_", regex=True, inplace=True) partitions_list = self.spec.get_partition_list(process_name) partitions = set().union(*partitions_list) @@ -375,36 +411,41 @@ def _apply_data_ops(self, process_name: str, df: Optional[pd.DataFrame] = None): df = label_df_partitions(df, partitions, encodings=True) return df - @staticmethod - def _do_subset(df: pd.DataFrame, column: str, val: Any) -> pd.DataFrame: - prev_df = df.copy() - if isinstance(val, (list, set)): - df = df[df[column].isin(val)] - else: - df = df[df[column] == val] - if len(df) == len(prev_df): - logging.warning(f"Subset didn't change num of rows and may be unnecessary - column: {column}, val: {val}") - elif len(df) == 0: - raise BadSubset(column, val) + @classmethod + def _do_subset(cls, df: pd.DataFrame, subset_dict: Dict[str, Any]) -> pd.DataFrame: + df_selector = cls._get_df_selector(df, subset_dict) + df = df.loc[df_selector] + if len(df) == 0: + raise BadSubset(subset_dict) return df - @staticmethod - def _do_filter(df: pd.DataFrame, column: str, val: Any) -> pd.DataFrame: - prev_df = df.copy() - if isinstance(val, (list, set)): - df = df[~df[column].isin(val)] - else: - df = df[df[column] != val] - if len(df) == len(prev_df): - logging.warning(f"Filter didn't change num of rows and may be unnecessary - column: {column}, val: {val}") - elif len(df) == 0: - raise BadFilter(column, val) + @classmethod + def _do_filter(cls, df: pd.DataFrame, filter_dict: Dict[str, Any]) -> pd.DataFrame: + df_selector = cls._get_df_selector(df, filter_dict) + df = df.loc[~df_selector] + if len(df) == 0: + raise BadFilter(filter_dict) return df + @staticmethod + def _get_df_selector(df: pd.DataFrame, op_dict: Dict[str, Any]) -> pd.Series: + def _check_row_eq(row: pd.Series) -> bool: + for key, val in row.iteritems(): + if isinstance(val, set): + if not op_dict[key] in val: + return False + else: + if not op_dict[key] == val: + return False + return True + + df_selector = pd.DataFrame(pd.DataFrame(df[list(op_dict)]).apply(_check_row_eq, axis=1)).all(axis=1) + return df_selector + def _map_file_io(self) -> Tuple[Dict[str, IOCache], Dict[str, IOCache]]: """ Builds a lookup of lazy loading caches for file readers and writers, - which have implicit access to paths, methods, and kwargs for each file. + which have implicit access to paths, methods, and plot_kwargs for each file. Returns: {reader, writer}_map: Key: file_alias (e.g. "rna") diff --git a/dataforest/core/DataTree.py b/dataforest/core/DataTree.py index 1d517d7..c77fbcc 100644 --- a/dataforest/core/DataTree.py +++ b/dataforest/core/DataTree.py @@ -1,22 +1,28 @@ +import logging from pathlib import Path from typing import Union, Optional, List, Dict from dataforest.core.DataBase import DataBase from dataforest.core.DataBranch import DataBranch +from dataforest.core.PlotTreeMethods import PlotTreeMethods +from dataforest.core.ProcessTreeRun import ProcessTreeRun +from dataforest.core.TreeDataFrame import DataFrameList from dataforest.core.TreeSpec import TreeSpec from dataforest.processes.core.TreeProcessMethods import TreeProcessMethods from dataforest.structures.cache.BranchCache import BranchCache class DataTree(DataBase): - BRANCH_CLASS = DataBranch + _LOG = logging.getLogger("DataTree") + _BRANCH_CLASS = DataBranch + _PLOT_METHODS = PlotTreeMethods def __init__( self, root: Union[str, Path], tree_spec: Optional[List[dict]] = None, + twigs: Optional[List[dict]] = None, verbose: bool = False, - current_process: Optional[str] = None, remote_root: Optional[Union[str, Path]] = None, ): # TODO: add something that tells them how many of each process will be run @@ -24,41 +30,77 @@ def __init__( # it finishes. That way, we can super().__init__() self.root = root - self.tree_spec = self._init_spec(tree_spec) + self.tree_spec = self._init_spec(tree_spec, twigs) + self._twigs = twigs self._verbose = verbose - self._current_process = current_process + self._current_process = None self.remote_root = remote_root - self._branch_cache = BranchCache( - root, self.tree_spec.branch_specs, self.BRANCH_CLASS, verbose, current_process, remote_root, - ) - self.process = TreeProcessMethods(self.tree_spec, self._branch_cache) + self._branch_cache = BranchCache(root, self.tree_spec.branch_specs, self._BRANCH_CLASS, verbose, remote_root,) + self._process_tree_runs = dict() + self.process = TreeProcessMethods(self) + + @property + def meta(self): + return DataFrameList([branch.meta for branch in self.branches]) @property def n_branches(self): return len(self.tree_spec.branch_specs) - def update_process_spec(self, process_name: str, process_spec: dict): - self.tree_spec[process_name] = process_spec - self.update_spec(self.tree_spec) + @property + def current_process(self): + return self._current_process if self._current_process else "root" + + @property + def has_sweeps(self) -> bool: + """Whether or not any sweeps are in the spec""" + return bool(set.union(*self.tree_spec.sweep_dict.values())) + + @property + def has_twigs(self): + return bool(self._twigs) + + @property + def branches(self): + self._branch_cache.load_all() + return list(self._branch_cache.values()) - def update_spec(self, tree_spec: Union[List[dict], "TreeSpec[RunGroupSpec]"]): - tree_spec = list(tree_spec) - self.tree_spec = self._init_spec(tree_spec) - self._branch_cache.update_branch_specs(self.tree_spec.branch_specs) - self.process = TreeProcessMethods(self.tree_spec, self._branch_cache) + def goto_process(self, process_name: str): + self._branch_cache.load_all() + for branch in self._branch_cache.values(): + branch.goto_process(process_name) + self._current_process = process_name + + def load_all(self): + self._branch_cache.load_all() def run_all(self, workers: int = 1, batch_queue: Optional[str] = None): return [method() for method in self.process.process_methods] + def unique_branches_at_process(self, process_name: str) -> Dict[str, "DataBranch"]: + """ + Gets a subset of branches representing those unique up to the specified + process. From two branches which only become distinguished after + `process_name`, just one will be selected. + """ + return {str(branch.spec[:process_name]): branch for branch in self.branches} + def create_root_plots(self, plot_kwargs: Optional[Dict[str, dict]] = None): rand_spec = self.tree_spec.branch_specs[0] rand_branch = self._branch_cache[str(rand_spec)] - rand_branch.create_root_plots(plot_kwargs) + rand_branch._generate_root_plots(plot_kwargs) + + def __getitem__(self, process_name: str) -> ProcessTreeRun: + if process_name not in self._process_tree_runs: + process_name = "root" if process_name is None else process_name + process = self.tree_spec[process_name].process if process_name != "root" else "root" + self._process_tree_runs[process_name] = ProcessTreeRun(self, process_name, process) + return self._process_tree_runs[process_name] @staticmethod - def _init_spec(tree_spec: Union[list, TreeSpec]) -> TreeSpec: + def _init_spec(tree_spec: Union[list, TreeSpec], twigs: Optional[List[dict]]) -> TreeSpec: if tree_spec is None: tree_spec = list() if not isinstance(tree_spec, TreeSpec): - tree_spec = TreeSpec(tree_spec) + tree_spec = TreeSpec(tree_spec, twigs) return tree_spec diff --git a/dataforest/core/Interface.py b/dataforest/core/Interface.py index b96cbcd..c7f9f59 100644 --- a/dataforest/core/Interface.py +++ b/dataforest/core/Interface.py @@ -1,5 +1,6 @@ +from copy import deepcopy from pathlib import Path -from typing import Union, Optional, Iterable, List +from typing import Union, Optional, Iterable, List, AnyStr import pandas as pd @@ -24,7 +25,7 @@ def load( ) -> Union["DataBranch", "CellBranch", "DataTree", "CellTree"]: # TODO: replace kwargs with explicit to make it easier for users """ - Loads `cls.TREE_CLASS` if `tree_spec` passed, otherwise `cls.BRANCH_CLASS` + Loads `cls.TREE_CLASS` if `tree_spec` passed, otherwise `cls._BRANCH_CLASS` Args: root: branch_spec: @@ -37,8 +38,8 @@ def load( """ kwargs = { - "branch_spec": branch_spec, - "tree_spec": tree_spec, + "branch_spec": deepcopy(branch_spec), + "tree_spec": deepcopy(tree_spec), "verbose": verbose, "current_process": current_process, "remote_root": remote_root, @@ -46,19 +47,24 @@ def load( } kwargs = cls._prune_kwargs(kwargs) interface_cls = cls._get_interface_class(kwargs) - return interface_cls(root, **kwargs) + inst = interface_cls(root, **kwargs) + if not inst.root_built: + raise ValueError( + "Root must be built once before it can be loaded. Use `from_input_dirs` or `from_sample_metadata`" + ) + return inst @classmethod def from_input_dirs( cls, - root: Union[str, Path], - input_paths: Optional[Union[str, Path, Iterable[Union[str, Path]]]] = None, + root: AnyStr, + input_paths: Optional[Union[AnyStr, Iterable[AnyStr]]] = None, mode: Optional[str] = None, branch_spec: Optional[List[dict]] = None, tree_spec: Optional[List[dict]] = None, verbose: bool = False, current_process: Optional[str] = None, - remote_root: Optional[Union[str, Path]] = None, + remote_root: Optional[AnyStr] = None, root_plots: bool = True, plot_kwargs: Optional[dict] = None, overwrite_plots: Optional[Iterable[str]] = None, @@ -99,7 +105,7 @@ def from_input_dirs( kwargs = {**additional_kwargs, **kwargs} inst = interface_cls(root, **kwargs) if root_plots: - inst.create_root_plots(plot_kwargs) + inst.plot.generate_plots("root", plot_kwargs) return inst @classmethod @@ -161,7 +167,7 @@ def from_sample_metadata( kwargs = {**additional_kwargs, **kwargs} inst = interface_cls(root, **kwargs) if root_plots: - inst.create_root_plots(plot_kwargs) + inst.plot.generate_plots("root", plot_kwargs) return inst @staticmethod diff --git a/dataforest/core/PlotMethods.py b/dataforest/core/PlotMethods.py index 09781a2..333293c 100644 --- a/dataforest/core/PlotMethods.py +++ b/dataforest/core/PlotMethods.py @@ -1,9 +1,19 @@ +from copy import deepcopy from functools import wraps +import logging +from pathlib import Path +from typing import Optional, Dict, Iterable, TYPE_CHECKING, Set + +from IPython.display import Image, display +from typeguard import typechecked from dataforest.config.MetaPlotMethods import MetaPlotMethods from dataforest.utils import tether, copy_func from dataforest.utils.ExceptionHandler import ExceptionHandler +if TYPE_CHECKING: + from dataforest.core.DataBranch import DataBranch + class PlotMethods(metaclass=MetaPlotMethods): """ @@ -18,7 +28,37 @@ def __init__(self, branch: "DataBranch"): callable_ = copy_func(plot_method) callable_.__name__ = name setattr(self, name, self._wrap(callable_)) - tether(self, "branch") + tether(self, "branch", incl_methods=list(self.plot_method_lookup.keys())) + self._img_cache = {} + + def show(self, process_name: str): + plots_path = self.branch[process_name].plots_path + for img_path in plots_path.iterdir(): + display(str(img_path)) + display(Image(img_path)) + + @typechecked + def generate_plots( + self, + processes: Optional[Iterable[str]] = None, + plot_map: Optional[Dict[str, dict]] = None, + plot_kwargs: Optional[Dict[str, dict]] = None, + ): + plot_map = self.plot_map if not plot_map else plot_map + plot_kwargs = self.plot_kwargs if not plot_kwargs else plot_kwargs + if processes is not None: + plot_map = {proc: proc_plot_map for proc, proc_plot_map in plot_map.items() if proc in processes} + plot_kwargs = { + proc: proc_plot_kwargs for proc, proc_plot_kwargs in plot_kwargs.items() if proc in processes + } + for process, proc_plot_map in plot_map.items(): + method_names_config = tuple(proc_plot_map.keys()) + for name_config in method_names_config: + method_name = self.method_name_lookup[name_config] + method = getattr(self, method_name) + kwarg_sets = plot_kwargs[process][name_config].values() + for kwargs in kwarg_sets: + method(**kwargs) @property def plot_method_lookup(self): @@ -26,7 +66,69 @@ def plot_method_lookup(self): @property def plot_methods(self): - return self.__class__.PLOT_METHODS + return self.__class__.PROCESS_PLOT_METHODS + + @property + def method_lookup(self): + return {k: getattr(self, method_name) for k, method_name in self.key_method_lookup.items()} + + @property + def plot_kwargs_defaults(self): + return self.__class__.PLOT_KWARGS_DEFAULTS + + @property + def plot_map(self): + # TODO: rename to plot_settings for clarity and to avoid confusion w/ process run? + return self.__class__.PLOT_MAP + + @property + def plot_kwargs(self): + return self.__class__.PLOT_KWARGS + + @property + def methods(self) -> Dict[str, Set[str]]: + """ + Key: process_name at which plot becomes unlocked + Value: plot method names + """ + is_plot_method = lambda s: s.startswith("plot") and callable(getattr(self, s)) + method_names = list(filter(is_plot_method, dir(self))) + avail_dict = {} + + def _assign(method_name): + method = getattr(self, method_name) + if hasattr(method, "_requires"): + required = getattr(method, "_requires") + else: + required = "root" + if required not in avail_dict: + avail_dict[required] = set() + avail_dict[required].add(method_name) + + list(map(_assign, method_names)) + return avail_dict + + @property + def key_method_lookup(self): + """ + Key: method key in format of config (e.g. "_UMAP_EMBEDDINGS_SCAT_") + Value: method name (e.g. "plot_umap_embeddings_scat") + """ + convert_to_key = lambda s: "_" + s.upper()[5:] + "_" + return {convert_to_key(x): x for k, v in self.methods.items() for x in v} + + @property + def method_key_lookup(self): + """inverted `key_method_lookup`""" + return {v: k for k, v in self.key_method_lookup.items()} + + @property + def keys(self) -> Dict[str, Set[str]]: + """ + Key: process name at which plot becomes unlocked + Value: keys for plots in config + """ + return {k: set(map(lambda x: self.method_key_lookup[x], v)) for k, v in self.methods.items()} def _wrap(self, method): """Wrap with mkdirs and logging""" @@ -35,11 +137,40 @@ def _wrap(self, method): def wrapped(branch, method_name, *args, stop_on_error: bool = False, **kwargs): try: process_run = branch[branch.current_process] - plot_dir = process_run.plot_map[method_name].parent - plot_dir.mkdir(exist_ok=True) + plot_name = branch.plot.method_key_lookup.get(method_name, None) + if plot_name in process_run._plot_map: + if not (plt_filename_lookup := process_run._plot_map[plot_name]): + if plt_filename_lookup: + _, plt_filepath = next(iter(plt_filename_lookup.items())) + plt_dir = plt_filepath.parent + else: + plt_dir = Path("/tmp") + plt_dir.mkdir(exist_ok=True) return method(branch, *args, **kwargs) except Exception as e: err_filename = method.__name__ - ExceptionHandler.handle(self.branch, e, err_filename, stop_on_error) + ExceptionHandler.handle(self.branch, e, f"PLOT_{err_filename}.err", stop_on_error) return wrapped + + def _generate_root_plots( + self, plot_kwargs: Optional[Dict[str, dict]] = None, overwrite: bool = False, stop_on_error: bool = False + ): + if self.branch.is_process_plots_exist("root") and not overwrite: + logging.info( + f"plots already present for `root` at {self.branch['root'].plots_path}. To regenerate plots, delete dir" + ) + return + + if plot_kwargs is None: + plot_kwargs = self.plot_kwargs["root"] + root_plot_map = self.branch["root"]._plot_map + root_plot_methods = self.plot_methods.get("root", []) + + for plot_name, plot_method in root_plot_methods.items(): + kwargs_sets = plot_kwargs.get(plot_name, dict()) + for plot_kwargs_key, _kwargs in kwargs_sets.items(): + method = getattr(self, plot_method) + kwargs = deepcopy(_kwargs) + kwargs["plot_path"] = root_plot_map[plot_name][plot_kwargs_key] + method(stop_on_error=stop_on_error, **kwargs) diff --git a/dataforest/core/PlotTreeMethods.py b/dataforest/core/PlotTreeMethods.py new file mode 100644 index 0000000..5cebedb --- /dev/null +++ b/dataforest/core/PlotTreeMethods.py @@ -0,0 +1,18 @@ +from dataforest.core.PlotMethods import PlotMethods +from dataforest.plot.PlotWidget import PlotWidget + + +class PlotTreeMethods(PlotMethods): + def __init__(self, tree): + self._tree = tree + for method_name in self.plot_method_lookup.keys(): + setattr(self, method_name, self._widget_wrap(method_name)) + + def _widget_wrap(self, method_name): + def wrap(**kwargs): + plot_key = self.method_key_lookup[method_name] + widget = PlotWidget(self._tree, plot_key, **kwargs) + return widget.build_control(show=True, stop_on_error=True) + + wrap.__name__ = method_name + return wrap diff --git a/dataforest/core/ProcessRun.py b/dataforest/core/ProcessRun.py index 4c31e63..63c5380 100644 --- a/dataforest/core/ProcessRun.py +++ b/dataforest/core/ProcessRun.py @@ -1,9 +1,9 @@ from collections import ChainMap import logging +from pathlib import Path from typing import Dict, List, TYPE_CHECKING import pandas as pd -from pathlib import Path from termcolor import cprint @@ -18,7 +18,7 @@ class ProcessRun: """ def __init__(self, branch: "DataBranch", process_name: str, process: str): - self.logger = logging.getLogger(f"ProcessRun - {process_name}") + self._LOG = logging.getLogger(f"ProcessRun - {process_name}") if process_name not in branch.spec and process_name != "root": raise ValueError(f'key "{process_name}" not in branch_spec: {branch.spec}') self.process_name = process_name @@ -28,7 +28,7 @@ def __init__(self, branch: "DataBranch", process_name: str, process: str): self._file_lookup = self._build_file_lookup() self._plot_lookup = self._build_plot_lookup() self._path_map = None - self._plot_map = None + self._plot_map_cache = None self._path_map_prior = None @property @@ -46,6 +46,14 @@ def path(self) -> Path: """Path to directory containing processes output files and logs""" return self.branch.paths[self.process_name] + @property + def plots(self): + return self.branch.plot.show(self.process_name) + + @property + def plot_methods(self): + return self.branch.plot.methods[self.process] + @property def logs_path(self) -> Path: return self.path / "_logs" @@ -84,8 +92,12 @@ def process_path_map(self) -> Dict[str, Path]: return {file_alias: self.path / self._file_lookup[file_alias] for file_alias in self._file_lookup} @property - def process_plot_map(self) -> Dict[str, Path]: - return {file_alias: self.plots_path / self._plot_lookup[file_alias] for file_alias in self._plot_lookup} + def plot_lookup(self) -> Dict[str, List[Path]]: + """ + Key: plot key (e.g. "_UMIS_PER_CELL_HIST_") + Value: list of plot paths of the type specified by key + """ + return {plot_key: list(plot_path_map.values()) for plot_key, plot_path_map in self._plot_map.items()} @property def path_map(self) -> Dict[str, Path]: @@ -99,12 +111,6 @@ def path_map(self) -> Dict[str, Path]: self._path_map = self._build_path_map(incl_current=True) return self._path_map - @property - def plot_map(self) -> Dict[str, Path]: - if self._plot_map is None: - self._plot_map = self._build_path_map(incl_current=True, plot_map=True) - return self._plot_map - @property def path_map_prior(self) -> Dict[str, Path]: """Like `path_map`, but for excluding the current process""" @@ -129,6 +135,10 @@ def file_map_done(self) -> Dict[str, str]: @property def done(self) -> bool: + """ + Whether or not process has been executed to completion, regardless of + success + """ if self.path.exists() and not (self.path / "INCOMPLETE").exists(): if len(self.files) > 0 or self.logs_path.exists(): return True @@ -148,8 +158,10 @@ def success(self) -> bool: @property def failed(self) -> bool: if self.path is not None and self.path.exists(): - error_files = ["process.err", "hooks.err"] - if not set(error_files).intersection(self.logs_path.iterdir()): + error_prefixes = ["PROCESS__", "HOOKS__"] + is_error_file = lambda s: any(map(lambda prefix: s.name.startswith(prefix), error_prefixes)) + contains_error_file = any(map(is_error_file, self.logs_path.iterdir())) + if contains_error_file: return True return False @@ -162,27 +174,29 @@ def logs(self): """ Prints stdout and stderr log files """ - log_dir = self.path / "_logs" - log_files = log_dir.iterdir() - stdouts = list(filter(lambda x: str(x).endswith(".out"), log_files)) - stderrs = list(filter(lambda x: str(x).endswith(".err"), log_files)) - if (len(stdouts) == 0) and (len(stderrs) == 0): - raise ValueError(f"No logs for processes: {self.process_name}") - for stdout in stdouts: - name = str(stdout.name).split(".out")[0] - cprint(f"STDOUT: {name}", "cyan", "on_grey") - with open(str(stdout), "r") as f: - print(f.read()) - for stderr in stderrs: - name = str(stderr.name).split(".err")[0] - cprint(f"STDERR: {name}", "magenta", "on_grey") - with open(str(stderr), "r") as f: - print(f.read()) + self._print_logs() def subprocess_runs(self, process_name: str) -> pd.DataFrame: """DataFrame of branch_spec info for all runs of a given subprocess""" raise NotImplementedError() + @property + def _process_plot_map(self) -> Dict[str, Path]: + plot_map_dict = {} + for plot_name in self._plot_lookup: + plot_map_dict[plot_name] = {} + for plot_kwargs_key, plot_filepath in self._plot_lookup[plot_name].items(): + plot_map_dict[plot_name][plot_kwargs_key] = self.plots_path / plot_filepath + + return plot_map_dict + + @property + def _plot_map(self) -> Dict[str, Path]: + # TODO: confusing with new plot_map name in config - rename that to plot_settings? + if self._plot_map_cache is None: + self._plot_map_cache = self._build_path_map(incl_current=True, plot_map=True) + return self._plot_map_cache + def _build_layers_files(self) -> Dict[str, str]: """ Builds {file_alias: filename} lookup for additional layers specified @@ -214,14 +228,32 @@ def _build_path_map(self, incl_current: bool = False, plot_map: bool = False) -> spec = self.branch.spec precursor_lookup = spec.get_precursors_lookup(incl_current=incl_current, incl_root=True) precursors = precursor_lookup[self.process_name] - process_runs = [self.branch[process_name] for process_name in precursors] - pr_attr = "process_plot_map" if plot_map else "process_path_map" + process_runs = [self] if plot_map else [self.branch[process_name] for process_name in precursors] + pr_attr = "_process_plot_map" if plot_map else "process_path_map" process_path_map_list = [getattr(pr, pr_attr) for pr in process_runs] path_map = dict() for process_path_map in process_path_map_list: path_map.update(process_path_map) return path_map + def _print_logs(self): + log_dir = self.path / "_logs" + log_files = list(log_dir.iterdir()) + stdouts = list(filter(lambda x: str(x).endswith(".out"), log_files)) + stderrs = list(filter(lambda x: str(x).endswith(".err"), log_files)) + if (len(stdouts) == 0) and (len(stderrs) == 0): + raise ValueError(f"No logs for processes: {self.process_name}") + for stdout in stdouts: + name = str(stdout.name).split(".out")[0] + cprint(f"STDOUT: {name}", "cyan", "on_grey") + with open(str(stdout), "r") as f: + print(f.read()) + for stderr in stderrs: + name = str(stderr.name).split(".err")[0] + cprint(f"STDERR: {name}", "magenta", "on_grey") + with open(str(stderr), "r") as f: + print(f.read()) + def __repr__(self): repr_ = super().__repr__()[:-1] # remove closing bracket to append repr_ += f" process: {self.process}; process_name: {self.process_name}; done: {self.done}>" diff --git a/dataforest/core/ProcessTreeRun.py b/dataforest/core/ProcessTreeRun.py new file mode 100644 index 0000000..f1f4897 --- /dev/null +++ b/dataforest/core/ProcessTreeRun.py @@ -0,0 +1,38 @@ +import logging +from typing import TYPE_CHECKING, Callable, Set, Tuple + +from dataforest.core.DataBranch import DataBranch + +if TYPE_CHECKING: + from dataforest.core.DataTree import DataTree + + +class ProcessTreeRun: + def __init__(self, tree: "DataTree", process_name: str, process: str): + self._LOG = logging.getLogger(f"ProcessRun - {process_name}") + if process_name not in tree.tree_spec and process_name != "root": + raise ValueError(f'key "{process_name}" not in tree_spec: {tree.tree_spec}') + self.process_name = process_name + self.process = process + self._tree = tree + + @property + def done(self): + self._tree.load_all() + return all(map(lambda branch: branch[self.process_name].done, self._tree._branch_cache.values())) + + @property + def failed(self): + filter_ = lambda branch, process: branch[process].done and not branch[process].success + return self._filter_branches(filter_) + + @property + def success(self): + filter_ = lambda branch, process: branch[process].done and branch[process].success + return self._filter_branches(filter_) + + def _filter_branches(self, filter_: Callable) -> tuple: + self._tree.load_all() + branches = self._tree._branch_cache.values() + process = self.process_name + return tuple([branch for branch in branches if filter_(branch, process)]) diff --git a/dataforest/core/RunGroupSpec.py b/dataforest/core/RunGroupSpec.py index 6786f18..88ad8f8 100644 --- a/dataforest/core/RunGroupSpec.py +++ b/dataforest/core/RunGroupSpec.py @@ -1,5 +1,5 @@ from itertools import product -from typing import Union, Iterable, List +from typing import Union, Iterable, List, Tuple, Any from dataforest.core.RunSpec import RunSpec from dataforest.core.Sweep import Sweep @@ -17,6 +17,7 @@ class RunGroupSpec(RunSpec): def __init__(self, dict_: Union[dict, "RunGroupSpec"]): super().__init__(**dict_) + self.sweeps = set() self.run_specs = self._build_run_specs() def _build_run_specs(self): @@ -32,11 +33,12 @@ def _build_run_specs(self): run_specs: list of `RunSpec`s representing all combinations of values from sweeps """ - sub_groups = self._expand_sweeps(self, (list, set, tuple)) + sub_groups, combos = self._expand_sweeps(self, (list, set, tuple)) # if there were no operations in array format to be expanded if isinstance(sub_groups, dict): self._map_sweeps() - self_expand = {k: self._expand_sweeps(v, Sweep) for k, v in self.items() if isinstance(v, dict)} + # expand _SWEEP_ keys + self_expand = {k: self._expand_sweeps(v, Sweep)[0] for k, v in self.items() if isinstance(v, dict)} self_expand.update({k: v for k, v in self.items() if not isinstance(v, dict)}) # if any operations are still in array format if any(map(lambda x: isinstance(x, list), self_expand.values())): @@ -57,11 +59,14 @@ def _map_sweeps(self): if isinstance(operation_dict, dict): for key, val in operation_dict.items(): if isinstance(val, dict) and "_SWEEP_" in val: + sweep_info = (self["_PROCESS_"], operation, key, tuple(val["_SWEEP_"])) + self.sweeps.add(sweep_info) sweep_obj = val["_SWEEP_"] self[operation][key] = Sweep(operation, key, sweep_obj) @staticmethod - def _expand_sweeps(dict_: dict, types: Union[type, Iterable[type]]) -> Union[dict, List[dict]]: + # TODO: type hint combos + def _expand_sweeps(dict_: dict, types: Union[type, Iterable[type]]) -> Tuple[Union[dict, List[dict]], Any]: """ Expand sweeps into a list of dicts representing all possible combinations. Args: @@ -70,7 +75,9 @@ def _expand_sweeps(dict_: dict, types: Union[type, Iterable[type]]) -> Union[dic """ sweeps_part = order_dict({k: v for k, v in dict_.items() if isinstance(v, types)}) static_part = {k: v for k, v in dict_.items() if k not in sweeps_part} - combos = product(*sweeps_part.values()) + combos = list(product(*sweeps_part.values())) dicts = [{**static_part, **dict(zip(sweeps_part.keys(), combo))} for combo in combos] - dicts = dicts[0] if len(dicts) == 1 else dicts - return dicts + if len(dicts) == 1: + dicts = dicts[0] + combos = combos[0] + return dicts, combos diff --git a/dataforest/core/Sweep.py b/dataforest/core/Sweep.py index acf0059..40ffd5f 100644 --- a/dataforest/core/Sweep.py +++ b/dataforest/core/Sweep.py @@ -7,6 +7,10 @@ def __init__(self, operation, key, sweep_obj): self.key = key super().__init__(self._get_values(sweep_obj)) + @property + def dtype(self): + return type(self[0]) + @staticmethod def _get_values(sweep_obj): if isinstance(sweep_obj, dict): @@ -27,7 +31,7 @@ def _get_values(sweep_obj): elif isinstance(sweep_obj, (list, set, tuple)): values = list(sweep_obj) else: - raise TypeError(f"`sweep_obj` must be of types: [dict, list, set, tuple]") + raise TypeError(f"`sweep_obj` expected type: [dict, list, set, tuple], not {type(sweep_obj)}. {sweep_obj}") return values def __str__(self): diff --git a/dataforest/core/TreeDataFrame.py b/dataforest/core/TreeDataFrame.py new file mode 100644 index 0000000..53a1327 --- /dev/null +++ b/dataforest/core/TreeDataFrame.py @@ -0,0 +1,73 @@ +from functools import wraps +from typing import TYPE_CHECKING, Callable, List, Iterable, Union + +import pandas as pd + +from dataforest.structures.cache.BranchCache import BranchCache + +if TYPE_CHECKING: + from dataforest.core.DataTree import DataTree + + +class DataFrameList(list): + TETHER_EXCLUDE = {"__class__", "__init__", "__weakref__", "__dict__", "__getitem__", "__setitem__"} + + def __init__(self, df_list: Iterable[Union[pd.DataFrame, pd.Series]]): + super().__init__(list(df_list)) + self._elem_class = self._get_elem_class(self) + self._tether_df_methods() + + def get_elem(self, i: int): + """ + Get the df or series from the list-like structure since getitem is + overloaded by the pandas method + Args: + i: index in list-like structure + """ + return list.__getitem__(self, i) + + @staticmethod + def _get_elem_class(container): + if all(isinstance(x, pd.DataFrame) for x in container): + return pd.DataFrame + elif all(isinstance(x, pd.Series) for x in container): + return pd.Series + + def _tether_df_methods(self): + if self._elem_class is None: + return + names = set(dir(self._elem_class)).difference(self.TETHER_EXCLUDE) + for name in names: + distributed_method = self._build_distributed_method(name) + setattr(self, name, distributed_method) + self._getitem = self._build_distributed_method("__getitem__") + self._setitem = self._build_distributed_method("__setitem__") + + def _build_distributed_method(self, method_name) -> Callable: + df_method = getattr(self._elem_class, method_name) + + @wraps(df_method) + def _distributed_method(*args, **kwargs): + def _split_args(i, arg): + if isinstance(arg, self.__class__): + return arg.get_elem(i) + return arg + + def _single_kernel(i, df): + args_ = [_split_args(i, arg) for arg in args] + kwargs_ = {k: _split_args(i, v) for k, v in kwargs.items()} + return df_method(df, *args_, **kwargs_) + + def _distributed_kernel(): + ret = [_single_kernel(*x) for x in enumerate(self)] + return self.__class__(ret) + + return _distributed_kernel() + + return _distributed_method + + def __getitem__(self, item): + return self._getitem(item) + + def __setitem__(self, key, value): + return self._setitem(key, value) diff --git a/dataforest/core/TreeSpec.py b/dataforest/core/TreeSpec.py index 0eac29b..2abdff4 100644 --- a/dataforest/core/TreeSpec.py +++ b/dataforest/core/TreeSpec.py @@ -1,9 +1,14 @@ +from copy import deepcopy from itertools import product -from typing import Union, List +from typing import Union, List, Dict, TYPE_CHECKING, Optional from dataforest.core.BranchSpec import BranchSpec from dataforest.core.RunGroupSpec import RunGroupSpec +if TYPE_CHECKING: + from dataforest.core.DataBranch import DataBranch + from dataforest.core.RunSpec import RunSpec + class TreeSpec(BranchSpec): """ @@ -28,7 +33,7 @@ class TreeSpec(BranchSpec): >>> }, >>> ], >>> "_SUBSET_": { - >>> "sample": {"_SWEEP_": ["sample_1", "sample_2"]} + >>> "sample_id": {"_SWEEP_": ["sample_1", "sample_2"]} >>> }, >>> }, >>> { @@ -45,13 +50,61 @@ class TreeSpec(BranchSpec): >>> tree_spec = TreeSpec(tree_spec) """ - def __init__(self, tree_spec: Union[List[dict], "TreeSpec[RunGroupSpec]"]): + _RUN_SPEC_CLASS = RunGroupSpec + + def __init__(self, tree_spec: Union[List[dict], "TreeSpec[RunGroupSpec]"], twigs: Optional[List[dict]]): super(list, self).__init__() self.extend([RunGroupSpec(item) for item in tree_spec]) - self.branch_specs = self._build_branch_specs() + self.twig_lookup = {str(twig): twig for twig in twigs} + # TODO: abstract to method + if self.twig_lookup: + self.twig_lookup = {"base": [], **self.twig_lookup} + self.branch_specs = self._build_branch_specs(twigs) + self.sweep_dict = {x["_PROCESS_"]: x.sweeps for x in self} + self.sweep_dict["root"] = set() + self._raw = tree_spec + self._run_spec_lookup: Dict[str, "RunGroupSpec"] = self._build_run_spec_lookup() + self._precursors_lookup: Dict[str, List[str]] = self._build_precursors_lookup() + self._precursors_lookup_incl_curr: Dict[str, List[str]] = self._build_precursors_lookup(incl_current=True) + self._precursors_lookup_incl_root: Dict[str, List[str]] = self._build_precursors_lookup(incl_root=True) + self._precursors_lookup_incl_root_curr: Dict[str, List[str]] = self._build_precursors_lookup( + incl_root=True, incl_current=True + ) + + @staticmethod + def add_twig(template: "BranchSpec", twig: Union[tuple, list]): + spec = deepcopy(template) + if isinstance(twig, tuple): + twig = [twig] + for mod in twig: + val = mod[-1] + final_key = mod[-2] + accessors = mod[:-2] + scope = spec + for key in accessors: + scope = scope[key] + scope[final_key] = val + return spec + + @staticmethod + def _apply_add_twigs(specs: List["BranchSpec"], twigs: Optional[Union[tuple, list]]): + if twigs is None: + return specs + specs = [spec for x in specs for spec in TreeSpec._add_twigs(x, twigs)] + return specs + + @staticmethod + def _add_twigs(template: "BranchSpec", twigs: List[Union[tuple, list]]): + specs = list() + for twig in twigs: + spec = TreeSpec.add_twig(template, twig) + specs.append(spec) + return specs - def _build_branch_specs(self): - return list(product(*[run_group_spec.run_specs for run_group_spec in self])) + def _build_branch_specs(self, twigs: Optional[List[dict]]): + specs = list(map(BranchSpec, product(*[run_group_spec.run_specs for run_group_spec in self]))) + specs.extend(self._apply_add_twigs(specs, twigs)) + return specs def __setitem__(self, key, value): if not isinstance(key, int): diff --git a/dataforest/hooks/dataprocess/dataprocess.py b/dataforest/hooks/dataprocess/dataprocess.py index b28224b..a86773b 100644 --- a/dataforest/hooks/dataprocess/dataprocess.py +++ b/dataforest/hooks/dataprocess/dataprocess.py @@ -18,7 +18,7 @@ class dataprocess(metaclass=MetaDataProcess): ✓ checks to ensure that specified input data is present - may make new tables and update them in the future - kwargs can be used to pass custom attributes + plot_kwargs can be used to pass custom attributes """ def __init__( @@ -61,7 +61,7 @@ def wrapper(branch, run_name, stop_on_error=True, stop_on_hook_error=True, *args try: return func(self.branch, run_name, *args, **kwargs) except Exception as e: - ExceptionHandler.handle(branch, e, "process.err", stop_on_error) + ExceptionHandler.handle(branch, e, f"PROCESS_{func.__name__}.err", stop_on_error) finally: self._run_hooks(self.clean_hooks, try_all=True, stop_on_hook_error=stop_on_hook_error) @@ -83,7 +83,7 @@ def _run_hooks(self, hooks: Iterable[Callable], try_all: bool = False, stop_on_h hook_exceptions[str(hook.__name__)] = e if not try_all: e = HookException(self.name, hook_exceptions) - ExceptionHandler.handle(self.branch, e, "hooks.err", stop_on_hook_error) + ExceptionHandler.handle(self.branch, e, f"HOOK_{hook.__name__}.err", stop_on_hook_error) if hook_exceptions: e = HookException(self.name, hook_exceptions) - ExceptionHandler.handle(self.branch, e, "hooks.err", stop_on_hook_error) + ExceptionHandler.handle(self.branch, e, f"HOOK_{hook.__name__}.err", stop_on_hook_error) diff --git a/dataforest/hooks/hook.py b/dataforest/hooks/hook.py index 3744fd8..9d79523 100644 --- a/dataforest/hooks/hook.py +++ b/dataforest/hooks/hook.py @@ -1,4 +1,4 @@ -from functools import partial, update_wrapper, wraps +from functools import update_wrapper, wraps from typing import Callable, Optional, List, Tuple, Union diff --git a/dataforest/hooks/hooks/core/hooks.py b/dataforest/hooks/hooks/core/hooks.py index 75a9e1a..269f550 100644 --- a/dataforest/hooks/hooks/core/hooks.py +++ b/dataforest/hooks/hooks/core/hooks.py @@ -1,6 +1,8 @@ +from copy import deepcopy import gc import logging from pathlib import Path +import shutil import pandas as pd import yaml @@ -122,8 +124,45 @@ def hook_catalogue(dp): raise ValueError(f"run_id: {run_id} is not equal to stored: {run_id_stored} for {str(run_spec)}") +# TODO-QC: take a check here @hook def hook_generate_plots(dp: dataprocess): - plot_methods = dp.branch.plot.plot_method_lookup - for method in plot_methods.values(): - method(dp.branch) + plot_sources = dp.branch.plot.plot_method_lookup + current_process = dp.branch.current_process + all_plot_kwargs_sets = dp.branch.plot.plot_kwargs.get(current_process, dict()) + if not all_plot_kwargs_sets: + logging.warning(f"Plot map for '{current_process}' is undefined, skipping plots hook") + return + process_plot_methods = dp.branch.plot.plot_methods.get(current_process, dict()) + process_plot_map = dp.branch[current_process]._plot_map + requested_plot_methods = deepcopy(process_plot_methods) + + exceptions = [] + for method in plot_sources.values(): + plot_method_name = method.__name__ + if plot_method_name in requested_plot_methods.values(): + plot_name = dp.branch.plot.method_key_lookup[plot_method_name] + plot_kwargs_sets = all_plot_kwargs_sets[plot_name] + for plot_kwargs_key in plot_kwargs_sets.keys(): + plot_path = process_plot_map[plot_name][plot_kwargs_key] + kwargs = deepcopy(plot_kwargs_sets[plot_kwargs_key]) + kwargs["plot_path"] = plot_path + try: + method(dp.branch, show=False, **kwargs) + except Exception as e: + exceptions.append(e) + requested_plot_methods.pop(plot_name) + + if len(requested_plot_methods) > 0: # if not all requested mapped to functions in plot sources + logging.warning( + f"Requested plotting methods {requested_plot_methods} are not implemented so they were skipped." + ) + for e in exceptions: + raise e + + +@hook +def hook_clear_logs(dp: dataprocess): + logs_path = dp.branch[dp.name].logs_path + if logs_path.exists(): + shutil.rmtree(str(logs_path), ignore_errors=True) diff --git a/dataforest/hyperparams/Sweep.py b/dataforest/hyperparams/Sweep.py index 75aafdb..e587782 100644 --- a/dataforest/hyperparams/Sweep.py +++ b/dataforest/hyperparams/Sweep.py @@ -88,7 +88,7 @@ def plot( plot_method_kwargs = dict() if figsize is None: figsize = tuple(self.DEFAULT_SUBPLOT_SIZE * np.array(self.shape)) - fig, ax = plt.subplots(*self.shape, sharex="col", sharey="row", figsize=figsize) + fig, ax = plt.subplots(*self.shape, sharex="cols", sharey="row", figsize=figsize) for (i, j) in self.indices: branch = self.branch_matrix[i, j] if branch is None: diff --git a/dataforest/plot/PlotPreparator.py b/dataforest/plot/PlotPreparator.py new file mode 100644 index 0000000..2ae5420 --- /dev/null +++ b/dataforest/plot/PlotPreparator.py @@ -0,0 +1,115 @@ +from math import ceil +from typing import Callable, Optional, Literal, Union, List, Any, Tuple, Iterable + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from dataforest.core.DataBranch import DataBranch + + +class PlotPreparator: + DEFAULT_PLOT_RESOLUTION_PX = (500, 500) # width, height in pixels + NONE_VARIANTS = [None, "none", "None", "NULL", "NA"] + _DEFAULT_N_COLS = 3 + + def __init__(self, branch: "DataBranch"): + self.branch_df = pd.DataFrame({"branch": [branch]}) + self.ax_arr = None + self.fig = None + self.facet_cols = None + self.strat_cols = None + + @property + def facet_vals(self): + if not self.facet_cols: + return None + return sorted(self.branch_df["facet"].unique()) + + @property + def srat_vals(self): + if not self.strat_cols: + return None + return sorted(self.branch_df["stratify"].unique()) + + def prepare(self, plot_kwargs: dict): + xlim = plot_kwargs.pop("xlim", None) + ylim = plot_kwargs.pop("ylim", None) + xscale = plot_kwargs.pop("xscale", None) + yscale = plot_kwargs.pop("yscale", None) + + @np.vectorize + def _apply_ax_kwargs(ax_: plt.Axes): + if xlim: + ax_.set_xlim(xlim) + if ylim: + ax_.set_ylim(ylim) + if xscale: + ax_.set_xscale(xscale) + if yscale: + ax_.set_yscale(yscale) + + plot_size = plot_kwargs.pop("plot_size", self.DEFAULT_PLOT_RESOLUTION_PX) + figsize = plot_kwargs.pop("figsize", None) + ax_arr = plot_kwargs.pop("ax", None) + fig = plot_kwargs.pop("fig", plt.gcf()) + if self.ax_arr is not None: + ax_arr = self.ax_arr + elif ax_arr is None: + fig, ax_arr = plt.subplots(1, 1) + if not isinstance(ax_arr, np.ndarray): + ax_arr = np.array([[ax_arr]]) + elif ax_arr.ndim == 1: + ax_arr = np.expand_dims(ax_arr, axis=0) + _apply_ax_kwargs(ax_arr) + dpi = fig.get_dpi() + # scale to pixel resolution, irrespective of screen resolution + fig.set_size_inches(plot_size[0] / float(dpi), plot_size[1] / float(dpi)) + if figsize: + fig.set_size_inches(*figsize) + self.fig = fig + self.ax_arr = ax_arr + + def facet(self, cols: Union[str, List[str]], n_rows: Optional[int] = None, n_cols: Optional[int] = None): + self._branch_groupby(cols, "facet") + self.facet_cols = cols + labels = self.branch_df["facet"].unique() + dim = self._get_facet_dim(labels, n_rows, n_cols) + _, self.ax_arr = plt.subplots(*dim, sharex="col", sharey="row") + + def stratify(self, cols: Union[str, List[str]], plot_kwargs): + self._branch_groupby(cols, "stratify") + self.strat_cols = cols + + def _branch_groupby(self, cols: Union[str, Iterable[str]], key_colname: Literal["facet", "stratify"]): + df = self.branch_df + df["grp"] = df["branch"].apply(lambda branch: list(branch.groupby(cols))) + df = df.explode("grp").reset_index(drop=True) + df[[key_colname, "branch"]] = pd.DataFrame(df["grp"].tolist(), index=df.index) + self.branch_df = df + + def _get_facet_dim( + self, labels: List[Any], n_rows: Optional[int] = None, n_cols: Optional[int] = None + ) -> Tuple[int, int]: + """ + Get dimensions of subplot grid for facet + Args: + labels: unique labels in facet col + n_rows: + n_cols: + + Returns: (n_rows, n_cols) + """ + # TODO: if two cols and rows/cols within range, have option to use each as an axis + if not n_rows and not n_cols: + n_cols = min(len(labels), self._DEFAULT_N_COLS) + if n_rows and n_cols and n_rows != len(labels) / n_cols: + raise ValueError( + f"If both `n_rows` and `n_cols` are specified, their product " + f"must be appropriate for ({len(labels)})" + ) + if n_cols: + n_rows = ceil(len(labels) / n_cols) + else: + n_cols = ceil(len(labels) / n_rows) + return n_rows, n_cols diff --git a/dataforest/plot/PlotWidget.py b/dataforest/plot/PlotWidget.py new file mode 100644 index 0000000..b9748b7 --- /dev/null +++ b/dataforest/plot/PlotWidget.py @@ -0,0 +1,113 @@ +import logging +from copy import deepcopy +from typing import TYPE_CHECKING, Dict, Any, Optional + +from IPython.display import Image, display +import ipywidgets as widgets +from matplotlib.figure import Figure + +from dataforest.core.TreeSpec import TreeSpec + +if TYPE_CHECKING: + from dataforest.core.DataBranch import DataBranch + from dataforest.core.DataTree import DataTree + + +class PlotWidget: + def __init__(self, tree: "DataTree", plot_key: str, use_saved: bool = True, **plot_kwargs): + """ + + Args: + tree: + plot_key: key format of plot name (e.g. "_UMIS_VS_PERC_HSP_SCAT_") + use_saved: use saved plots or regenerate + **plot_kwargs: + """ + self.branch = None + self._tree = tree + self._branch_spec_template = deepcopy(tree.tree_spec.branch_specs[0]) + self._branch_spec = deepcopy(self._branch_spec_template) + self._plot_key = plot_key + self._use_saved = use_saved + self._bypass_kwargs = plot_kwargs.pop("bypass_kwargs", dict()) + self._plot_kwargs = plot_kwargs + self._sweeps = self._tree.tree_spec.sweep_dict + self._plot_cache = dict() + + def build_control(self, **plotter_kwargs): + twig_lookup = self._tree.tree_spec.twig_lookup + _kwargs = {} + + def _prep_sweeps_kwargs(): + process = self._tree.current_process + precursors = self._tree.tree_spec.get_precursors_lookup(incl_current=True)[process] + sweeps = set().union(*[self._sweeps[precursor] for precursor in precursors]) + # {"key:operation:process": value, ...} (e.g. {"num_pcs:_PARAMS_:normalize": 30, ...}) + param_sweeps = {":".join(swp[:3][::-1]): list(swp[3]) for swp in sweeps} + _kwargs.update({**param_sweeps, **self._plot_kwargs}) + + if self._tree.has_sweeps: + _prep_sweeps_kwargs() + if self._tree.has_twigs: + _kwargs["twig_str"] = list(twig_lookup.keys()) + + @widgets.interact(**_kwargs) + # TODO: this seems to be slower now that the spec is recalculated every time -- is it replotted? + def _control(**kwargs: Dict[str, Any]): + spec = deepcopy(self._branch_spec_template) + twig_str = kwargs.pop("twig_str", None) + # add sweeps + for param_keys_str, value in kwargs.items(): + if param_keys_str in self._plot_kwargs: + self._plot_kwargs[param_keys_str] = value + continue + elif isinstance(value, float): + value = int(value) if int(value) == value else value + (name, operation, process) = param_keys_str.split(":") + spec[process][operation][name] = value + # TODO: do we need to do something with kwargs? Or taken care of? + [print(k, kwargs.pop(k, None)) for k in self._plot_kwargs] + # add twigs + if twig_str: + twig = twig_lookup[twig_str] + spec = TreeSpec.add_twig(spec, twig) + # get branch + self.branch = self._tree._branch_cache[str(spec)] + return self._get_plot(self.branch, **plotter_kwargs) + + return _control + + def _get_plot(self, branch, **kwargs): + kwargs = {**self._bypass_kwargs, **kwargs} + plot_map = branch[self._tree.current_process]._plot_map + plot_path_lookup = {plot_key: next(iter(path_dict.values())) for plot_key, path_dict in plot_map.items()} + spec = branch.spec[: branch.current_process] + cache_key = str({"spec": spec, "plot_kwargs": self._plot_kwargs}) + if self._plot_key in plot_path_lookup and self._use_saved: + plot_path = plot_path_lookup.get(self._plot_key) + if plot_path.exists(): + return Image(plot_path) + generated = False + plot_obj = self._plot_cache.get(cache_key, None) + if not plot_obj: + generated = True + plot_obj = self._generate_plot(branch, **kwargs) + self._plot_cache[cache_key] = plot_obj + if isinstance(plot_obj, tuple) and isinstance(plot_obj[0], Figure): + if not generated: + return plot_obj[0] + elif isinstance(plot_obj, Image): + return plot_obj + elif isinstance(plot_obj, list) and all(isinstance(x, Image) for x in plot_obj): + for img in plot_obj: + display(img) + return plot_obj + else: + raise TypeError( + f"Expected types (matplotlib.axes.Axes, IPython.display.Image, List[Ipython.display.Image]). Got {type(plot_obj)}" + ) + + def _generate_plot(self, branch: "DataBranch", **kwargs): + method = branch.plot.method_lookup[self._plot_key] + kwargs = {**self._plot_kwargs, **kwargs, "show": False} + return method(**kwargs) diff --git a/dataforest/plot/__init__.py b/dataforest/plot/__init__.py new file mode 100644 index 0000000..753bf3f --- /dev/null +++ b/dataforest/plot/__init__.py @@ -0,0 +1 @@ +from .wrappers import plot_py, plot_r, requires diff --git a/dataforest/plot/wrappers.py b/dataforest/plot/wrappers.py new file mode 100644 index 0000000..2bbc20c --- /dev/null +++ b/dataforest/plot/wrappers.py @@ -0,0 +1,153 @@ +from functools import wraps +import json +import logging +from pathlib import Path +from itertools import product +from typing import Tuple, Optional, AnyStr, Union + +from IPython.display import Image, display +import matplotlib +import matplotlib.pyplot as plt +import numpy as np + +from dataforest.core.DataBranch import DataBranch +from dataforest.core.PlotMethods import PlotMethods +from dataforest.plot.PlotPreparator import PlotPreparator + +_DEFAULT_BIG_PLOT_RESOLUTION_PX = (1000, 1000) # width, height in pixels +_PLOT_FILE_EXT = ".png" + + +def plot_py(plot_func): + @wraps(plot_func) + def wrapper( + branch: "DataBranch", + stratify: Optional[str] = None, + facet: Optional[str] = None, + plot_path: Optional[AnyStr] = None, + facet_dim: tuple = (), + show: bool = True, + **kwargs, + ) -> Union[plt.Figure, Tuple[plt.Figure, np.ndarray]]: + prep = PlotPreparator(branch) + stratify = None if stratify in prep.NONE_VARIANTS else stratify + facet = None if facet in prep.NONE_VARIANTS else facet + if facet is not None: + kwargs["ax"] = prep.facet(facet, *facet_dim) + if stratify is not None: + prep.stratify(stratify, kwargs) + if plot_path is not None: + matplotlib.use("Agg") # don't plot on screen + prep.prepare(kwargs) + facet_inds = list(product(*map(range, prep.ax_arr.shape))) + for _, row in prep.branch_df.iterrows(): + ax = prep.ax_arr[0, 0] + if facet is not None: + i = prep.facet_vals.index(row["facet"]) + ax_i = facet_inds[i] + ax = prep.ax_arr[ax_i] + ax.set_title(row["facet"]) + if stratify is not None: + kwargs["label"] = row["stratify"] + plot_func(row["branch"], ax=ax, **kwargs) + if stratify is not None: + ax.legend() + if plot_path is not None: + logging.info(f"saving py figure to {plot_path}") + prep.fig.savefig(plot_path) + if show: + display(prep.fig) + return prep.fig, prep.ax_arr + + return wrapper + + +def plot_r(plot_func): + @wraps(plot_func) + def wrapper( + branch: "DataBranch", + stratify: Optional[str] = None, + facet: Optional[str] = None, + plot_path: Optional[AnyStr] = None, + show: bool = True, + **kwargs, + ): + def _get_plot_script(): + for _plot_source in PlotMethods.R_PLOT_SOURCES: + _r_script = _plot_source / (plot_func.__name__ + ".R") + if _r_script.exists(): + return _plot_source, _r_script + + stratify = None if stratify in PlotPreparator.NONE_VARIANTS else stratify + facet = None if facet in PlotPreparator.NONE_VARIANTS else facet + plot_path = plot_path if plot_path else "/tmp/plot.png" + + if facet is not None: + logging.warning(f"{plot_func.__name__} facet not yet implemented for R plots") + if stratify is not None: + logging.warning(f"{plot_func.__name__} stratify not yet supported for R plots") + plot_source, r_script = _get_plot_script() + plot_size = kwargs.pop("plot_size", PlotPreparator.DEFAULT_PLOT_RESOLUTION_PX) + if stratify is not None: + if stratify in branch.meta: # col exists in metadata + kwargs["group.by"] = stratify + else: + logging.warning(f"{plot_func.__name__} with key '{stratify}' is skipped because key is not in metadata") + return + subset_vals = [None] + if facet is not None: + subset_vals = sorted(branch.meta[facet].unique()) + img_arr = [] + + for val in subset_vals: + # corresponding arguments in r/plot_entry_point.R + args = [ + plot_source, # r_plot_scripts_path + branch.paths["root"], # root_dir + branch.spec.shell_str, # spec_str + facet, # subset_key + val, # subset_val + branch.current_process, # current_process + plot_path, # plot_file_path + plot_size[0], # plot_width_px + plot_size[1], # plot_height_px + json.dumps( + "kwargs = " + str(kwargs if kwargs else {}) + ), # TODO-QC: is there a better way to handle this? + ] + logging.info(f"saved R figure to {plot_path}") + plot_func(branch, r_script, args) # plot_kwargs already included in args + img = Image(plot_path) + if not Path(plot_path).exists(): + logging.warning(f"{plot_func.__name__} output file not found after execution for {val}") + else: + if show: + display(img) + img_arr.append(img) + return img_arr + + return wrapper + + +# noinspection PyPep8Naming +class requires: + def __init__(self, req_process): + self._req_process = req_process + + def __call__(self, func): + func._requires = self._req_process + + @wraps(func) + def wrapper(branch: "DataBranch", *args, **kwargs): + if not branch[self._req_process].success: + raise ValueError(f"`{func.__name__}` requires a complete and successful process: `{self._req_process}`") + precursors = branch.spec.get_precursors_lookup(incl_current=True)[branch.current_process] + if self._req_process not in precursors: + proc = self._req_process + raise ValueError( + f"This plot method requires a branch at `{proc}` or later. Current process run: {precursors}. If " + f"`{proc}` has already been run, please use `branch.goto_process`. Otherwise, please run `{proc}`." + ) + return func(branch, *args, **kwargs) + + return wrapper diff --git a/dataforest/processes/core/TreeProcessMethods.py b/dataforest/processes/core/TreeProcessMethods.py index e3220ba..7ed4441 100644 --- a/dataforest/processes/core/TreeProcessMethods.py +++ b/dataforest/processes/core/TreeProcessMethods.py @@ -1,13 +1,20 @@ -from typing import Callable, List, Union +import logging +from typing import Callable, List, Union, TYPE_CHECKING + +from joblib import Parallel, delayed from dataforest.core.TreeSpec import TreeSpec from dataforest.structures.cache.BranchCache import BranchCache +if TYPE_CHECKING: + from dataforest.core.DataTree import DataTree + class TreeProcessMethods: - def __init__(self, tree_spec: TreeSpec, branch_cache: BranchCache): - self._tree_spec = tree_spec - self._branch_cache = branch_cache + _N_CPUS_EXCLUDED = 1 + + def __init__(self, tree: "DataTree"): + self._tree = tree self._process_methods = list() self._tether_process_methods() @@ -17,7 +24,7 @@ def process_methods(self): def _tether_process_methods(self): # TODO: docstring - method_names = [run_group_spec.name for run_group_spec in self._tree_spec] + method_names = [run_group_spec.name for run_group_spec in self._tree.tree_spec] for name in method_names: distributed_method = self._build_distributed_method(name) setattr(self, name, distributed_method) @@ -36,11 +43,13 @@ def _build_distributed_method(self, method_name: str) -> Callable: branches """ - def distributed_method( + def _distributed_method( *args, stop_on_error: bool = False, stop_on_hook_error: bool = False, clear_data: Union[bool, List[str]] = True, + force_rerun: bool = False, + parallel: bool = False, **kwargs, ): """ @@ -53,26 +62,44 @@ def distributed_method( after process execution to save memory. If boolean, whether or not to clear all data attrs, if list, names of data attrs to clear + force_rerun: force the process to rerun even if already done + parallel: whether or not to parallelize **kwargs: Returns: """ - if not self._branch_cache.fully_loaded: - self._branch_cache.load_all() - return_vals = [] - for branch in list(self._branch_cache.values()): + kwargs = {"stop_on_error": stop_on_error, "stop_on_hook_error": stop_on_hook_error, **kwargs} + unique_branches = self._tree.unique_branches_at_process(method_name) + + def _single_kernel(branch): branch_method = getattr(branch.process, method_name) - if not branch[method_name].done: - return_vals.append( - branch_method( - *args, stop_on_error=stop_on_error, stop_on_hook_error=stop_on_hook_error, **kwargs - ) - ) + if not branch[method_name].done or force_rerun: + ret = branch_method(*args, **kwargs) if clear_data: clear_kwargs = {"all_data": True} if isinstance(clear_data, bool) else {"attrs": clear_data} branch.clear_data(**clear_kwargs) - return return_vals + return ret + + def _distributed_kernel_serial(): + _ret_vals = [] + for branch in unique_branches.values(): + _ret_vals.append(_single_kernel(branch)) + return _ret_vals + + def _distributed_kernel_parallel(): + process = delayed(_single_kernel) + pool = Parallel(n_jobs=-1 - self._N_CPUS_EXCLUDED) + return pool(process(branch) for branch in unique_branches.values()) + + exec_scheme = "PARALLEL" if parallel else "SERIAL" + logging.info(f"{exec_scheme} execution of {method_name} on {len(unique_branches)} unique branches") + kernel = _distributed_kernel_parallel if parallel else _distributed_kernel_serial + ret_vals = kernel() + return ret_vals + + _distributed_method.__name__ = method_name + return _distributed_method - distributed_method.__name__ = method_name - return distributed_method + def _distributed_kernel_parallel(self): + raise NotImplementedError() diff --git a/dataforest/structures/cache/BranchCache.py b/dataforest/structures/cache/BranchCache.py index 21d5085..7005710 100644 --- a/dataforest/structures/cache/BranchCache.py +++ b/dataforest/structures/cache/BranchCache.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import List, Union @@ -7,6 +8,8 @@ class BranchCache(HashCash): + _LOG = logging.getLogger("BranchCache") + def __init__(self, root: Union[str, Path], branch_specs: List[BranchSpec], branch_class: type, *args): super().__init__() self._root = root @@ -16,9 +19,11 @@ def __init__(self, root: Union[str, Path], branch_specs: List[BranchSpec], branc self.fully_loaded = False def load_all(self): - for spec_str in self._branch_spec_lookup: - _ = self[spec_str] # force load of all items - self.fully_loaded = True + self._LOG.info(f"loading all branches to `goto_process`") + if not self.fully_loaded: + for spec_str in self._branch_spec_lookup: + _ = self[spec_str] # force load of all items + self.fully_loaded = True def update_branch_specs(self, branch_specs: List[BranchSpec]): self._branch_spec_lookup = {str(spec): spec for spec in branch_specs} diff --git a/dataforest/structures/cache/PlotCache.py b/dataforest/structures/cache/PlotCache.py new file mode 100644 index 0000000..9bf34fa --- /dev/null +++ b/dataforest/structures/cache/PlotCache.py @@ -0,0 +1,3 @@ +class PlotCache: + def __init__(self, branch, plot_key): + self._plot_key = plot_key diff --git a/dataforest/utils/ExceptionHandler.py b/dataforest/utils/ExceptionHandler.py index e6ea8f9..8542d7f 100644 --- a/dataforest/utils/ExceptionHandler.py +++ b/dataforest/utils/ExceptionHandler.py @@ -1,20 +1,42 @@ +import logging import traceback class ExceptionHandler: - @staticmethod - def handle(branch, e, logfile_name, stop): + _LOG = logging.getLogger("ExceptionHandler") + + @classmethod + def handle(cls, branch, e, logfile_name, stop): + try: + log_path = cls._get_log_path(branch, logfile_name) + except Exception as path_e: + cls._LOG.warning(f"{type(path_e).__name__} encountered getting logging output path for {logfile_name}") + if stop: + raise e + else: + cls._handle_write(e, log_path, stop) + + @classmethod + def _handle_write(cls, e, log_path, stop): try: - ExceptionHandler._log_error(branch, e, logfile_name) + cls._write_log(e, log_path) + cls._LOG.info(f"Wrote log to {log_path}") + except Exception as write_e: + cls._LOG.warning(f"{type(write_e).__name__} encountered writing {log_path}") finally: if stop: raise e + else: + cls._LOG.warning(f"{type(e).__name__} encountered but `stop_on_error=False`. Logs at {log_path}") @staticmethod - def _log_error(branch, e, logfile_name): + def _get_log_path(branch, logfile_name): log_dir = branch[branch.current_process].logs_path log_dir.mkdir(exist_ok=True) - log_path = log_dir / logfile_name + return log_dir / logfile_name + + @staticmethod + def _write_log(e, log_path): with open(log_path, "w") as f: f.write("Traceback (most recent call last):") traceback.print_tb(e.__traceback__, file=f) diff --git a/dataforest/utils/exceptions.py b/dataforest/utils/exceptions.py index cf3980f..61091b9 100644 --- a/dataforest/utils/exceptions.py +++ b/dataforest/utils/exceptions.py @@ -47,13 +47,12 @@ def __str__(self): class BadOperation(KeyError): OPERATION = None - def __init__(self, column, val): - self._column = column - self._val = val + def __init__(self, dict_): + self._dict = dict_ def __str__(self): return ( - f"{self.OPERATION} resulted in no rows - column: {self._column} val: {self._val}. Please note that any " + f"{self.OPERATION} resulted in no rows - {self._dict}. Please note that any " f"spaces in `spec` should be converted to underscores!" ) diff --git a/dataforest/utils/loaders/collectors.py b/dataforest/utils/loaders/collectors.py index 4209f2e..43184d1 100644 --- a/dataforest/utils/loaders/collectors.py +++ b/dataforest/utils/loaders/collectors.py @@ -3,6 +3,8 @@ from pathlib import Path from typing import Optional, Callable +_PLOT_EXCLUDE_NAMES = ["plot_py", "plot_r"] + def collect_hooks(path): return _collector(path, "hooks.py", _function_filter_hook) @@ -63,7 +65,8 @@ def _function_filter_hook(func): def _function_filter_plot(func): - return func.__name__.startswith("plot_") + name = getattr(func, "__name__", "") + return name.startswith("plot_") and name not in _PLOT_EXCLUDE_NAMES def _function_filter_process(func): diff --git a/dataforest/utils/loaders/config.py b/dataforest/utils/loaders/config.py index 40cbb13..408fb08 100644 --- a/dataforest/utils/loaders/config.py +++ b/dataforest/utils/loaders/config.py @@ -1,24 +1,45 @@ import json from pathlib import Path -from typing import Union +from typing import Union, AnyStr, Dict, Callable import yaml +from dataforest import config as _config_module -def load_config(config: Union[dict, str, Path]): - """Load library loaders""" - if isinstance(config, (str, Path)): - config = Path(config) - if config.suffix == ".json": - with open(str(config), "r") as f: - config = json.load(f) - elif config.suffix == ".yaml": - with open(str(config), "r") as f: - config = yaml.load(f, yaml.FullLoader) - else: - raise ValueError("If filepath is passed, must be .json or .yaml") - if not isinstance(config, dict): - raise TypeError( - f"Config must be either a `dict` or a path to a .json or .yaml file as either a `str` or `pathlib.Path`" - ) - return config + +def get_config_options(config_dir: AnyStr) -> Dict[str, Path]: + config_paths = Path(config_dir).iterdir() + config_lookup = {path.stem: path for path in config_paths if not path.name.startswith("__")} + return config_lookup + + +def get_config_loader(config_options: Dict[str, Path]) -> Callable: + def _load_config(config: Union[dict, str, Path]) -> dict: + """ + Global configuration for package function from dict, filepath, or name + in `config_options` + """ + if isinstance(config, (str, Path)): + if config in config_options: + config = config_options[config] + config = Path(config) + if config.suffix == ".json": + with open(str(config), "r") as f: + config = json.load(f) + elif config.suffix == ".yaml": + with open(str(config), "r") as f: + config = yaml.load(f, yaml.FullLoader) + else: + raise ValueError("If filepath is passed, must be .json or .yaml") + if not isinstance(config, dict): + raise TypeError( + f"Config must be either a `dict` or a path to a .json or .yaml file as either a `str` or `pathlib.Path`" + ) + return config + + return _load_config + + +_CONFIG_DIR = Path(_config_module.__file__).parent +CONFIG_OPTIONS = get_config_options(_CONFIG_DIR) +load_config = get_config_loader(CONFIG_OPTIONS) diff --git a/dataforest/utils/loaders/path.py b/dataforest/utils/loaders/path.py new file mode 100644 index 0000000..e3feae4 --- /dev/null +++ b/dataforest/utils/loaders/path.py @@ -0,0 +1,14 @@ +from importlib import import_module +from pathlib import Path +from typing import AnyStr, List + + +def get_module_paths(paths: List[AnyStr]) -> List[Path]: + def _get_module_path(path: AnyStr) -> Path: + module = import_module(str(path)) + path = Path(module.__file__) + if path.stem == "__init__": + path = path.parent + return path + + return list(map(_get_module_path, paths)) diff --git a/dataforest/utils/loaders/update_config.py b/dataforest/utils/loaders/update_config.py index afaa601..59245c5 100644 --- a/dataforest/utils/loaders/update_config.py +++ b/dataforest/utils/loaders/update_config.py @@ -1,13 +1,16 @@ from pathlib import Path -from typing import Union +from typing import Union, Callable from dataforest.config.MetaConfig import MetaConfig from dataforest.utils.loaders.config import load_config -def update_config(config: Union[dict, str, Path]): - config = load_config(config) - MetaConfig.CONFIG = config +def get_config_updater(config_loader: Callable) -> Callable: + def _update_config(config: Union[dict, str, Path]): + config = config_loader(config) + MetaConfig.CONFIG = config + + return _update_config def get_current_config() -> dict: diff --git a/dataforest/utils/plots_config.py b/dataforest/utils/plots_config.py new file mode 100644 index 0000000..53badb6 --- /dev/null +++ b/dataforest/utils/plots_config.py @@ -0,0 +1,281 @@ +from copy import deepcopy +import json +from collections import OrderedDict +from typing import Dict + + +def build_process_plot_method_lookup(plot_map: dict) -> Dict[str, Dict[str, str]]: + """ + Get a lookup of processes, each containing a mapping between plot method + names in the config and the actual callable names. + Format: + process_name[config_plot_name][plot_callable_name] + Ex: + {"normalize": {"_UMIS_PER_CELL_HIST_": "plot_umis_per_cell_hist", ...}, ...} + """ + process_plot_methods = {} + for process, plots in plot_map.items(): + process_plot_methods[process] = {} + for plot_name in plots.keys(): + try: + plot_method = plots[plot_name]["plot_method"] + except (TypeError, KeyError): + plot_method = _get_plot_method_from_plot_name(plot_name) + + process_plot_methods[process][plot_name] = plot_method + return process_plot_methods + +def parse_plot_kwargs(plot_map: dict, plot_kwargs_defaults: dict): + """Parse plot methods plot_kwargs per process from plot_map""" + all_plot_kwargs = {} + for process, plots in plot_map.items(): + all_plot_kwargs[process] = {} + + for plot_name in plots.keys(): + all_plot_kwargs[process][plot_name] = {} + + try: + plot_kwargs = plots[plot_name]["plot_kwargs"] + except (KeyError, TypeError): + plot_kwargs = _get_default_plot_kwargs(plot_kwargs_defaults) + + kwargs_feed = _get_plot_kwargs_feed(plot_kwargs, plot_kwargs_defaults, plot_name) + kwargs_feed_mapped = _get_plot_kwargs_feed( + plot_kwargs, plot_kwargs_defaults, plot_name, map_to_default_kwargs=True, + ) + + for plot_kwargs_set, plot_kwargs_set_mapped in zip(kwargs_feed, kwargs_feed_mapped): + all_plot_kwargs[process][plot_name][_get_plot_kwargs_string(plot_kwargs_set)] = plot_kwargs_set_mapped + + return all_plot_kwargs + + +def parse_plot_map(plot_map: dict, plot_kwargs_defaults: dict): + """ + Parse plot file map per process from plot_map and ensures that + implicit definition returns a dictionary of default values for all plot_kwargs + + Example: + { + "root": { + "_UMIS_PER_CELL_HIST_": { + '{"plot_size": "default", "stratify": "default"}': "umis_per_cell_hist-plot_size:800+800-stratify:none.png" + } + } + """ + all_plot_maps = {} + for process, plots in plot_map.items(): + all_plot_maps[process] = {} + + for plot_name in plots.keys(): + all_plot_maps[process][plot_name] = {} + + try: + all_plot_kwargs = plots[plot_name]["plot_kwargs"] + except (KeyError, TypeError): + all_plot_kwargs = _get_default_plot_kwargs(plot_kwargs_defaults) + + kwargs_feed = _get_plot_kwargs_feed(all_plot_kwargs, plot_kwargs_defaults, plot_name) + kwargs_feed_mapped = _get_plot_kwargs_feed( + plot_kwargs=all_plot_kwargs, + plot_kwargs_defaults=plot_kwargs_defaults, + plot_name=plot_name, + map_to_default_kwargs=True, + ) + + for i, (plot_kwargs_set, plot_kwargs_set_mapped) in enumerate(zip(kwargs_feed, kwargs_feed_mapped)): + try: + plot_filename = plots[plot_name]["filename"] + if type(plot_filename) == list: + plot_filename = _get_formatted_plot_filename(plot_filename[i], plot_kwargs_defaults) + except (KeyError, TypeError): + plot_filename = _get_default_plot_filename(plot_name, plot_kwargs_set_mapped, plot_kwargs_defaults) + + all_plot_maps[process][plot_name][_get_plot_kwargs_string(plot_kwargs_set)] = plot_filename + return all_plot_maps + + +def _get_plot_method_from_plot_name(plot_name): + """Infer plot method name from plot name, e.g. _UMIS_PER_CELL_HIST_ -> plot_umis_per_cell_hist""" + if plot_name[0] == "_": + plot_name = plot_name[1:] + if plot_name[-1] == "_": + plot_name = plot_name[:-1] + formatted_plot_name = plot_name.lower() + plot_method = "plot_" + formatted_plot_name + + return plot_method + + +def _unify_kwargs_opt_lens(plot_kwargs: dict, plot_kwargs_defaults: dict, plot_name: str): + """ + Make all kwarg option counts equal so that we can get aligned in order options + + Examples: + >>> _unify_kwargs_opt_lens( + >>> { + >>> "stratify": ["sample_id", "none"], + >>> "plot_size": "default" + >>> }, + >>> plot_kwargs_defaults, plot_name + >>> ) + # output + { + "stratify": ["sample_id", "none"], + "plot_size": ["default", "default"] + } + """ + kwargs_num_options = set() # number of options for each kwarg + for key, values in plot_kwargs.items(): + if type(values) != list: + plot_kwargs[key] = [values] + else: + kwargs_num_options.add(len(values)) + try: + max_num_opts = max(kwargs_num_options) + kwargs_num_options.remove(max_num_opts) + except ValueError: + max_num_opts = 1 # means that there are no lists and just singular arguments + if len(kwargs_num_options) > 0: # check if the lists of options are equal to each other + raise ValueError( + f"'{plot_name}' contains arguments with unequal number of options, should include the same number of options where there are multiple options or a single option." + ) + + # fill in plot_kwargs that are not defined + template_plot_kwargs = _get_default_plot_kwargs(plot_kwargs_defaults) + for key, value in template_plot_kwargs.items(): + if key not in plot_kwargs: + plot_kwargs[key] = value + + for key, values in plot_kwargs.items(): + if type(values) != list: + plot_kwargs[key] = [values] * max_num_opts # stretch to the same length + elif len(values) == 1: + plot_kwargs[key] = values * max_num_opts + + return plot_kwargs + + +def _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults): + """Map plot_kwargs to values defined in plot_kwargs defaults if available""" + mapped_plot_kwargs = deepcopy(plot_kwargs) + + for kwarg_name, kwarg_values in plot_kwargs.items(): + if kwarg_name in plot_kwargs_defaults: + mapping = [] + for val in kwarg_values: + try: + mapping.append(plot_kwargs_defaults[kwarg_name].get(val, val)) + except TypeError: + mapping.append(val) + + mapped_plot_kwargs[kwarg_name] = mapping + + return mapped_plot_kwargs + + +def _get_plot_kwargs_feed(plot_kwargs: dict, plot_kwargs_defaults: dict, plot_name: str, map_to_default_kwargs=False): + """ + Create a feeding of kwargs tuples to create multiple plots, replacing values with + plot_kwargs_defaults values + + Example: + # input plot_kwargs + { + "stratify": ["sample_id", "default"], + "plot_size": ["default", "default"] + } + + # output: + [ + { + "stratify": "sample_id", + "plot_size": [800, 800] + }, + { + "stratify": None, + "plot_size": [800, 800] + } + ] + """ + + plot_kwargs = _unify_kwargs_opt_lens(plot_kwargs, plot_kwargs_defaults, plot_name) + if map_to_default_kwargs: + plot_kwargs = _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults) + plot_kwargs_feed = [ + dict(j) for j in zip(*[[(k, i) for i in v] for k, v in plot_kwargs.items()]) + ] # 1-1 mapping of plot_kwargs options + + return plot_kwargs_feed + + +def _get_default_plot_kwargs(plot_kwargs_defaults: dict): + kwargs_keys = list(plot_kwargs_defaults.keys()) + for kwargs_key in reversed(kwargs_keys): + if "filename" in kwargs_key: # ignore filename-related args (e.g., plot filename extension) + kwargs_keys.remove(kwargs_key) + + default_plot_kwargs = dict(zip(kwargs_keys, ["default"] * len(kwargs_keys))) + + return default_plot_kwargs + + +def plot_kwargs_to_str(plot_kwargs): + """ + Converts plot_kwargs dictionary into a deterministic string, sorted by keys + """ + UP = "-" + DOWN = ":" + AND = "+" + str_chain = "" + + def _helper(dict_: dict): + nonlocal str_chain + for key in sorted(dict_): + val = dict_[key] + str_chain += str(key) + if val is not None: + str_chain += DOWN + if val is None: + pass + elif isinstance(val, dict): + _helper(val) + elif isinstance(val, (set, list, tuple)): + str_chain += AND.join(map(str, val)) + else: + str_chain += str(val) + str_chain += UP + + _helper(plot_kwargs) + str_chain = str_chain.strip("-").lower() + + return str_chain + + +def _get_plot_kwargs_string(plot_kwargs: dict): # TODO-QC: proper type checking + ord_plot_kwargs = OrderedDict(sorted(plot_kwargs.items())) + + return json.dumps(ord_plot_kwargs) + + +def _get_formatted_plot_filename(plot_name: str, plot_kwargs_defaults: dict): + filename_ext = "." + plot_kwargs_defaults.get("filename_ext", "png").lower().replace(".", "") + + plot_filename = plot_name.lower() + if "." not in plot_name: # if plot map doesn't have extension yet + plot_filename += filename_ext + + return plot_filename + + +def _get_default_plot_filename(plot_name: str, plot_kwargs: dict, plot_kwargs_defaults: dict): + """Infer plot filename from plot name, e.g. _UMIS_PER_CELL_HIST_ -> umis_per_cell_hist.png""" + filename_ext = "." + plot_kwargs_defaults.get("filename_ext", "png").lower().replace(".", "") + + if plot_name[0] == "_": + plot_name = plot_name[1:] + if plot_name[-1] == "_": + plot_name = plot_name[:-1] + plot_filename = plot_name.lower() + "-" + plot_kwargs_to_str(plot_kwargs) + filename_ext + + return plot_filename diff --git a/dataforest/utils/tether.py b/dataforest/utils/tether.py index 39e9916..455e570 100644 --- a/dataforest/utils/tether.py +++ b/dataforest/utils/tether.py @@ -12,7 +12,7 @@ def tether(obj, tether_attr, incl_methods=None, excl_methods=None): if incl_methods and excl_methods: ValueError("Cannot specify both `incl_methods` and `excl_methods`") tether_arg = getattr(obj, tether_attr) - method_list = get_methods(obj, incl_methods, excl_methods) + method_list = get_methods(obj, excl_methods, incl_methods) for method_name in method_list: method = copy_func(getattr(obj, method_name)) tethered_method = make_tethered_method(method, tether_arg) @@ -20,19 +20,19 @@ def tether(obj, tether_attr, incl_methods=None, excl_methods=None): def get_methods( - obj: Any, excl_methods: Optional[Iterable[str]] = None, incl_methods: Optional[Iterable[str]] = None + obj: Any, excl_methods: Optional[Iterable[str]] = None, incl_methods: Optional[List[str]] = None ) -> List[str]: """Get all user defined methods of an object""" all_methods = filter(lambda x: not x.startswith("__"), dir(obj)) method_list = [] + if incl_methods: + return incl_methods for method_name in all_methods: if hasattr(obj, method_name): if callable(getattr(obj, method_name)): method_list.append(method_name) if excl_methods: method_list = [m for m in method_list if m not in excl_methods] - elif incl_methods: - method_list = incl_methods return method_list diff --git a/docs/pull_request_template.md b/docs/pull_request_template.md new file mode 100644 index 0000000..b87ceff --- /dev/null +++ b/docs/pull_request_template.md @@ -0,0 +1,17 @@ +### Updates +- + +### Testing +[ ] Tested locally +[ ] Tested notebook +[ ] Cleared notebook outputs +[ ] Docstrings and type-hinting + +### Fixes +- + +### Future works +- + +### Reviewer's tasks +- diff --git a/requirements/requirements.txt b/requirements.txt similarity index 56% rename from requirements/requirements.txt rename to requirements.txt index 5516096..b8a6d1a 100644 --- a/requirements/requirements.txt +++ b/requirements.txt @@ -1,7 +1,10 @@ +fastcore +ipywidgets +joblib matplotlib numpy pandas -pathlib pyyaml termcolor +typeguard Ipython \ No newline at end of file diff --git a/setup.py b/setup.py index 9f4e6e8..9cc8e79 100644 --- a/setup.py +++ b/setup.py @@ -11,13 +11,13 @@ def parse_requirements(requirements_path): return f.read().splitlines() -requirements = parse_requirements("requirements/requirements.txt") +requirements = parse_requirements("requirements.txt") test_requirements = parse_requirements("requirements/requirements-test.txt") dev_requirements = parse_requirements("requirements/requirements-dev.txt") setup( name="dataforest", # Replace with your own username - version="0.0.1", + version="0.0.2", author="Austin McKay", author_email="austinmckay303@gmail.com", description="An interactive data science workflow manager", @@ -28,9 +28,10 @@ def parse_requirements(requirements_path): # TODO: this works in "all", but not in `install_reqs` -- fix install_requires=requirements, extras_require={"all": requirements, "test": test_requirements, "dev": dev_requirements,}, + package_data={"dataforest": ["config/default_config.yaml"]}, classifiers=[ "Programming Language :: Python :: 3", - "License :: AGPL 3.0 License", + "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", "Operating System :: OS Independent", ], python_requires=">=3.6", diff --git a/tests/test_datatree.py b/tests/test_datatree.py index 689d607..9d7f18f 100644 --- a/tests/test_datatree.py +++ b/tests/test_datatree.py @@ -29,7 +29,7 @@ def tree_spec(): }, {"min_genes": 5, "max_genes": 5000, "min_cells": 5, "perc_mito_cutoff": 20, "method": "sctransform"}, ], - "_SUBSET_": {"sample": {"_SWEEP_": ["sample_1", "sample_2"]}}, + "_SUBSET_": {"sample_id": {"_SWEEP_": ["sample_1", "sample_2"]}}, }, { "_PROCESS_": "reduce", diff --git a/tests/test_tether.py b/tests/test_tether.py index 7000dac..822dad9 100644 --- a/tests/test_tether.py +++ b/tests/test_tether.py @@ -6,7 +6,7 @@ def __init__(self, val): self.val = val @staticmethod - def my_method(val): + def my_method(val, _): return val + 5 diff --git a/tests/test_tree.py b/tests/test_tree.py index 5013b0d..9250581 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -17,7 +17,6 @@ def test_branch_1(): t[t.stack] = "c" t.down("c") t[t.stack] = "d" - print(t.dict) assert t.dict == {"a": {"b": "c", "c": "d"}} return t @@ -32,7 +31,6 @@ def test_branch_2(): t.down("c") t[t.stack] = "d" t[t.stack] = "e" - print(t.dict) assert t.dict == {"a": {"b": None, "c": {"d", "e"}}} return t @@ -64,10 +62,8 @@ def test_node_awareness_2(test_branch_2): def test_apply_leaves_1(test_branch_1): t = test_branch_1 t_2 = t.apply_leaves(lambda x: 1) - print(t_2.dict) def test_apply_leaves_2(test_branch_2): t = test_branch_2 t_2 = t.apply_leaves(lambda x: 1) - print(t_2.dict)