From 9088a025477974688a9b58226c3951c589de964a Mon Sep 17 00:00:00 2001 From: Munchic Date: Fri, 7 Aug 2020 21:03:47 -0300 Subject: [PATCH 01/19] function collector collect only function --- dataforest/core/DataBranch.py | 2 +- dataforest/utils/loaders/collectors.py | 4 +++- requirements/requirements.txt | 1 - 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/dataforest/core/DataBranch.py b/dataforest/core/DataBranch.py index d80ec2f..34d844e 100644 --- a/dataforest/core/DataBranch.py +++ b/dataforest/core/DataBranch.py @@ -124,7 +124,7 @@ def __init__( self.spec = self._init_spec(branch_spec) self.verbose = verbose self.logger = logging.getLogger(self.__class__.__name__) - # self.plot = self.PLOT_METHODS(self) + self.plot = self.PLOT_METHODS(self) self.process = self.PROCESS_METHODS(self, self.spec) # self.hyper = HyperparameterMethods(self) self.schema = self.SCHEMA_CLASS() diff --git a/dataforest/utils/loaders/collectors.py b/dataforest/utils/loaders/collectors.py index 4209f2e..29d1494 100644 --- a/dataforest/utils/loaders/collectors.py +++ b/dataforest/utils/loaders/collectors.py @@ -63,7 +63,9 @@ def _function_filter_hook(func): def _function_filter_plot(func): - return func.__name__.startswith("plot_") + if hasattr(func, "__name__"): + return func.__name__.startswith("plot_") + # otherwise is just a variable def _function_filter_process(func): diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 5516096..5bc5c80 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,7 +1,6 @@ matplotlib numpy pandas -pathlib pyyaml termcolor Ipython \ No newline at end of file From 5b4f494b7e8088e7d1ab690cc001ce68e9070d69 Mon Sep 17 00:00:00 2001 From: Munchic Date: Mon, 10 Aug 2020 23:21:51 -0300 Subject: [PATCH 02/19] plot only requested plots + check for bad inputs --- dataforest/hooks/hooks/core/hooks.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/dataforest/hooks/hooks/core/hooks.py b/dataforest/hooks/hooks/core/hooks.py index c8ccdc0..e7af8fc 100644 --- a/dataforest/hooks/hooks/core/hooks.py +++ b/dataforest/hooks/hooks/core/hooks.py @@ -123,6 +123,15 @@ def hook_catalogue(dp): @hook def hook_generate_plots(dp: dataprocess): - plot_methods = dp.branch.plot.plot_method_lookup - for method in plot_methods.values(): - method(dp.branch) + plot_sources = dp.branch.plot.plot_method_lookup + + current_process = dp.branch.current_process + requested_plot_methods = dp.branch.plot.plot_methods[current_process] + + for method in plot_sources.values(): + if method.__name__ in requested_plot_methods: + method(dp.branch) + requested_plot_methods.remove(method.__name__) + + if len(requested_plot_methods) > 0: # if not all requested mapped to functions in plot sources + logging.warning(f"Requested plotting methods {requested_plot_methods} are invalid so they were skipped.") From 08385eed361899dabd0394be0ed821ff8eb185d9 Mon Sep 17 00:00:00 2001 From: Munchic Date: Mon, 10 Aug 2020 23:37:29 -0300 Subject: [PATCH 03/19] dumb mistake that modifies plot_methods list --- dataforest/hooks/hooks/core/hooks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dataforest/hooks/hooks/core/hooks.py b/dataforest/hooks/hooks/core/hooks.py index e7af8fc..2fd1dc5 100644 --- a/dataforest/hooks/hooks/core/hooks.py +++ b/dataforest/hooks/hooks/core/hooks.py @@ -1,6 +1,7 @@ import gc import logging from pathlib import Path +from copy import deepcopy import pandas as pd import yaml @@ -126,7 +127,7 @@ def hook_generate_plots(dp: dataprocess): plot_sources = dp.branch.plot.plot_method_lookup current_process = dp.branch.current_process - requested_plot_methods = dp.branch.plot.plot_methods[current_process] + requested_plot_methods = deepcopy(dp.branch.plot.plot_methods[current_process]) for method in plot_sources.values(): if method.__name__ in requested_plot_methods: From 6aa1f4cb936c44f9f88dbb64696c87b96a267fcd Mon Sep 17 00:00:00 2001 From: Munchic Date: Tue, 11 Aug 2020 16:43:26 -0300 Subject: [PATCH 04/19] create new folder with from plot name --- dataforest/core/PlotMethods.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dataforest/core/PlotMethods.py b/dataforest/core/PlotMethods.py index 09781a2..d423feb 100644 --- a/dataforest/core/PlotMethods.py +++ b/dataforest/core/PlotMethods.py @@ -35,7 +35,13 @@ def _wrap(self, method): def wrapped(branch, method_name, *args, stop_on_error: bool = False, **kwargs): try: process_run = branch[branch.current_process] - plot_dir = process_run.plot_map[method_name].parent + plot_name = method_name.replace("plot_", "", 1) + + if plot_name in process_run.plot_map: + plot_dir = process_run.plot_map[plot_name].parent + elif method_name in process_run.plot_map: + plot_dir = process_run.plot_map[method_name].parent + plot_dir.mkdir(exist_ok=True) return method(branch, *args, **kwargs) except Exception as e: From 79c2a3d1b8c2de22d6ab5d325185fb473164b0d2 Mon Sep 17 00:00:00 2001 From: Munchic Date: Thu, 13 Aug 2020 20:40:32 -0300 Subject: [PATCH 05/19] pass plot kwargs from config to plots --- dataforest/config/MetaPlotMethods.py | 8 +++ dataforest/core/PlotMethods.py | 8 +++ dataforest/hooks/hooks/core/functions.py | 79 ++++++++++++++++++++++++ dataforest/hooks/hooks/core/hooks.py | 15 +++-- 4 files changed, 105 insertions(+), 5 deletions(-) create mode 100644 dataforest/hooks/hooks/core/functions.py diff --git a/dataforest/config/MetaPlotMethods.py b/dataforest/config/MetaPlotMethods.py index 428ccca..1c81c78 100644 --- a/dataforest/config/MetaPlotMethods.py +++ b/dataforest/config/MetaPlotMethods.py @@ -10,3 +10,11 @@ def PLOT_METHOD_LOOKUP(cls): @property def PLOT_METHODS(cls): return cls.CONFIG["plot_methods"] + + @property + def PLOT_KWARGS_DEFAULTS(cls): + return cls.CONFIG["plot_kwargs_defaults"] + + @property + def PLOT_KWARGS(cls): + return cls.CONFIG["plot_kwargs"] diff --git a/dataforest/core/PlotMethods.py b/dataforest/core/PlotMethods.py index d423feb..85f6d24 100644 --- a/dataforest/core/PlotMethods.py +++ b/dataforest/core/PlotMethods.py @@ -28,6 +28,14 @@ def plot_method_lookup(self): def plot_methods(self): return self.__class__.PLOT_METHODS + @property + def plot_kwargs_defaults(self): + return self.__class__.PLOT_KWARGS_DEFAULTS + + @property + def plot_kwargs(self): + return self.__class__.PLOT_KWARGS + def _wrap(self, method): """Wrap with mkdirs and logging""" diff --git a/dataforest/hooks/hooks/core/functions.py b/dataforest/hooks/hooks/core/functions.py new file mode 100644 index 0000000..e03e7d2 --- /dev/null +++ b/dataforest/hooks/hooks/core/functions.py @@ -0,0 +1,79 @@ +from copy import deepcopy + +from dataforest.hooks.dataprocess import dataprocess + + +def _unify_kwargs_opt_lens(plot_kwargs, plot_name): + """ + Make all kwarg option counts equal so that we can get aligned in order options + + Examples: + >>> _unify_kwargs_opt_lens( + >>> { + >>> "stratify": ["sample", "none"], + >>> "plot_size": "default" + >>> } + >>> ) + # output + { + "stratify": ["sample", "none"], + "plot_size": ["default", "default"] + } + """ + if "plot_size" not in plot_kwargs: + plot_kwargs["plot_size"] = "default" + + kwargs_num_options = set() # number of options for each kwarg + for key, values in plot_kwargs.items(): + if type(values) != list: + plot_kwargs[key] = [values] + else: + kwargs_num_options.add(len(values)) + try: + max_num_opts = max(kwargs_num_options) + kwargs_num_options.remove(max_num_opts) + except ValueError: + max_num_opts = 1 # means that there are no lists and just singular arguments + if len(kwargs_num_options) > 0: # check if the lists of options are equal to each other + raise ValueError( + f"plot_kwargs['{plot_name}'] contains arguments with unequal number of options, should include the same number of options where there are multiple options or a single option." + ) + + for key, values in plot_kwargs.items(): + if type(values) != list: + plot_kwargs[key] = [values] * max_num_opts # stretch to the same length + elif len(values) == 1: + plot_kwargs[key] = values * max_num_opts + + return plot_kwargs + + +def _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults): + """Map kwargs options to values defined in kwargs defaults""" + for key, values in plot_kwargs.items(): + if key in plot_kwargs_defaults: + mapped_values = [] + for val in values: + opts_mapping = plot_kwargs_defaults.get(key) + if val in opts_mapping: + mapped_values.append(opts_mapping[val]) + else: + raise KeyError(f"Option '{val}' is not defined for '{key}', check `plot_kwargs_defaults`") + + plot_kwargs[key] = mapped_values + + return plot_kwargs + + +def _get_all_plot_kwargs(dp: dataprocess, plot_name): + """Creates a list of dictionaries with singular value for each kwarg from kwarg value lists""" + plot_kwargs_defaults = dp.branch.plot.plot_kwargs_defaults + all_plot_kwargs = deepcopy(dp.branch.plot.plot_kwargs) # make singular elements a list + plot_kwargs = all_plot_kwargs.get(plot_name, {"plot_size": "default", "stratify": "none"}) + + plot_kwargs = _unify_kwargs_opt_lens(plot_kwargs, plot_name) + plot_kwargs = _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults) + + return [ + dict(j) for j in zip(*[[(k, i) for i in v] for k, v in plot_kwargs.items()]) + ] # 1-1 mapping of kwargs options diff --git a/dataforest/hooks/hooks/core/hooks.py b/dataforest/hooks/hooks/core/hooks.py index 2fd1dc5..bf76b89 100644 --- a/dataforest/hooks/hooks/core/hooks.py +++ b/dataforest/hooks/hooks/core/hooks.py @@ -10,6 +10,7 @@ from dataforest.utils.catalogue import run_id_from_multi_row from dataforest.utils.exceptions import InputDataNotFound from dataforest.hooks.hook import hook +from dataforest.hooks.hooks.core.functions import _get_all_plot_kwargs @hook @@ -125,14 +126,18 @@ def hook_catalogue(dp): @hook def hook_generate_plots(dp: dataprocess): plot_sources = dp.branch.plot.plot_method_lookup - current_process = dp.branch.current_process requested_plot_methods = deepcopy(dp.branch.plot.plot_methods[current_process]) for method in plot_sources.values(): - if method.__name__ in requested_plot_methods: - method(dp.branch) - requested_plot_methods.remove(method.__name__) + plot_name = method.__name__ + if plot_name in requested_plot_methods: + all_kwargs = _get_all_plot_kwargs(dp, plot_name) + for kwargs in all_kwargs: + method(dp.branch, **kwargs) + requested_plot_methods.remove(plot_name) if len(requested_plot_methods) > 0: # if not all requested mapped to functions in plot sources - logging.warning(f"Requested plotting methods {requested_plot_methods} are invalid so they were skipped.") + logging.warning( + f"Requested plotting methods {requested_plot_methods} are not implemented so they were skipped." + ) From 583610ff597385d225f06647934cbbc32b61650f Mon Sep 17 00:00:00 2001 From: Munchic Date: Mon, 17 Aug 2020 18:02:46 -0300 Subject: [PATCH 06/19] new plot map is here yay --- dataforest/config/MetaPlotMethods.py | 14 +- dataforest/config/MetaProcessSchema.py | 5 +- dataforest/core/DataBranch.py | 14 +- dataforest/core/PlotMethods.py | 9 +- dataforest/core/ProcessRun.py | 8 +- dataforest/hooks/hooks/core/functions.py | 79 --------- dataforest/hooks/hooks/core/hooks.py | 7 +- dataforest/utils/plots_config.py | 205 +++++++++++++++++++++++ 8 files changed, 244 insertions(+), 97 deletions(-) delete mode 100644 dataforest/hooks/hooks/core/functions.py create mode 100644 dataforest/utils/plots_config.py diff --git a/dataforest/config/MetaPlotMethods.py b/dataforest/config/MetaPlotMethods.py index 1c81c78..31df652 100644 --- a/dataforest/config/MetaPlotMethods.py +++ b/dataforest/config/MetaPlotMethods.py @@ -1,5 +1,6 @@ from dataforest.config.MetaConfig import MetaConfig from dataforest.utils.loaders.collectors import collect_plots +from dataforest.utils.plots_config import parse_plot_methods, parse_plot_kwargs class MetaPlotMethods(MetaConfig): @@ -9,12 +10,19 @@ def PLOT_METHOD_LOOKUP(cls): @property def PLOT_METHODS(cls): - return cls.CONFIG["plot_methods"] + try: + plot_methods = cls.CONFIG["plot_methods"] + except KeyError: + plot_methods = parse_plot_methods(config=cls.CONFIG) + + return plot_methods @property def PLOT_KWARGS_DEFAULTS(cls): return cls.CONFIG["plot_kwargs_defaults"] @property - def PLOT_KWARGS(cls): - return cls.CONFIG["plot_kwargs"] + def PLOT_KWARGS(cls): # TODO-QC: mapping of process, plot to kwargs + plot_kwargs = parse_plot_kwargs(config=cls.CONFIG) + + return plot_kwargs diff --git a/dataforest/config/MetaProcessSchema.py b/dataforest/config/MetaProcessSchema.py index 92d3a87..594811d 100644 --- a/dataforest/config/MetaProcessSchema.py +++ b/dataforest/config/MetaProcessSchema.py @@ -4,6 +4,7 @@ from pathlib import Path from dataforest.config.MetaConfig import MetaConfig +from dataforest.utils.plots_config import parse_plot_map class MetaProcessSchema(MetaConfig): @@ -19,8 +20,8 @@ def FILE_MAP(cls): return cls["file_map"] @property - def PLOT_MAP(cls): - return cls["plot_map"] + def PLOT_MAP(cls): # TODO-QC: process plot map starting here? Make it into a class where you can fetch kwargs? + return parse_plot_map(cls.CONFIG) @property def LAYERS(cls): diff --git a/dataforest/core/DataBranch.py b/dataforest/core/DataBranch.py index 34d844e..3fc936a 100644 --- a/dataforest/core/DataBranch.py +++ b/dataforest/core/DataBranch.py @@ -272,12 +272,16 @@ def create_root_plots(self, plot_kwargs: Optional[Dict[str, dict]] = None): f"plots already present for `root` at {self['root'].plots_path}. To regenerate plots, delete directory" ) return - plot_kwargs = plot_kwargs if plot_kwargs else dict() + + if plot_kwargs == None: + plot_kwargs = self.plot.plot_kwargs["root"] + # plot_kwargs = plot_kwargs if plot_kwargs else dict() root_plot_methods = self.plot.plot_methods.get("root", []) - for name in root_plot_methods: - kwargs = plot_kwargs.get(name, dict()) - method = getattr(self.plot, name) - method(**kwargs) + for plot_name, plot_method in root_plot_methods.items(): + kwargs_sets = plot_kwargs.get(plot_name, dict()) + for kwargs in kwargs_sets.values(): + method = getattr(self.plot, plot_method) + method(**kwargs) def is_process_plots_exist(self, process_name: str) -> bool: return self[process_name].plots_path.exists() diff --git a/dataforest/core/PlotMethods.py b/dataforest/core/PlotMethods.py index 85f6d24..eb017fe 100644 --- a/dataforest/core/PlotMethods.py +++ b/dataforest/core/PlotMethods.py @@ -43,12 +43,13 @@ def _wrap(self, method): def wrapped(branch, method_name, *args, stop_on_error: bool = False, **kwargs): try: process_run = branch[branch.current_process] - plot_name = method_name.replace("plot_", "", 1) + for key, value in branch.plot.plot_methods[branch.current_process].items(): + if value == method_name: + plot_name = key # look up plot name from plot_method name if plot_name in process_run.plot_map: - plot_dir = process_run.plot_map[plot_name].parent - elif method_name in process_run.plot_map: - plot_dir = process_run.plot_map[method_name].parent + for plot_kwargs_key, plot_filename in process_run.plot_map[plot_name].items(): + plot_dir = plot_filename.parent # only need one sample dir plot_dir.mkdir(exist_ok=True) return method(branch, *args, **kwargs) diff --git a/dataforest/core/ProcessRun.py b/dataforest/core/ProcessRun.py index 4c31e63..dc99ace 100644 --- a/dataforest/core/ProcessRun.py +++ b/dataforest/core/ProcessRun.py @@ -85,7 +85,13 @@ def process_path_map(self) -> Dict[str, Path]: @property def process_plot_map(self) -> Dict[str, Path]: - return {file_alias: self.plots_path / self._plot_lookup[file_alias] for file_alias in self._plot_lookup} + plot_map_dict = {} + for plot_name in self._plot_lookup: + plot_map_dict[plot_name] = {} + for plot_kwargs_key, plot_filepath in self._plot_lookup[plot_name].items(): + plot_map_dict[plot_name][plot_kwargs_key] = self.plots_path / plot_filepath + + return plot_map_dict @property def path_map(self) -> Dict[str, Path]: diff --git a/dataforest/hooks/hooks/core/functions.py b/dataforest/hooks/hooks/core/functions.py deleted file mode 100644 index e03e7d2..0000000 --- a/dataforest/hooks/hooks/core/functions.py +++ /dev/null @@ -1,79 +0,0 @@ -from copy import deepcopy - -from dataforest.hooks.dataprocess import dataprocess - - -def _unify_kwargs_opt_lens(plot_kwargs, plot_name): - """ - Make all kwarg option counts equal so that we can get aligned in order options - - Examples: - >>> _unify_kwargs_opt_lens( - >>> { - >>> "stratify": ["sample", "none"], - >>> "plot_size": "default" - >>> } - >>> ) - # output - { - "stratify": ["sample", "none"], - "plot_size": ["default", "default"] - } - """ - if "plot_size" not in plot_kwargs: - plot_kwargs["plot_size"] = "default" - - kwargs_num_options = set() # number of options for each kwarg - for key, values in plot_kwargs.items(): - if type(values) != list: - plot_kwargs[key] = [values] - else: - kwargs_num_options.add(len(values)) - try: - max_num_opts = max(kwargs_num_options) - kwargs_num_options.remove(max_num_opts) - except ValueError: - max_num_opts = 1 # means that there are no lists and just singular arguments - if len(kwargs_num_options) > 0: # check if the lists of options are equal to each other - raise ValueError( - f"plot_kwargs['{plot_name}'] contains arguments with unequal number of options, should include the same number of options where there are multiple options or a single option." - ) - - for key, values in plot_kwargs.items(): - if type(values) != list: - plot_kwargs[key] = [values] * max_num_opts # stretch to the same length - elif len(values) == 1: - plot_kwargs[key] = values * max_num_opts - - return plot_kwargs - - -def _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults): - """Map kwargs options to values defined in kwargs defaults""" - for key, values in plot_kwargs.items(): - if key in plot_kwargs_defaults: - mapped_values = [] - for val in values: - opts_mapping = plot_kwargs_defaults.get(key) - if val in opts_mapping: - mapped_values.append(opts_mapping[val]) - else: - raise KeyError(f"Option '{val}' is not defined for '{key}', check `plot_kwargs_defaults`") - - plot_kwargs[key] = mapped_values - - return plot_kwargs - - -def _get_all_plot_kwargs(dp: dataprocess, plot_name): - """Creates a list of dictionaries with singular value for each kwarg from kwarg value lists""" - plot_kwargs_defaults = dp.branch.plot.plot_kwargs_defaults - all_plot_kwargs = deepcopy(dp.branch.plot.plot_kwargs) # make singular elements a list - plot_kwargs = all_plot_kwargs.get(plot_name, {"plot_size": "default", "stratify": "none"}) - - plot_kwargs = _unify_kwargs_opt_lens(plot_kwargs, plot_name) - plot_kwargs = _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults) - - return [ - dict(j) for j in zip(*[[(k, i) for i in v] for k, v in plot_kwargs.items()]) - ] # 1-1 mapping of kwargs options diff --git a/dataforest/hooks/hooks/core/hooks.py b/dataforest/hooks/hooks/core/hooks.py index bf76b89..20d81de 100644 --- a/dataforest/hooks/hooks/core/hooks.py +++ b/dataforest/hooks/hooks/core/hooks.py @@ -10,7 +10,6 @@ from dataforest.utils.catalogue import run_id_from_multi_row from dataforest.utils.exceptions import InputDataNotFound from dataforest.hooks.hook import hook -from dataforest.hooks.hooks.core.functions import _get_all_plot_kwargs @hook @@ -123,17 +122,19 @@ def hook_catalogue(dp): raise ValueError(f"run_id: {run_id} is not equal to stored: {run_id_stored} for {str(run_spec)}") +# TODO-QC: take a check here @hook def hook_generate_plots(dp: dataprocess): plot_sources = dp.branch.plot.plot_method_lookup current_process = dp.branch.current_process + all_plot_kwargs_sets = dp.branch.plot.plot_kwargs[current_process] requested_plot_methods = deepcopy(dp.branch.plot.plot_methods[current_process]) for method in plot_sources.values(): plot_name = method.__name__ if plot_name in requested_plot_methods: - all_kwargs = _get_all_plot_kwargs(dp, plot_name) - for kwargs in all_kwargs: + plot_kwargs_sets = all_plot_kwargs_sets[plot_name] + for kwargs in plot_kwargs_sets.values(): method(dp.branch, **kwargs) requested_plot_methods.remove(plot_name) diff --git a/dataforest/utils/plots_config.py b/dataforest/utils/plots_config.py new file mode 100644 index 0000000..8168190 --- /dev/null +++ b/dataforest/utils/plots_config.py @@ -0,0 +1,205 @@ +from copy import deepcopy +import json +from collections import OrderedDict +from pathlib import Path + + +def parse_plot_methods(config: dict): + """Parse plot methods per process from plot_map""" + plot_map = config["plot_map"] + plot_methods = {} + for process, plots in plot_map.items(): + plot_methods[process] = {} + for plot_name in plots.keys(): + try: + plot_method = plots[plot_name]["plot_method"] + except (TypeError, KeyError): + plot_method = _get_plot_method_from_plot_name(plot_name) + + plot_methods[process][plot_name] = plot_method + + return plot_methods + + +def parse_plot_kwargs(config: dict): + """Parse plot methods kwargs per process from plot_map""" + plot_map = config["plot_map"] + all_plot_kwargs = {} + for process, plots in plot_map.items(): + all_plot_kwargs[process] = {} + + for plot_name in plots.keys(): + all_plot_kwargs[process][plot_name] = {} + + try: + plot_kwargs = plots[plot_name]["plot_kwargs"] + except (KeyError, TypeError): + plot_kwargs = _get_default_plot_kwargs(config) + + kwargs_feed = _get_plot_kwargs_feed(plot_kwargs, plot_name) + kwargs_feed_mapped = _get_plot_kwargs_feed( + plot_kwargs, plot_name, map_to_default_kwargs=True, plot_kwargs_defaults=config["plot_kwargs_defaults"] + ) + + for plot_kwargs_set, plot_kwargs_set_mapped in zip(kwargs_feed, kwargs_feed_mapped): + all_plot_kwargs[process][plot_name][_get_plot_kwargs_string(plot_kwargs_set)] = plot_kwargs_set_mapped + + return all_plot_kwargs + + +def parse_plot_map(config: dict): + """ + Parse plot file map per process from plot_map and ensures that + implicit definition returns a dictionary of default values for all kwargs + """ + plot_map = config["plot_map"] + all_plot_maps = {} + for process, plots in plot_map.items(): + all_plot_maps[process] = {} + + for plot_name in plots.keys(): + all_plot_maps[process][plot_name] = {} + + try: + all_plot_kwargs = plots[plot_name]["plot_kwargs"] + except (KeyError, TypeError): + all_plot_kwargs = _get_default_plot_kwargs(config) + + kwargs_feed = _get_plot_kwargs_feed(all_plot_kwargs, plot_name) + kwargs_feed_mapped = _get_plot_kwargs_feed( + all_plot_kwargs, + plot_name, + map_to_default_kwargs=True, + plot_kwargs_defaults=config["plot_kwargs_defaults"], + ) + + for i, (plot_kwargs_set, plot_kwargs_set_mapped) in enumerate(zip(kwargs_feed, kwargs_feed_mapped)): + try: + plot_filename = plots[plot_name]["filename"] + if type(plot_filename) == list: + plot_filename = Path(plot_filename[i]) + except (KeyError, TypeError): + plot_filename = _get_default_plot_filename( + plot_name, plot_kwargs_set_mapped, config["plot_kwargs_defaults"] + ) + + all_plot_maps[process][plot_name][_get_plot_kwargs_string(plot_kwargs_set)] = plot_filename + + return all_plot_maps + + +def _get_plot_method_from_plot_name(plot_name): + """Infer plot method name from plot name, e.g. _UMIS_PER_CELL_HIST_ -> plot_umis_per_cell_hist""" + if plot_name[0] == "_": + plot_name = plot_name[1:] + if plot_name[-1] == "_": + plot_name = plot_name[:-1] + formatted_plot_name = plot_name.lower() + plot_method = "plot_" + formatted_plot_name + + return plot_method + + +def _unify_kwargs_opt_lens(plot_kwargs, plot_name): + """ + Make all kwarg option counts equal so that we can get aligned in order options + + Examples: + >>> _unify_kwargs_opt_lens( + >>> { + >>> "stratify": ["sample", "none"], + >>> "plot_size": "default" + >>> } + >>> ) + # output + { + "stratify": ["sample", "none"], + "plot_size": ["default", "default"] + } + """ + kwargs_num_options = set() # number of options for each kwarg + for key, values in plot_kwargs.items(): + if type(values) != list: + plot_kwargs[key] = [values] + else: + kwargs_num_options.add(len(values)) + try: + max_num_opts = max(kwargs_num_options) + kwargs_num_options.remove(max_num_opts) + except ValueError: + max_num_opts = 1 # means that there are no lists and just singular arguments + if len(kwargs_num_options) > 0: # check if the lists of options are equal to each other + raise ValueError( + f"plot_kwargs['{plot_name}'] contains arguments with unequal number of options, should include the same number of options where there are multiple options or a single option." + ) + + for key, values in plot_kwargs.items(): + if type(values) != list: + plot_kwargs[key] = [values] * max_num_opts # stretch to the same length + elif len(values) == 1: + plot_kwargs[key] = values * max_num_opts + + return plot_kwargs + + +def _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults): + """Map plot_kwargs to values defined in kwargs defaults if available""" + mapped_plot_kwargs = deepcopy(plot_kwargs) + + for kwarg_name, kwarg_values in plot_kwargs.items(): + if kwarg_name in plot_kwargs_defaults: + mapping = [] + for val in kwarg_values: + try: + mapping.append(plot_kwargs_defaults[kwarg_name].get(val, val)) + except TypeError: + mapping.append(val) + + mapped_plot_kwargs[kwarg_name] = mapping + + return mapped_plot_kwargs + + +def _get_plot_kwargs_feed(plot_kwargs: dict, plot_name: str, map_to_default_kwargs=False, plot_kwargs_defaults=None): + plot_kwargs = _unify_kwargs_opt_lens(plot_kwargs, plot_name) + if map_to_default_kwargs: + plot_kwargs = _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults) + plot_kwargs_feed = [ + dict(j) for j in zip(*[[(k, i) for i in v] for k, v in plot_kwargs.items()]) + ] # 1-1 mapping of kwargs options + + return plot_kwargs_feed + + +def _get_default_plot_kwargs(config: dict): + kwargs_keys = config["plot_kwargs_defaults"].keys() + default_plot_kwargs = dict(zip(kwargs_keys, ["default"] * len(kwargs_keys))) + + return default_plot_kwargs + + +def _get_filename_from_plot_kwargs(plot_filename, plot_kwargs): + suffix_chain = "" + for key, value in sorted(list(plot_kwargs.items())): + suffix_chain += f"-{key}:{value}".replace(" ", "").lower() # remove spaces + + return plot_filename + suffix_chain + + +def _get_plot_kwargs_string(plot_kwargs: dict): # TODO-QC: proper type checking + ord_plot_kwargs = OrderedDict(sorted(plot_kwargs.items())) + + return json.dumps(ord_plot_kwargs) + + +def _get_default_plot_filename(plot_name: str, plot_kwargs: dict, plot_kwargs_defaults: dict): + """Infer plot filename from plot name, e.g. _UMIS_PER_CELL_HIST_ -> umis_per_cell_hist.png""" + filename_ext = "." + plot_kwargs_defaults.get("filename_ext", "png").lower().replace(".", "") + + if plot_name[0] == "_": + plot_name = plot_name[1:] + if plot_name[-1] == "_": + plot_name = plot_name[:-1] + plot_filename = _get_filename_from_plot_kwargs(plot_name.lower(), plot_kwargs) + filename_ext + + return Path(plot_filename) From e28b7af86b91bedc09bb76bf5e0335d3f63aa915 Mon Sep 17 00:00:00 2001 From: Munchic Date: Mon, 17 Aug 2020 21:10:22 -0300 Subject: [PATCH 07/19] root and normalize plotting work now yay! --- dataforest/core/DataBranch.py | 7 +++++-- dataforest/core/PlotMethods.py | 7 ++++--- dataforest/hooks/hooks/core/hooks.py | 17 ++++++++++++----- dataforest/utils/plots_config.py | 9 +++++++++ 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/dataforest/core/DataBranch.py b/dataforest/core/DataBranch.py index 3fc936a..f7a399a 100644 --- a/dataforest/core/DataBranch.py +++ b/dataforest/core/DataBranch.py @@ -275,12 +275,15 @@ def create_root_plots(self, plot_kwargs: Optional[Dict[str, dict]] = None): if plot_kwargs == None: plot_kwargs = self.plot.plot_kwargs["root"] - # plot_kwargs = plot_kwargs if plot_kwargs else dict() + root_plot_map = self["root"].plot_map root_plot_methods = self.plot.plot_methods.get("root", []) + for plot_name, plot_method in root_plot_methods.items(): kwargs_sets = plot_kwargs.get(plot_name, dict()) - for kwargs in kwargs_sets.values(): + for plot_kwargs_key, _kwargs in kwargs_sets.items(): method = getattr(self.plot, plot_method) + kwargs = deepcopy(_kwargs) + kwargs["plot_path"] = root_plot_map[plot_name][plot_kwargs_key] method(**kwargs) def is_process_plots_exist(self, process_name: str) -> bool: diff --git a/dataforest/core/PlotMethods.py b/dataforest/core/PlotMethods.py index eb017fe..1296591 100644 --- a/dataforest/core/PlotMethods.py +++ b/dataforest/core/PlotMethods.py @@ -3,6 +3,7 @@ from dataforest.config.MetaPlotMethods import MetaPlotMethods from dataforest.utils import tether, copy_func from dataforest.utils.ExceptionHandler import ExceptionHandler +from dataforest.utils.plots_config import get_plot_name_from_plot_method class PlotMethods(metaclass=MetaPlotMethods): @@ -43,9 +44,9 @@ def _wrap(self, method): def wrapped(branch, method_name, *args, stop_on_error: bool = False, **kwargs): try: process_run = branch[branch.current_process] - for key, value in branch.plot.plot_methods[branch.current_process].items(): - if value == method_name: - plot_name = key # look up plot name from plot_method name + plot_name = get_plot_name_from_plot_method( + branch.plot.plot_methods[branch.current_process], method_name + ) if plot_name in process_run.plot_map: for plot_kwargs_key, plot_filename in process_run.plot_map[plot_name].items(): diff --git a/dataforest/hooks/hooks/core/hooks.py b/dataforest/hooks/hooks/core/hooks.py index 20d81de..f671808 100644 --- a/dataforest/hooks/hooks/core/hooks.py +++ b/dataforest/hooks/hooks/core/hooks.py @@ -9,6 +9,7 @@ from dataforest.hooks.dataprocess import dataprocess from dataforest.utils.catalogue import run_id_from_multi_row from dataforest.utils.exceptions import InputDataNotFound +from dataforest.utils.plots_config import get_plot_name_from_plot_method from dataforest.hooks.hook import hook @@ -128,15 +129,21 @@ def hook_generate_plots(dp: dataprocess): plot_sources = dp.branch.plot.plot_method_lookup current_process = dp.branch.current_process all_plot_kwargs_sets = dp.branch.plot.plot_kwargs[current_process] - requested_plot_methods = deepcopy(dp.branch.plot.plot_methods[current_process]) + process_plot_methods = dp.branch.plot.plot_methods[current_process] + process_plot_map = dp.branch[dp.branch.current_process].plot_map + requested_plot_methods = deepcopy(process_plot_methods) for method in plot_sources.values(): - plot_name = method.__name__ - if plot_name in requested_plot_methods: + plot_method_name = method.__name__ + if plot_method_name in requested_plot_methods.values(): + plot_name = get_plot_name_from_plot_method(process_plot_methods, plot_method_name) plot_kwargs_sets = all_plot_kwargs_sets[plot_name] - for kwargs in plot_kwargs_sets.values(): + for plot_kwargs_key in plot_kwargs_sets.keys(): + plot_path = process_plot_map[plot_name][plot_kwargs_key] + kwargs = deepcopy(plot_kwargs_sets[plot_kwargs_key]) + kwargs["plot_path"] = plot_path method(dp.branch, **kwargs) - requested_plot_methods.remove(plot_name) + requested_plot_methods.pop(plot_name) if len(requested_plot_methods) > 0: # if not all requested mapped to functions in plot sources logging.warning( diff --git a/dataforest/utils/plots_config.py b/dataforest/utils/plots_config.py index 8168190..a3ffdc7 100644 --- a/dataforest/utils/plots_config.py +++ b/dataforest/utils/plots_config.py @@ -88,6 +88,15 @@ def parse_plot_map(config: dict): return all_plot_maps +def get_plot_name_from_plot_method(process_plot_methods, plot_method_name): + """Reverse search for plot name in the config from plot method used""" + for key, value in process_plot_methods.items(): + if value == plot_method_name: + plot_name = key # look up plot name from plot_method name + + return plot_name + + def _get_plot_method_from_plot_name(plot_name): """Infer plot method name from plot name, e.g. _UMIS_PER_CELL_HIST_ -> plot_umis_per_cell_hist""" if plot_name[0] == "_": From 8cb55e70058464ca6c4d88d3a40f140b6dabe863 Mon Sep 17 00:00:00 2001 From: Munchic Date: Mon, 17 Aug 2020 21:53:37 -0300 Subject: [PATCH 08/19] new plot filepath --- dataforest/core/BranchSpec.py | 5 +++++ dataforest/utils/plots_config.py | 38 ++++++++++++++++++++++++++------ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/dataforest/core/BranchSpec.py b/dataforest/core/BranchSpec.py index 4bdf1e7..f5d4c31 100644 --- a/dataforest/core/BranchSpec.py +++ b/dataforest/core/BranchSpec.py @@ -79,6 +79,11 @@ def __init__(self, spec: Union[List[dict], "BranchSpec[RunSpec]"]): ) self.process_order: List[str] = [spec_item.name for spec_item in self] + @property + def shell_str(self): + """string version which can be passed via shell and loaded via json""" + return f"'{str(self)}'" + def copy(self) -> "BranchSpec": return deepcopy(self) diff --git a/dataforest/utils/plots_config.py b/dataforest/utils/plots_config.py index a3ffdc7..736c50e 100644 --- a/dataforest/utils/plots_config.py +++ b/dataforest/utils/plots_config.py @@ -187,12 +187,36 @@ def _get_default_plot_kwargs(config: dict): return default_plot_kwargs -def _get_filename_from_plot_kwargs(plot_filename, plot_kwargs): - suffix_chain = "" - for key, value in sorted(list(plot_kwargs.items())): - suffix_chain += f"-{key}:{value}".replace(" ", "").lower() # remove spaces - - return plot_filename + suffix_chain +def plot_kwargs_to_str(plot_kwargs): + """ + Converts plot_kwargs dictionary into a deterministic string, sorted by keys + """ + UP = "-" + DOWN = ":" + AND = "+" + str_chain = "" + + def _helper(dict_: dict): + nonlocal str_chain + for key in sorted(dict_): + val = dict_[key] + str_chain += str(key) + if val is not None: + str_chain += DOWN + if val is None: + pass + elif isinstance(val, dict): + _helper(val) + elif isinstance(val, (set, list, tuple)): + str_chain += AND.join(map(str, val)) + else: + str_chain += str(val) + str_chain += UP + + _helper(plot_kwargs) + str_chain = str_chain.strip("-").lower() + + return str_chain def _get_plot_kwargs_string(plot_kwargs: dict): # TODO-QC: proper type checking @@ -209,6 +233,6 @@ def _get_default_plot_filename(plot_name: str, plot_kwargs: dict, plot_kwargs_de plot_name = plot_name[1:] if plot_name[-1] == "_": plot_name = plot_name[:-1] - plot_filename = _get_filename_from_plot_kwargs(plot_name.lower(), plot_kwargs) + filename_ext + plot_filename = plot_name.lower() + "-" + plot_kwargs_to_str(plot_kwargs) + filename_ext return Path(plot_filename) From 298ee24c958834887b531dfac862597ce4ddec19 Mon Sep 17 00:00:00 2001 From: Munchic Date: Wed, 19 Aug 2020 12:01:19 -0300 Subject: [PATCH 09/19] ignore precursors for plot_map --- dataforest/core/ProcessRun.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataforest/core/ProcessRun.py b/dataforest/core/ProcessRun.py index dc99ace..d1a2006 100644 --- a/dataforest/core/ProcessRun.py +++ b/dataforest/core/ProcessRun.py @@ -220,7 +220,7 @@ def _build_path_map(self, incl_current: bool = False, plot_map: bool = False) -> spec = self.branch.spec precursor_lookup = spec.get_precursors_lookup(incl_current=incl_current, incl_root=True) precursors = precursor_lookup[self.process_name] - process_runs = [self.branch[process_name] for process_name in precursors] + process_runs = [self] if plot_map else [self.branch[process_name] for process_name in precursors] pr_attr = "process_plot_map" if plot_map else "process_path_map" process_path_map_list = [getattr(pr, pr_attr) for pr in process_runs] path_map = dict() From 53ec5ac5903c1f09a447513632b81f3d38aaa7d3 Mon Sep 17 00:00:00 2001 From: Munchic Date: Wed, 19 Aug 2020 14:29:29 -0300 Subject: [PATCH 10/19] use plot_kwargs_defaults as template for default --- dataforest/utils/plots_config.py | 62 +++++++++++++++++++++----------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/dataforest/utils/plots_config.py b/dataforest/utils/plots_config.py index 736c50e..25b0a29 100644 --- a/dataforest/utils/plots_config.py +++ b/dataforest/utils/plots_config.py @@ -1,7 +1,6 @@ from copy import deepcopy import json from collections import OrderedDict -from pathlib import Path def parse_plot_methods(config: dict): @@ -24,6 +23,7 @@ def parse_plot_methods(config: dict): def parse_plot_kwargs(config: dict): """Parse plot methods kwargs per process from plot_map""" plot_map = config["plot_map"] + plot_kwargs_defaults = config["plot_kwargs_defaults"] all_plot_kwargs = {} for process, plots in plot_map.items(): all_plot_kwargs[process] = {} @@ -34,11 +34,11 @@ def parse_plot_kwargs(config: dict): try: plot_kwargs = plots[plot_name]["plot_kwargs"] except (KeyError, TypeError): - plot_kwargs = _get_default_plot_kwargs(config) + plot_kwargs = _get_default_plot_kwargs(plot_kwargs_defaults) - kwargs_feed = _get_plot_kwargs_feed(plot_kwargs, plot_name) + kwargs_feed = _get_plot_kwargs_feed(plot_kwargs, plot_kwargs_defaults, plot_name) kwargs_feed_mapped = _get_plot_kwargs_feed( - plot_kwargs, plot_name, map_to_default_kwargs=True, plot_kwargs_defaults=config["plot_kwargs_defaults"] + plot_kwargs, plot_kwargs_defaults, plot_name, map_to_default_kwargs=True, ) for plot_kwargs_set, plot_kwargs_set_mapped in zip(kwargs_feed, kwargs_feed_mapped): @@ -53,6 +53,7 @@ def parse_plot_map(config: dict): implicit definition returns a dictionary of default values for all kwargs """ plot_map = config["plot_map"] + plot_kwargs_defaults = config["plot_kwargs_defaults"] all_plot_maps = {} for process, plots in plot_map.items(): all_plot_maps[process] = {} @@ -63,25 +64,23 @@ def parse_plot_map(config: dict): try: all_plot_kwargs = plots[plot_name]["plot_kwargs"] except (KeyError, TypeError): - all_plot_kwargs = _get_default_plot_kwargs(config) + all_plot_kwargs = _get_default_plot_kwargs(plot_kwargs_defaults) - kwargs_feed = _get_plot_kwargs_feed(all_plot_kwargs, plot_name) + kwargs_feed = _get_plot_kwargs_feed(all_plot_kwargs, plot_kwargs_defaults, plot_name) kwargs_feed_mapped = _get_plot_kwargs_feed( - all_plot_kwargs, - plot_name, + plot_kwargs=all_plot_kwargs, + plot_kwargs_defaults=plot_kwargs_defaults, + plot_name=plot_name, map_to_default_kwargs=True, - plot_kwargs_defaults=config["plot_kwargs_defaults"], ) for i, (plot_kwargs_set, plot_kwargs_set_mapped) in enumerate(zip(kwargs_feed, kwargs_feed_mapped)): try: plot_filename = plots[plot_name]["filename"] if type(plot_filename) == list: - plot_filename = Path(plot_filename[i]) + plot_filename = _get_formatted_plot_filename(plot_filename[i], plot_kwargs_defaults) except (KeyError, TypeError): - plot_filename = _get_default_plot_filename( - plot_name, plot_kwargs_set_mapped, config["plot_kwargs_defaults"] - ) + plot_filename = _get_default_plot_filename(plot_name, plot_kwargs_set_mapped, plot_kwargs_defaults) all_plot_maps[process][plot_name][_get_plot_kwargs_string(plot_kwargs_set)] = plot_filename @@ -109,7 +108,7 @@ def _get_plot_method_from_plot_name(plot_name): return plot_method -def _unify_kwargs_opt_lens(plot_kwargs, plot_name): +def _unify_kwargs_opt_lens(plot_kwargs: dict, plot_kwargs_defaults: dict, plot_name: str): """ Make all kwarg option counts equal so that we can get aligned in order options @@ -118,7 +117,8 @@ def _unify_kwargs_opt_lens(plot_kwargs, plot_name): >>> { >>> "stratify": ["sample", "none"], >>> "plot_size": "default" - >>> } + >>> }, + >>> plot_kwargs_defaults, plot_name >>> ) # output { @@ -139,9 +139,15 @@ def _unify_kwargs_opt_lens(plot_kwargs, plot_name): max_num_opts = 1 # means that there are no lists and just singular arguments if len(kwargs_num_options) > 0: # check if the lists of options are equal to each other raise ValueError( - f"plot_kwargs['{plot_name}'] contains arguments with unequal number of options, should include the same number of options where there are multiple options or a single option." + f"'{plot_name}' contains arguments with unequal number of options, should include the same number of options where there are multiple options or a single option." ) + # fill in plot_kwargs that are not defined + template_plot_kwargs = _get_default_plot_kwargs(plot_kwargs_defaults) + for key, value in template_plot_kwargs.items(): + if key not in plot_kwargs: + plot_kwargs[key] = value + for key, values in plot_kwargs.items(): if type(values) != list: plot_kwargs[key] = [values] * max_num_opts # stretch to the same length @@ -169,8 +175,8 @@ def _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults): return mapped_plot_kwargs -def _get_plot_kwargs_feed(plot_kwargs: dict, plot_name: str, map_to_default_kwargs=False, plot_kwargs_defaults=None): - plot_kwargs = _unify_kwargs_opt_lens(plot_kwargs, plot_name) +def _get_plot_kwargs_feed(plot_kwargs: dict, plot_kwargs_defaults: dict, plot_name: str, map_to_default_kwargs=False): + plot_kwargs = _unify_kwargs_opt_lens(plot_kwargs, plot_kwargs_defaults, plot_name) if map_to_default_kwargs: plot_kwargs = _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults) plot_kwargs_feed = [ @@ -180,8 +186,12 @@ def _get_plot_kwargs_feed(plot_kwargs: dict, plot_name: str, map_to_default_kwar return plot_kwargs_feed -def _get_default_plot_kwargs(config: dict): - kwargs_keys = config["plot_kwargs_defaults"].keys() +def _get_default_plot_kwargs(plot_kwargs_defaults: dict): + kwargs_keys = list(plot_kwargs_defaults.keys()) + for kwargs_key in reversed(kwargs_keys): + if "filename" in kwargs_key: # ignore filename-related args (e.g., plot filename extension) + kwargs_keys.remove(kwargs_key) + default_plot_kwargs = dict(zip(kwargs_keys, ["default"] * len(kwargs_keys))) return default_plot_kwargs @@ -225,6 +235,16 @@ def _get_plot_kwargs_string(plot_kwargs: dict): # TODO-QC: proper type checking return json.dumps(ord_plot_kwargs) +def _get_formatted_plot_filename(plot_name: str, plot_kwargs_defaults: dict): + filename_ext = "." + plot_kwargs_defaults.get("filename_ext", "png").lower().replace(".", "") + + plot_filename = plot_name.lower() + if "." not in plot_name: # if plot map doesn't have extension yet + plot_filename += filename_ext + + return plot_filename + + def _get_default_plot_filename(plot_name: str, plot_kwargs: dict, plot_kwargs_defaults: dict): """Infer plot filename from plot name, e.g. _UMIS_PER_CELL_HIST_ -> umis_per_cell_hist.png""" filename_ext = "." + plot_kwargs_defaults.get("filename_ext", "png").lower().replace(".", "") @@ -235,4 +255,4 @@ def _get_default_plot_filename(plot_name: str, plot_kwargs: dict, plot_kwargs_de plot_name = plot_name[:-1] plot_filename = plot_name.lower() + "-" + plot_kwargs_to_str(plot_kwargs) + filename_ext - return Path(plot_filename) + return plot_filename From 1b510fcc194d8253f1b2a40f0af263c7523c2598 Mon Sep 17 00:00:00 2001 From: Munchic Date: Wed, 19 Aug 2020 15:48:31 -0300 Subject: [PATCH 11/19] ignore undefined plot_map --- dataforest/hooks/hooks/core/hooks.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dataforest/hooks/hooks/core/hooks.py b/dataforest/hooks/hooks/core/hooks.py index f671808..2fcae5d 100644 --- a/dataforest/hooks/hooks/core/hooks.py +++ b/dataforest/hooks/hooks/core/hooks.py @@ -128,7 +128,11 @@ def hook_catalogue(dp): def hook_generate_plots(dp: dataprocess): plot_sources = dp.branch.plot.plot_method_lookup current_process = dp.branch.current_process - all_plot_kwargs_sets = dp.branch.plot.plot_kwargs[current_process] + try: + all_plot_kwargs_sets = dp.branch.plot.plot_kwargs[current_process] + except KeyError: + logging.warning(f"Plot map for '{current_process}' is undefined, skipping plots hook") + return process_plot_methods = dp.branch.plot.plot_methods[current_process] process_plot_map = dp.branch[dp.branch.current_process].plot_map requested_plot_methods = deepcopy(process_plot_methods) From 0b71457f2531712a7b414874dc1a2bde57530441 Mon Sep 17 00:00:00 2001 From: theaustinator Date: Wed, 23 Sep 2020 02:15:08 -0600 Subject: [PATCH 12/19] plotting debugging and notebook compatibility; Merge branch 'dev' into feature/qc-plots # Conflicts: # dataforest/core/DataBranch.py # dataforest/hooks/hooks/core/hooks.py --- .travis.yml | 9 ++-- dataforest/config/MetaPlotMethods.py | 6 +-- dataforest/core/BranchSpec.py | 42 +++++++------------ dataforest/core/DataBranch.py | 11 +++-- dataforest/core/Interface.py | 3 ++ dataforest/core/PlotMethods.py | 37 +++++++++++----- dataforest/core/ProcessRun.py | 2 +- dataforest/core/TreeSpec.py | 2 +- dataforest/hooks/hook.py | 2 +- dataforest/hooks/hooks/core/hooks.py | 16 +++++-- dataforest/utils/plots_config.py | 35 +++++++--------- .../requirements.txt => requirements.txt | 0 setup.py | 7 ++-- tests/test_datatree.py | 2 +- tests/test_tether.py | 2 +- 15 files changed, 95 insertions(+), 81 deletions(-) rename requirements/requirements.txt => requirements.txt (100%) diff --git a/.travis.yml b/.travis.yml index ec1f39b..e11f6f6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,15 +1,12 @@ -dist: xenial +dist: focal language: python python: - '3.6' install: - pip3 install -r requirements.txt -script: - - pytest +script: pytest deploy: provider: pypi user: theaustinator - on: - tags: true password: - secure: + secure: P5b36T+ulGokWYCGqt/JRN9l4p3HwTlQBazpWxhqz3bHQ8GhmIAYRn5xPno0R4hA0dyO0hKtdp8Vkd+RqTmdvv0HJSHvLp4HC5A9JSugChnA5CRnfbBQ8VWMxNrGGJDADaaRFyT7GCzE317c4LOpcLsCqRUnAIytNqm7VvLCL00nQmeQu2b0sYsjXHbHbQQVcWEdMlMptr+hUqGiDXRjG8/1luMvP7ZUe0IBcVcHKp2hwIjTzwqPCyNF0J92l7DEmTHVyfmPYp5ioqdQDvJ5DjN6bBNELD/uEbmq3A8zagm47m46aGc0uBiT7qpKLW4w+fYlVkph2Uvj3qRvnAzWCcRLtItkJyIj7V2ovKSs7btheUmJHM93JnQc6hRqYfdggKLBDgosvLLOab4xBzJW5K7JwD4Wa9NCs3pidJWcF3WSTm6xpTFj19uEd8m9xQ/KPn3UfOAwSTImi+7Ya/LpYMV1xMwuKyQMWqanlOfpi5CFSzBrrBC44lL/ZhXjgEfocxSvCeqm+aK2gObDq2Eymze/OVvvOrBgnilU5D3EMai5ustkj8RjnjSEFJKAJYeXwGE4Yx73Ger0AZDqvLWbn8Pf7rHfcqpSJJx7jIJUeOwhWP6vme1vpqvW8V/VuwJP6ZobkyQ7xmnl5ceNGTYPOtfzhcpw5S7WAukSzpYQKmQ= diff --git a/dataforest/config/MetaPlotMethods.py b/dataforest/config/MetaPlotMethods.py index 31df652..8c61ea8 100644 --- a/dataforest/config/MetaPlotMethods.py +++ b/dataforest/config/MetaPlotMethods.py @@ -1,6 +1,6 @@ from dataforest.config.MetaConfig import MetaConfig from dataforest.utils.loaders.collectors import collect_plots -from dataforest.utils.plots_config import parse_plot_methods, parse_plot_kwargs +from dataforest.utils.plots_config import build_process_plot_method_lookup, parse_plot_kwargs class MetaPlotMethods(MetaConfig): @@ -9,11 +9,11 @@ def PLOT_METHOD_LOOKUP(cls): return {k: v for source in cls.CONFIG["plot_sources"] for k, v in collect_plots(source).items()} @property - def PLOT_METHODS(cls): + def PROCESS_PLOT_METHODS(cls): try: plot_methods = cls.CONFIG["plot_methods"] except KeyError: - plot_methods = parse_plot_methods(config=cls.CONFIG) + plot_methods = build_process_plot_method_lookup(config=cls.CONFIG) return plot_methods diff --git a/dataforest/core/BranchSpec.py b/dataforest/core/BranchSpec.py index f5d4c31..f00a402 100644 --- a/dataforest/core/BranchSpec.py +++ b/dataforest/core/BranchSpec.py @@ -1,3 +1,4 @@ +import json from copy import deepcopy from typing import Union, List, Dict @@ -18,41 +19,21 @@ class BranchSpec(list): >>> # NOTE: conceptual illustration only, not real processes >>> branch_spec = [ >>> { - >>> "_PROCESS_": "normalize", - >>> "_PARAMS_": { - >>> "min_genes": 5, - >>> "max_genes": 5000, - >>> "min_cells": 5, - >>> "nfeatures": 30, - >>> "perc_mito_cutoff": 20, - >>> "method": "seurat_default", - >>> } - >>> "_SUBSET_": { - >>> "indication": {"disease_1", "disease_3"}, - >>> "collection_center": "mass_general", - >>> }, - >>> "_FILTER_": { - >>> "donor": "D115" - >>> } - >>> }, - >>> { >>> "_PROCESS_": "reduce", # dimensionality reduction >>> "_ALIAS_": "linear_dim_reduce", >>> "_PARAMS_": { >>> "algorithm": "pca", - >>> "pca_npcs": 30, + >>> "n_pcs": 30 >>> } >>> }, >>> { >>> "_PROCESS_": "reduce", >>> "_ALIAS_": "nonlinear_dim_reduce", - >>> "_PARAMS_": { - >>> "algorithm": "umap", - >>> "n_neighbors": 15, - >>> "min_dist": 0.1, - >>> "n_components": 2, - >>> "metric": "euclidean" - >>> } + >>> "_PARAMS_": ... + >>> }, + >>> { + >>> "_PROCESS_": "dispersity" + >>> "_PARAMS_": ... >>> } >>> ] >>> branch_spec = BranchSpec(branch_spec) @@ -68,7 +49,11 @@ class BranchSpec(list): process_order: """ - def __init__(self, spec: Union[List[dict], "BranchSpec[RunSpec]"]): + def __init__(self, spec: Union[str, List[dict], "BranchSpec[RunSpec]"]): + if isinstance(spec, str): + spec = json.loads(spec) + if not isinstance(spec, (list, tuple)): + raise ValueError("spec must be convertible to a list or subclass") super().__init__([RunSpec(item) for item in spec]) self._run_spec_lookup: Dict[str, "RunSpec"] = self._build_run_spec_lookup() self._precursors_lookup: Dict[str, List[str]] = self._build_precursors_lookup() @@ -224,3 +209,6 @@ def __setitem__(self, k, v): def __contains__(self, item): return item in self._run_spec_lookup + + def __str__(self): + return super().__str__().replace("'", '"') diff --git a/dataforest/core/DataBranch.py b/dataforest/core/DataBranch.py index f7a399a..a4f838d 100644 --- a/dataforest/core/DataBranch.py +++ b/dataforest/core/DataBranch.py @@ -18,7 +18,7 @@ from dataforest.filesystem.io import ReaderMethods from dataforest.filesystem.io.WriterMethods import WriterMethods from dataforest.utils.exceptions import BadSubset, BadFilter -from dataforest.utils.utils import update_recursive +from dataforest.utils.utils import update_recursive, label_df_partitions class DataBranch(DataBase): @@ -100,7 +100,7 @@ class DataBranch(DataBase): _METADATA_NAME: dict = NotImplementedError("Should be implemented by superclass") _COPY_KWARGS: dict = { "root": "root", - "branch_spec": "branch_spec", + "branch_spec": "spec", "verbose": "verbose", "current_process": "current_process", } @@ -120,13 +120,11 @@ def __init__( self._current_process = current_process self._remote_root = remote_root self.root = Path(root) - self.spec = self._init_spec(branch_spec) self.verbose = verbose self.logger = logging.getLogger(self.__class__.__name__) self.plot = self.PLOT_METHODS(self) self.process = self.PROCESS_METHODS(self, self.spec) - # self.hyper = HyperparameterMethods(self) self.schema = self.SCHEMA_CLASS() self._paths_exists = PathCache(self.root, self.spec, exists_req=True) self._paths = self._paths_exists.get_shared_memory_view(exist_req=False) @@ -394,6 +392,11 @@ def _apply_data_ops(self, process_name: str, df: Optional[pd.DataFrame] = None): df = self._do_subset(df, column, val) for column, val in filter_.items(): df = self._do_filter(df, column, val) + df.replace(" ", "_", regex=True, inplace=True) + partitions_list = self.spec.get_partition_list(process_name) + partitions = set().union(*partitions_list) + if partitions: + df = label_df_partitions(df, partitions, encodings=True) return df @staticmethod diff --git a/dataforest/core/Interface.py b/dataforest/core/Interface.py index 950f317..b96cbcd 100644 --- a/dataforest/core/Interface.py +++ b/dataforest/core/Interface.py @@ -61,6 +61,7 @@ def from_input_dirs( remote_root: Optional[Union[str, Path]] = None, root_plots: bool = True, plot_kwargs: Optional[dict] = None, + overwrite_plots: Optional[Iterable[str]] = None, **kwargs, ) -> Union["DataBranch", "CellBranch", "DataTree", "CellTree"]: """ @@ -79,6 +80,8 @@ def from_input_dirs( remote_root: root_plots: plot_kwargs: + overwrite_plots: e.g. ["root", "cluster"] + # TODO implement overwrite plots """ if not isinstance(input_paths, (list, tuple)): input_paths = [input_paths] diff --git a/dataforest/core/PlotMethods.py b/dataforest/core/PlotMethods.py index 1296591..9685b28 100644 --- a/dataforest/core/PlotMethods.py +++ b/dataforest/core/PlotMethods.py @@ -1,9 +1,10 @@ from functools import wraps +from pathlib import Path +from typing import Optional, Dict from dataforest.config.MetaPlotMethods import MetaPlotMethods from dataforest.utils import tether, copy_func from dataforest.utils.ExceptionHandler import ExceptionHandler -from dataforest.utils.plots_config import get_plot_name_from_plot_method class PlotMethods(metaclass=MetaPlotMethods): @@ -21,13 +22,29 @@ def __init__(self, branch: "DataBranch"): setattr(self, name, self._wrap(callable_)) tether(self, "branch") + def regenerate_plots(self, plot_map: Optional[Dict[str, str]]): + raise NotImplementedError() + @property def plot_method_lookup(self): return self.__class__.PLOT_METHOD_LOOKUP @property def plot_methods(self): - return self.__class__.PLOT_METHODS + return self.__class__.PROCESS_PLOT_METHODS + + @property + def global_plot_methods(self): + global_plot_methods = { + config_name: callable_name + for name_mapping in self.plot_methods.values() + for config_name, callable_name in name_mapping.items() + } + return global_plot_methods + + @property + def global_plot_methods_reverse(self): + return {v: k for k, v in self.global_plot_methods.items()} @property def plot_kwargs_defaults(self): @@ -44,15 +61,15 @@ def _wrap(self, method): def wrapped(branch, method_name, *args, stop_on_error: bool = False, **kwargs): try: process_run = branch[branch.current_process] - plot_name = get_plot_name_from_plot_method( - branch.plot.plot_methods[branch.current_process], method_name - ) - + plot_name = branch.plot.global_plot_methods_reverse.get(method_name, None) if plot_name in process_run.plot_map: - for plot_kwargs_key, plot_filename in process_run.plot_map[plot_name].items(): - plot_dir = plot_filename.parent # only need one sample dir - - plot_dir.mkdir(exist_ok=True) + if not (plt_filename_lookup := process_run.plot_map[plot_name]): + if plt_filename_lookup: + _, plt_filepath = next(iter(plt_filename_lookup.items())) + plt_dir = plt_filepath.parent + else: + plt_dir = Path("/tmp") + plt_dir.mkdir(exist_ok=True) return method(branch, *args, **kwargs) except Exception as e: err_filename = method.__name__ diff --git a/dataforest/core/ProcessRun.py b/dataforest/core/ProcessRun.py index d1a2006..682f1ed 100644 --- a/dataforest/core/ProcessRun.py +++ b/dataforest/core/ProcessRun.py @@ -169,7 +169,7 @@ def logs(self): Prints stdout and stderr log files """ log_dir = self.path / "_logs" - log_files = log_dir.iterdir() + log_files = list(log_dir.iterdir()) stdouts = list(filter(lambda x: str(x).endswith(".out"), log_files)) stderrs = list(filter(lambda x: str(x).endswith(".err"), log_files)) if (len(stdouts) == 0) and (len(stderrs) == 0): diff --git a/dataforest/core/TreeSpec.py b/dataforest/core/TreeSpec.py index 0eac29b..90af53d 100644 --- a/dataforest/core/TreeSpec.py +++ b/dataforest/core/TreeSpec.py @@ -28,7 +28,7 @@ class TreeSpec(BranchSpec): >>> }, >>> ], >>> "_SUBSET_": { - >>> "sample": {"_SWEEP_": ["sample_1", "sample_2"]} + >>> "sample_id": {"_SWEEP_": ["sample_1", "sample_2"]} >>> }, >>> }, >>> { diff --git a/dataforest/hooks/hook.py b/dataforest/hooks/hook.py index 3744fd8..9d79523 100644 --- a/dataforest/hooks/hook.py +++ b/dataforest/hooks/hook.py @@ -1,4 +1,4 @@ -from functools import partial, update_wrapper, wraps +from functools import update_wrapper, wraps from typing import Callable, Optional, List, Tuple, Union diff --git a/dataforest/hooks/hooks/core/hooks.py b/dataforest/hooks/hooks/core/hooks.py index f671808..f1321ed 100644 --- a/dataforest/hooks/hooks/core/hooks.py +++ b/dataforest/hooks/hooks/core/hooks.py @@ -1,7 +1,8 @@ +from copy import deepcopy import gc import logging from pathlib import Path -from copy import deepcopy +import shutil import pandas as pd import yaml @@ -9,7 +10,6 @@ from dataforest.hooks.dataprocess import dataprocess from dataforest.utils.catalogue import run_id_from_multi_row from dataforest.utils.exceptions import InputDataNotFound -from dataforest.utils.plots_config import get_plot_name_from_plot_method from dataforest.hooks.hook import hook @@ -23,7 +23,8 @@ def hook_comparative(dp): """Sets up DataBranch for comparative analysis""" if "_PARTITION_" in dp.branch.spec: logging.warning( - "`partition` found at base level of branch_spec. It should normally be specified under an individual processes" + "`partition` found at base level of branch_spec. It should normally be specified under an individual " + "processes" ) if dp.comparative: @@ -136,7 +137,7 @@ def hook_generate_plots(dp: dataprocess): for method in plot_sources.values(): plot_method_name = method.__name__ if plot_method_name in requested_plot_methods.values(): - plot_name = get_plot_name_from_plot_method(process_plot_methods, plot_method_name) + plot_name = dp.branch.plot.global_plot_methods_reverse[plot_method_name] plot_kwargs_sets = all_plot_kwargs_sets[plot_name] for plot_kwargs_key in plot_kwargs_sets.keys(): plot_path = process_plot_map[plot_name][plot_kwargs_key] @@ -149,3 +150,10 @@ def hook_generate_plots(dp: dataprocess): logging.warning( f"Requested plotting methods {requested_plot_methods} are not implemented so they were skipped." ) + + +@hook +def hook_clear_logs(dp: dataprocess): + logs_path = dp.branch[dp.name].logs_path + if logs_path.exists(): + shutil.rmtree(str(logs_path), ignore_errors=True) diff --git a/dataforest/utils/plots_config.py b/dataforest/utils/plots_config.py index 25b0a29..11ae802 100644 --- a/dataforest/utils/plots_config.py +++ b/dataforest/utils/plots_config.py @@ -1,23 +1,30 @@ from copy import deepcopy import json from collections import OrderedDict +from typing import Dict -def parse_plot_methods(config: dict): - """Parse plot methods per process from plot_map""" +def build_process_plot_method_lookup(config: dict) -> Dict[str, Dict[str, str]]: + """ + Get a lookup of processes, each containing a mapping between plot method + names in the config and the actual callable names. + Format: + process_name[config_plot_name][plot_callable_name] + Ex: + {"normalize": {"_UMIS_PER_CELL_HIST_": "plot_umis_per_cell_hist", ...}, ...} + """ plot_map = config["plot_map"] - plot_methods = {} + process_plot_methods = {} for process, plots in plot_map.items(): - plot_methods[process] = {} + process_plot_methods[process] = {} for plot_name in plots.keys(): try: plot_method = plots[plot_name]["plot_method"] except (TypeError, KeyError): plot_method = _get_plot_method_from_plot_name(plot_name) - plot_methods[process][plot_name] = plot_method - - return plot_methods + process_plot_methods[process][plot_name] = plot_method + return process_plot_methods def parse_plot_kwargs(config: dict): @@ -83,19 +90,9 @@ def parse_plot_map(config: dict): plot_filename = _get_default_plot_filename(plot_name, plot_kwargs_set_mapped, plot_kwargs_defaults) all_plot_maps[process][plot_name][_get_plot_kwargs_string(plot_kwargs_set)] = plot_filename - return all_plot_maps -def get_plot_name_from_plot_method(process_plot_methods, plot_method_name): - """Reverse search for plot name in the config from plot method used""" - for key, value in process_plot_methods.items(): - if value == plot_method_name: - plot_name = key # look up plot name from plot_method name - - return plot_name - - def _get_plot_method_from_plot_name(plot_name): """Infer plot method name from plot name, e.g. _UMIS_PER_CELL_HIST_ -> plot_umis_per_cell_hist""" if plot_name[0] == "_": @@ -115,14 +112,14 @@ def _unify_kwargs_opt_lens(plot_kwargs: dict, plot_kwargs_defaults: dict, plot_n Examples: >>> _unify_kwargs_opt_lens( >>> { - >>> "stratify": ["sample", "none"], + >>> "stratify": ["sample_id", "none"], >>> "plot_size": "default" >>> }, >>> plot_kwargs_defaults, plot_name >>> ) # output { - "stratify": ["sample", "none"], + "stratify": ["sample_id", "none"], "plot_size": ["default", "default"] } """ diff --git a/requirements/requirements.txt b/requirements.txt similarity index 100% rename from requirements/requirements.txt rename to requirements.txt diff --git a/setup.py b/setup.py index 9f4e6e8..9cc8e79 100644 --- a/setup.py +++ b/setup.py @@ -11,13 +11,13 @@ def parse_requirements(requirements_path): return f.read().splitlines() -requirements = parse_requirements("requirements/requirements.txt") +requirements = parse_requirements("requirements.txt") test_requirements = parse_requirements("requirements/requirements-test.txt") dev_requirements = parse_requirements("requirements/requirements-dev.txt") setup( name="dataforest", # Replace with your own username - version="0.0.1", + version="0.0.2", author="Austin McKay", author_email="austinmckay303@gmail.com", description="An interactive data science workflow manager", @@ -28,9 +28,10 @@ def parse_requirements(requirements_path): # TODO: this works in "all", but not in `install_reqs` -- fix install_requires=requirements, extras_require={"all": requirements, "test": test_requirements, "dev": dev_requirements,}, + package_data={"dataforest": ["config/default_config.yaml"]}, classifiers=[ "Programming Language :: Python :: 3", - "License :: AGPL 3.0 License", + "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", "Operating System :: OS Independent", ], python_requires=">=3.6", diff --git a/tests/test_datatree.py b/tests/test_datatree.py index 689d607..9d7f18f 100644 --- a/tests/test_datatree.py +++ b/tests/test_datatree.py @@ -29,7 +29,7 @@ def tree_spec(): }, {"min_genes": 5, "max_genes": 5000, "min_cells": 5, "perc_mito_cutoff": 20, "method": "sctransform"}, ], - "_SUBSET_": {"sample": {"_SWEEP_": ["sample_1", "sample_2"]}}, + "_SUBSET_": {"sample_id": {"_SWEEP_": ["sample_1", "sample_2"]}}, }, { "_PROCESS_": "reduce", diff --git a/tests/test_tether.py b/tests/test_tether.py index 7000dac..822dad9 100644 --- a/tests/test_tether.py +++ b/tests/test_tether.py @@ -6,7 +6,7 @@ def __init__(self, val): self.val = val @staticmethod - def my_method(val): + def my_method(val, _): return val + 5 From 34f4aa97168964339e5a844784e5bad97e51a267 Mon Sep 17 00:00:00 2001 From: theaustinator Date: Fri, 25 Sep 2020 02:51:45 -0600 Subject: [PATCH 13/19] tree parallelization; plotting widgets; improved logging --- dataforest/config/MetaPlotMethods.py | 4 ++ dataforest/core/BranchSpec.py | 32 +++++++--- dataforest/core/DataBranch.py | 2 +- dataforest/core/DataTree.py | 22 +++++-- dataforest/core/Interface.py | 2 +- dataforest/core/PlotMethods.py | 49 +++++++++++++--- dataforest/core/ProcessRun.py | 7 ++- dataforest/core/RunGroupSpec.py | 2 + dataforest/core/Sweep.py | 6 +- dataforest/core/TreeSpec.py | 43 +++++++++++++- dataforest/hooks/dataprocess/dataprocess.py | 6 +- dataforest/hooks/hooks/core/hooks.py | 2 +- .../processes/core/TreeProcessMethods.py | 58 +++++++++++++++---- dataforest/structures/cache/PlotCache.py | 3 + dataforest/utils/ExceptionHandler.py | 32 ++++++++-- requirements.txt | 2 + 16 files changed, 224 insertions(+), 48 deletions(-) create mode 100644 dataforest/structures/cache/PlotCache.py diff --git a/dataforest/config/MetaPlotMethods.py b/dataforest/config/MetaPlotMethods.py index 8c61ea8..ea6174a 100644 --- a/dataforest/config/MetaPlotMethods.py +++ b/dataforest/config/MetaPlotMethods.py @@ -8,6 +8,10 @@ class MetaPlotMethods(MetaConfig): def PLOT_METHOD_LOOKUP(cls): return {k: v for source in cls.CONFIG["plot_sources"] for k, v in collect_plots(source).items()} + @property + def PLOT_MAP(cls): + return cls.CONFIG["plot_map"] + @property def PROCESS_PLOT_METHODS(cls): try: diff --git a/dataforest/core/BranchSpec.py b/dataforest/core/BranchSpec.py index f00a402..156b5dc 100644 --- a/dataforest/core/BranchSpec.py +++ b/dataforest/core/BranchSpec.py @@ -2,6 +2,8 @@ from copy import deepcopy from typing import Union, List, Dict +from typeguard import typechecked + from dataforest.core.RunSpec import RunSpec from dataforest.utils.exceptions import DuplicateProcessName @@ -64,6 +66,10 @@ def __init__(self, spec: Union[str, List[dict], "BranchSpec[RunSpec]"]): ) self.process_order: List[str] = [spec_item.name for spec_item in self] + @property + def processes(self): + return [run_spec["_PROCESS_"] for run_spec in self] + @property def shell_str(self): """string version which can be passed via shell and loaded via json""" @@ -194,18 +200,28 @@ def _build_precursors_lookup(self, incl_root: bool = False, incl_current: bool = current_precursors = current_precursors + [spec_item.name] return precursors - def __getitem__(self, item: Union[str, int]) -> "RunSpec": + @typechecked + def __getitem__(self, key: Union[str, int, slice]) -> Union["RunSpec", "BranchSpec"]: """Get `RunSpec` either via `int` index or `name`""" - if not isinstance(item, int): - try: - return self._run_spec_lookup[item] - except Exception as e: - raise e + if isinstance(key, str): + return self._run_spec_lookup[key] + elif isinstance(key, slice): + if isinstance(key.stop, str): + if key.start or key.step: + raise ValueError(f"Can only use stop with string slice (ex. [:'process_name'])") + precursors_lookup = self.get_precursors_lookup(incl_current=True) + import ipdb + + ipdb.set_trace() + precursors = precursors_lookup[key.stop] + return self.__class__([self._run_spec_lookup[process] for process in precursors]) + else: + return self.__class__(super().__getitem__(key)) else: - return super().__getitem__(item) + return super().__getitem__(key) def __setitem__(self, k, v): - raise ValueError("Cannot set items dynamically. All items must be defined at init") + raise NotImplementedError("Cannot set items dynamically. All items must be defined at init") def __contains__(self, item): return item in self._run_spec_lookup diff --git a/dataforest/core/DataBranch.py b/dataforest/core/DataBranch.py index a4f838d..49b6839 100644 --- a/dataforest/core/DataBranch.py +++ b/dataforest/core/DataBranch.py @@ -271,7 +271,7 @@ def create_root_plots(self, plot_kwargs: Optional[Dict[str, dict]] = None): ) return - if plot_kwargs == None: + if plot_kwargs is None: plot_kwargs = self.plot.plot_kwargs["root"] root_plot_map = self["root"].plot_map root_plot_methods = self.plot.plot_methods.get("root", []) diff --git a/dataforest/core/DataTree.py b/dataforest/core/DataTree.py index 1d517d7..0245947 100644 --- a/dataforest/core/DataTree.py +++ b/dataforest/core/DataTree.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import Union, Optional, List, Dict @@ -9,14 +10,14 @@ class DataTree(DataBase): - BRANCH_CLASS = DataBranch + _LOG = logging.getLogger("DataTree") + _BRANCH_CLASS = DataBranch def __init__( self, root: Union[str, Path], tree_spec: Optional[List[dict]] = None, verbose: bool = False, - current_process: Optional[str] = None, remote_root: Optional[Union[str, Path]] = None, ): # TODO: add something that tells them how many of each process will be run @@ -26,17 +27,26 @@ def __init__( self.root = root self.tree_spec = self._init_spec(tree_spec) self._verbose = verbose - self._current_process = current_process + self._current_process = None self.remote_root = remote_root - self._branch_cache = BranchCache( - root, self.tree_spec.branch_specs, self.BRANCH_CLASS, verbose, current_process, remote_root, - ) + self._branch_cache = BranchCache(root, self.tree_spec.branch_specs, self._BRANCH_CLASS, verbose, remote_root,) self.process = TreeProcessMethods(self.tree_spec, self._branch_cache) @property def n_branches(self): return len(self.tree_spec.branch_specs) + @property + def current_process(self): + return self._current_process if self._current_process else "root" + + def goto_process(self, process_name: str): + self._LOG.info(f"loading all branches to `goto_process`") + self._branch_cache.load_all() + for branch in self._branch_cache.values(): + branch.goto_process(process_name) + self._current_process = process_name + def update_process_spec(self, process_name: str, process_spec: dict): self.tree_spec[process_name] = process_spec self.update_spec(self.tree_spec) diff --git a/dataforest/core/Interface.py b/dataforest/core/Interface.py index b96cbcd..e725e18 100644 --- a/dataforest/core/Interface.py +++ b/dataforest/core/Interface.py @@ -24,7 +24,7 @@ def load( ) -> Union["DataBranch", "CellBranch", "DataTree", "CellTree"]: # TODO: replace kwargs with explicit to make it easier for users """ - Loads `cls.TREE_CLASS` if `tree_spec` passed, otherwise `cls.BRANCH_CLASS` + Loads `cls.TREE_CLASS` if `tree_spec` passed, otherwise `cls._BRANCH_CLASS` Args: root: branch_spec: diff --git a/dataforest/core/PlotMethods.py b/dataforest/core/PlotMethods.py index 9685b28..5280f15 100644 --- a/dataforest/core/PlotMethods.py +++ b/dataforest/core/PlotMethods.py @@ -1,8 +1,12 @@ from functools import wraps +import logging from pathlib import Path -from typing import Optional, Dict +from typing import Optional, Dict, Iterable + +from typeguard import typechecked from dataforest.config.MetaPlotMethods import MetaPlotMethods +from dataforest.structures.cache.PlotCache import PlotCache from dataforest.utils import tether, copy_func from dataforest.utils.ExceptionHandler import ExceptionHandler @@ -15,15 +19,35 @@ class PlotMethods(metaclass=MetaPlotMethods): """ def __init__(self, branch: "DataBranch"): + self._logger = logging.getLogger(self.__class__.__name__) self.branch = branch for name, plot_method in self.plot_method_lookup.items(): callable_ = copy_func(plot_method) callable_.__name__ = name setattr(self, name, self._wrap(callable_)) tether(self, "branch") + # self._plot_cache = {process: PlotCache(self.branch, process) for process in self.branch.spec.processes} + self._img_cache = {} - def regenerate_plots(self, plot_map: Optional[Dict[str, str]]): - raise NotImplementedError() + @typechecked + def regenerate_plots( + self, + processes: Optional[Iterable[str]], + plot_map: Optional[Dict[str, dict]], + plot_kwargs: Optional[Dict[str, dict]], + ): + plot_map = self.plot_map if not plot_map else plot_map + plot_kwargs = self.plot_kwargs if not plot_kwargs else plot_kwargs + if processes is not None: + plot_map = {proc: proc_plot_map for proc, proc_plot_map in plot_map if proc in processes} + plot_kwargs = {proc: proc_plot_kwargs for proc, proc_plot_kwargs in plot_kwargs if proc in processes} + for process, proc_plot_map in plot_map.items(): + method_names_config = tuple(proc_plot_map.keys()) + for name_config in method_names_config: + method_name = self.method_name_lookup[name_config] + kwargs = plot_kwargs[name_config] + method = getattr(self, method_name) + method(**kwargs) @property def plot_method_lookup(self): @@ -34,7 +58,11 @@ def plot_methods(self): return self.__class__.PROCESS_PLOT_METHODS @property - def global_plot_methods(self): + def method_lookup(self): + return {k: getattr(self, method_name) for k, method_name in self.method_name_lookup.items()} + + @property + def method_name_lookup(self): global_plot_methods = { config_name: callable_name for name_mapping in self.plot_methods.values() @@ -43,13 +71,18 @@ def global_plot_methods(self): return global_plot_methods @property - def global_plot_methods_reverse(self): - return {v: k for k, v in self.global_plot_methods.items()} + def method_key_lookup(self): + return {v: k for k, v in self.method_name_lookup.items()} @property def plot_kwargs_defaults(self): return self.__class__.PLOT_KWARGS_DEFAULTS + @property + def plot_map(self): + # TODO: rename to plot_settings for clarity and to avoid confusion w/ process run? + return self.__class__.PLOT_MAP + @property def plot_kwargs(self): return self.__class__.PLOT_KWARGS @@ -61,7 +94,7 @@ def _wrap(self, method): def wrapped(branch, method_name, *args, stop_on_error: bool = False, **kwargs): try: process_run = branch[branch.current_process] - plot_name = branch.plot.global_plot_methods_reverse.get(method_name, None) + plot_name = branch.plot.method_key_lookup.get(method_name, None) if plot_name in process_run.plot_map: if not (plt_filename_lookup := process_run.plot_map[plot_name]): if plt_filename_lookup: @@ -76,3 +109,5 @@ def wrapped(branch, method_name, *args, stop_on_error: bool = False, **kwargs): ExceptionHandler.handle(self.branch, e, err_filename, stop_on_error) return wrapped + + # def __getitem__(self, key): diff --git a/dataforest/core/ProcessRun.py b/dataforest/core/ProcessRun.py index 682f1ed..dfb1c3a 100644 --- a/dataforest/core/ProcessRun.py +++ b/dataforest/core/ProcessRun.py @@ -107,6 +107,7 @@ def path_map(self) -> Dict[str, Path]: @property def plot_map(self) -> Dict[str, Path]: + # TODO: confusing with new plot_map name in config - rename that to plot_settings? if self._plot_map is None: self._plot_map = self._build_path_map(incl_current=True, plot_map=True) return self._plot_map @@ -154,8 +155,10 @@ def success(self) -> bool: @property def failed(self) -> bool: if self.path is not None and self.path.exists(): - error_files = ["process.err", "hooks.err"] - if not set(error_files).intersection(self.logs_path.iterdir()): + error_prefixes = ["PROCESS__", "HOOKS__"] + is_error_file = lambda s: any(map(lambda prefix: s.name.startswith(prefix), error_prefixes)) + contains_error_file = any(map(is_error_file, self.logs_path.iterdir())) + if contains_error_file: return True return False diff --git a/dataforest/core/RunGroupSpec.py b/dataforest/core/RunGroupSpec.py index 6786f18..b643e4f 100644 --- a/dataforest/core/RunGroupSpec.py +++ b/dataforest/core/RunGroupSpec.py @@ -17,6 +17,7 @@ class RunGroupSpec(RunSpec): def __init__(self, dict_: Union[dict, "RunGroupSpec"]): super().__init__(**dict_) + self.sweeps = set() self.run_specs = self._build_run_specs() def _build_run_specs(self): @@ -57,6 +58,7 @@ def _map_sweeps(self): if isinstance(operation_dict, dict): for key, val in operation_dict.items(): if isinstance(val, dict) and "_SWEEP_" in val: + self.sweeps.add((self["_PROCESS_"], operation, key)) sweep_obj = val["_SWEEP_"] self[operation][key] = Sweep(operation, key, sweep_obj) diff --git a/dataforest/core/Sweep.py b/dataforest/core/Sweep.py index acf0059..40ffd5f 100644 --- a/dataforest/core/Sweep.py +++ b/dataforest/core/Sweep.py @@ -7,6 +7,10 @@ def __init__(self, operation, key, sweep_obj): self.key = key super().__init__(self._get_values(sweep_obj)) + @property + def dtype(self): + return type(self[0]) + @staticmethod def _get_values(sweep_obj): if isinstance(sweep_obj, dict): @@ -27,7 +31,7 @@ def _get_values(sweep_obj): elif isinstance(sweep_obj, (list, set, tuple)): values = list(sweep_obj) else: - raise TypeError(f"`sweep_obj` must be of types: [dict, list, set, tuple]") + raise TypeError(f"`sweep_obj` expected type: [dict, list, set, tuple], not {type(sweep_obj)}. {sweep_obj}") return values def __str__(self): diff --git a/dataforest/core/TreeSpec.py b/dataforest/core/TreeSpec.py index 90af53d..bb93a39 100644 --- a/dataforest/core/TreeSpec.py +++ b/dataforest/core/TreeSpec.py @@ -49,12 +49,53 @@ def __init__(self, tree_spec: Union[List[dict], "TreeSpec[RunGroupSpec]"]): super(list, self).__init__() self.extend([RunGroupSpec(item) for item in tree_spec]) self.branch_specs = self._build_branch_specs() + self.sweep_dict = {x["_PROCESS_"]: x.sweeps for x in self} + self._raw = tree_spec def _build_branch_specs(self): - return list(product(*[run_group_spec.run_specs for run_group_spec in self])) + return list(map(BranchSpec, product(*[run_group_spec.run_specs for run_group_spec in self]))) def __setitem__(self, key, value): if not isinstance(key, int): idx_lookup = {run_group_spec.name: i for i, run_group_spec in enumerate(self)} key = idx_lookup[key] list.__setitem__(self, key, value) + + +class PlotSpec: + def __init__(self, tree, plot_key, use_saved=True, **kwargs): + self._tree = tree + self._branch_spec = tree.tree_spec.branch_specs[0] + self._plot_key = plot_key + self._use_saved = use_saved + self._kwargs = kwargs + self._sweeps_remaining = len(self._tree.tree_spec.sweep_dict[self._tree.current_process]) + + def update(self, process, param, value): + self._branch_spec[process]["_PARAMS_"][param] = value + + def get_updater(self, process, param): + def updater(value): + self._branch_spec[process]["_PARAMS_"][param] = value + branch = self._tree._branch_cache[str(self)] + plot_map = branch[process].plot_map + plot_path_lookup = {plot_key: next(iter(path_dict.values())) for plot_key, path_dict in plot_map.items()} + if self._plot_key in plot_path_lookup and self._use_saved: + plot_path = plot_path_lookup.get(self._plot_key) + if plot_path.exists(): + return Image(plot_path) + self._generate_plot(branch, process) + + return updater + + def _generate_plot(self, branch: "DataBranch", process: str): + method = branch.plot.method_lookup[self._plot_key] + # fig, ax = method(**self._kwargs) + method(**self._kwargs) + # return fig, ax + + def __str__(self): + return str(self._branch_spec) + + def __repr__(self): + return repr(self._branch_spec) diff --git a/dataforest/hooks/dataprocess/dataprocess.py b/dataforest/hooks/dataprocess/dataprocess.py index b28224b..104e32f 100644 --- a/dataforest/hooks/dataprocess/dataprocess.py +++ b/dataforest/hooks/dataprocess/dataprocess.py @@ -61,7 +61,7 @@ def wrapper(branch, run_name, stop_on_error=True, stop_on_hook_error=True, *args try: return func(self.branch, run_name, *args, **kwargs) except Exception as e: - ExceptionHandler.handle(branch, e, "process.err", stop_on_error) + ExceptionHandler.handle(branch, e, f"PROCESS_{func.__name__}.err", stop_on_error) finally: self._run_hooks(self.clean_hooks, try_all=True, stop_on_hook_error=stop_on_hook_error) @@ -83,7 +83,7 @@ def _run_hooks(self, hooks: Iterable[Callable], try_all: bool = False, stop_on_h hook_exceptions[str(hook.__name__)] = e if not try_all: e = HookException(self.name, hook_exceptions) - ExceptionHandler.handle(self.branch, e, "hooks.err", stop_on_hook_error) + ExceptionHandler.handle(self.branch, e, f"HOOK_{hook.__name__}.err", stop_on_hook_error) if hook_exceptions: e = HookException(self.name, hook_exceptions) - ExceptionHandler.handle(self.branch, e, "hooks.err", stop_on_hook_error) + ExceptionHandler.handle(self.branch, e, f"HOOK_{hook.__name__}.err", stop_on_hook_error) diff --git a/dataforest/hooks/hooks/core/hooks.py b/dataforest/hooks/hooks/core/hooks.py index f1321ed..cecef11 100644 --- a/dataforest/hooks/hooks/core/hooks.py +++ b/dataforest/hooks/hooks/core/hooks.py @@ -137,7 +137,7 @@ def hook_generate_plots(dp: dataprocess): for method in plot_sources.values(): plot_method_name = method.__name__ if plot_method_name in requested_plot_methods.values(): - plot_name = dp.branch.plot.global_plot_methods_reverse[plot_method_name] + plot_name = dp.branch.plot.method_key_lookup[plot_method_name] plot_kwargs_sets = all_plot_kwargs_sets[plot_name] for plot_kwargs_key in plot_kwargs_sets.keys(): plot_path = process_plot_map[plot_name][plot_kwargs_key] diff --git a/dataforest/processes/core/TreeProcessMethods.py b/dataforest/processes/core/TreeProcessMethods.py index e3220ba..bdc8c8d 100644 --- a/dataforest/processes/core/TreeProcessMethods.py +++ b/dataforest/processes/core/TreeProcessMethods.py @@ -1,10 +1,17 @@ +import logging +from multiprocessing import cpu_count from typing import Callable, List, Union +from joblib import Parallel, delayed + from dataforest.core.TreeSpec import TreeSpec from dataforest.structures.cache.BranchCache import BranchCache class TreeProcessMethods: + _N_JOBS = cpu_count() - 1 + _LOG = logging.getLogger("TreeProcessMethods") + def __init__(self, tree_spec: TreeSpec, branch_cache: BranchCache): self._tree_spec = tree_spec self._branch_cache = branch_cache @@ -36,11 +43,13 @@ def _build_distributed_method(self, method_name: str) -> Callable: branches """ - def distributed_method( + def _distributed_method( *args, stop_on_error: bool = False, stop_on_hook_error: bool = False, clear_data: Union[bool, List[str]] = True, + force_rerun: bool = False, + parallel: bool = False, **kwargs, ): """ @@ -53,26 +62,51 @@ def distributed_method( after process execution to save memory. If boolean, whether or not to clear all data attrs, if list, names of data attrs to clear + force_rerun: force the process to rerun even if already done + parallel: whether or not to parallelize **kwargs: Returns: """ + kwargs = {"stop_on_error": stop_on_error, "stop_on_hook_error": stop_on_hook_error, **kwargs} if not self._branch_cache.fully_loaded: self._branch_cache.load_all() - return_vals = [] - for branch in list(self._branch_cache.values()): + all_branches = list(self._branch_cache.values()) + unique_branches = {str(branch.spec[:method_name]): branch for branch in all_branches} + + def _single_kernel(branch): branch_method = getattr(branch.process, method_name) - if not branch[method_name].done: - return_vals.append( - branch_method( - *args, stop_on_error=stop_on_error, stop_on_hook_error=stop_on_hook_error, **kwargs - ) - ) + if not branch[method_name].done or force_rerun: + ret = branch_method(*args, **kwargs) if clear_data: clear_kwargs = {"all_data": True} if isinstance(clear_data, bool) else {"attrs": clear_data} branch.clear_data(**clear_kwargs) - return return_vals + return ret + + def _distributed_kernel_serial(): + _ret_vals = [] + for branch in unique_branches: + _ret_vals.append(_single_kernel(branch)) + return _ret_vals + + def _distributed_kernel_parallel(): + process = delayed(_single_kernel) + pool = Parallel(n_jobs=self._N_JOBS) + return pool(process(branch) for branch in unique_branches) + + exec_scheme = "PARALLEL" if parallel else "SERIAL" + print(exec_scheme) + self._LOG.info( + f"{exec_scheme} execution of {method_name} over {self._N_JOBS} workers on {len(unique_branches)} " + f"unique conditions" + ) + kernel = _distributed_kernel_parallel if parallel else _distributed_kernel_serial + ret_vals = kernel() + return ret_vals + + _distributed_method.__name__ = method_name + return _distributed_method - distributed_method.__name__ = method_name - return distributed_method + def _distributed_kernel_parallel(self): + raise NotImplementedError() diff --git a/dataforest/structures/cache/PlotCache.py b/dataforest/structures/cache/PlotCache.py new file mode 100644 index 0000000..20f10eb --- /dev/null +++ b/dataforest/structures/cache/PlotCache.py @@ -0,0 +1,3 @@ +class PlotCache: + def __init__(self, branch, process): + self._branch = branch diff --git a/dataforest/utils/ExceptionHandler.py b/dataforest/utils/ExceptionHandler.py index e6ea8f9..8542d7f 100644 --- a/dataforest/utils/ExceptionHandler.py +++ b/dataforest/utils/ExceptionHandler.py @@ -1,20 +1,42 @@ +import logging import traceback class ExceptionHandler: - @staticmethod - def handle(branch, e, logfile_name, stop): + _LOG = logging.getLogger("ExceptionHandler") + + @classmethod + def handle(cls, branch, e, logfile_name, stop): + try: + log_path = cls._get_log_path(branch, logfile_name) + except Exception as path_e: + cls._LOG.warning(f"{type(path_e).__name__} encountered getting logging output path for {logfile_name}") + if stop: + raise e + else: + cls._handle_write(e, log_path, stop) + + @classmethod + def _handle_write(cls, e, log_path, stop): try: - ExceptionHandler._log_error(branch, e, logfile_name) + cls._write_log(e, log_path) + cls._LOG.info(f"Wrote log to {log_path}") + except Exception as write_e: + cls._LOG.warning(f"{type(write_e).__name__} encountered writing {log_path}") finally: if stop: raise e + else: + cls._LOG.warning(f"{type(e).__name__} encountered but `stop_on_error=False`. Logs at {log_path}") @staticmethod - def _log_error(branch, e, logfile_name): + def _get_log_path(branch, logfile_name): log_dir = branch[branch.current_process].logs_path log_dir.mkdir(exist_ok=True) - log_path = log_dir / logfile_name + return log_dir / logfile_name + + @staticmethod + def _write_log(e, log_path): with open(log_path, "w") as f: f.write("Traceback (most recent call last):") traceback.print_tb(e.__traceback__, file=f) diff --git a/requirements.txt b/requirements.txt index 5bc5c80..88ef6c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ +jobplib matplotlib numpy pandas pyyaml termcolor +typeguard Ipython \ No newline at end of file From 93389f9d25eb263cfa6285ac11409ea985a7ea00 Mon Sep 17 00:00:00 2001 From: theaustinator Date: Mon, 5 Oct 2020 22:39:43 -0600 Subject: [PATCH 14/19] plot widget; ProcessTreeRun; working on parallelism; --- dataforest/config/MetaPlotMethods.py | 2 +- dataforest/config/MetaProcessSchema.py | 2 +- dataforest/core/BranchSpec.py | 7 +- dataforest/core/DataBase.py | 14 +++- dataforest/core/DataBranch.py | 13 +-- dataforest/core/DataTree.py | 14 +++- dataforest/core/Interface.py | 15 ++-- dataforest/core/PlotWidget.py | 83 +++++++++++++++++++ dataforest/core/ProcessRun.py | 8 +- dataforest/core/ProcessTreeRun.py | 38 +++++++++ dataforest/core/RunGroupSpec.py | 3 +- dataforest/core/TreeSpec.py | 52 +++--------- .../processes/core/TreeProcessMethods.py | 8 +- dataforest/structures/cache/BranchCache.py | 11 ++- dataforest/structures/cache/PlotCache.py | 4 +- dataforest/utils/plots_config.py | 8 +- requirements.txt | 3 +- tests/test_tree.py | 4 - 18 files changed, 205 insertions(+), 84 deletions(-) create mode 100644 dataforest/core/PlotWidget.py create mode 100644 dataforest/core/ProcessTreeRun.py diff --git a/dataforest/config/MetaPlotMethods.py b/dataforest/config/MetaPlotMethods.py index ea6174a..a2e3291 100644 --- a/dataforest/config/MetaPlotMethods.py +++ b/dataforest/config/MetaPlotMethods.py @@ -26,7 +26,7 @@ def PLOT_KWARGS_DEFAULTS(cls): return cls.CONFIG["plot_kwargs_defaults"] @property - def PLOT_KWARGS(cls): # TODO-QC: mapping of process, plot to kwargs + def PLOT_KWARGS(cls): # TODO-QC: mapping of process, plot to plot_kwargs plot_kwargs = parse_plot_kwargs(config=cls.CONFIG) return plot_kwargs diff --git a/dataforest/config/MetaProcessSchema.py b/dataforest/config/MetaProcessSchema.py index 594811d..d86f68b 100644 --- a/dataforest/config/MetaProcessSchema.py +++ b/dataforest/config/MetaProcessSchema.py @@ -20,7 +20,7 @@ def FILE_MAP(cls): return cls["file_map"] @property - def PLOT_MAP(cls): # TODO-QC: process plot map starting here? Make it into a class where you can fetch kwargs? + def PLOT_MAP(cls): # TODO-QC: process plot map starting here? Make it into a class where you can fetch plot_kwargs? return parse_plot_map(cls.CONFIG) @property diff --git a/dataforest/core/BranchSpec.py b/dataforest/core/BranchSpec.py index 156b5dc..e0173a5 100644 --- a/dataforest/core/BranchSpec.py +++ b/dataforest/core/BranchSpec.py @@ -51,6 +51,8 @@ class BranchSpec(list): process_order: """ + _RUN_SPEC_CLASS = RunSpec + def __init__(self, spec: Union[str, List[dict], "BranchSpec[RunSpec]"]): if isinstance(spec, str): spec = json.loads(spec) @@ -176,7 +178,7 @@ def _get_data_operation_list(self, process_name: str, operation_name: str) -> Li def _build_run_spec_lookup(self) -> Dict[str, "RunSpec"]: """See class definition""" - run_spec_lookup = {"root": RunSpec({})} + run_spec_lookup = {"root": self._RUN_SPEC_CLASS({})} for run_spec in self: process_name = run_spec.name if process_name in run_spec_lookup: @@ -210,9 +212,6 @@ def __getitem__(self, key: Union[str, int, slice]) -> Union["RunSpec", "BranchSp if key.start or key.step: raise ValueError(f"Can only use stop with string slice (ex. [:'process_name'])") precursors_lookup = self.get_precursors_lookup(incl_current=True) - import ipdb - - ipdb.set_trace() precursors = precursors_lookup[key.stop] return self.__class__([self._run_spec_lookup[process] for process in precursors]) else: diff --git a/dataforest/core/DataBase.py b/dataforest/core/DataBase.py index ea241c5..e5bda84 100644 --- a/dataforest/core/DataBase.py +++ b/dataforest/core/DataBase.py @@ -1,5 +1,6 @@ +from abc import ABC from pathlib import Path -from typing import Union, Optional, List, Dict +from typing import Union, Optional, List, AnyStr from dataforest.core.PlotMethods import PlotMethods @@ -10,13 +11,18 @@ class DataBase: """ def __init__(self): + self.root = None self.plot = PlotMethods(self) + @property + def root_built(self): + return (Path(self.root) / "meta.tsv").exists() + @staticmethod def _combine_datasets( - root: Union[str, Path], - metadata: Optional[Union[str, Path]] = None, - input_paths: Optional[List[Union[str, Path]]] = None, + root: AnyStr, + metadata: Optional[AnyStr] = None, + input_paths: Optional[List[AnyStr]] = None, mode: Optional[str] = None, ): raise NotImplementedError("Must be implemented by subclass") diff --git a/dataforest/core/DataBranch.py b/dataforest/core/DataBranch.py index 49b6839..9fce56f 100644 --- a/dataforest/core/DataBranch.py +++ b/dataforest/core/DataBranch.py @@ -62,7 +62,7 @@ class DataBranch(DataBase): {reader, writer} from `{READER, WRITER}_METHODS` which was selected by `_map_default_{reader, writer}_methods` is not appropriate. Keys are `file_alias`es, and values are methods which take `filename`s - and `kwargs`. + and `plot_kwargs`. {READER, WRITER}_KWARGS_MAP: optional overloads for any file for which the default keyword arguments to the {reader, writer} should be @@ -345,8 +345,7 @@ def get_temp_meta_path(self: "DataBranch", process_name: str): def __getitem__(self, process_name: str) -> ProcessRun: if process_name not in self._process_runs: - if process_name in ("root", None): - process_name = "root" + process_name = "root" if process_name is None else process_name process = self.spec[process_name].process if process_name != "root" else "root" self._process_runs[process_name] = ProcessRun(self, process_name, process) return self._process_runs[process_name] @@ -389,9 +388,11 @@ def _apply_data_ops(self, process_name: str, df: Optional[pd.DataFrame] = None): df = self.meta.copy() for (subset, filter_) in zip(subset_list, filter_list): for column, val in subset.items(): - df = self._do_subset(df, column, val) + if val is not None: + df = self._do_subset(df, column, val) for column, val in filter_.items(): - df = self._do_filter(df, column, val) + if val is not None: + df = self._do_filter(df, column, val) df.replace(" ", "_", regex=True, inplace=True) partitions_list = self.spec.get_partition_list(process_name) partitions = set().union(*partitions_list) @@ -428,7 +429,7 @@ def _do_filter(df: pd.DataFrame, column: str, val: Any) -> pd.DataFrame: def _map_file_io(self) -> Tuple[Dict[str, IOCache], Dict[str, IOCache]]: """ Builds a lookup of lazy loading caches for file readers and writers, - which have implicit access to paths, methods, and kwargs for each file. + which have implicit access to paths, methods, and plot_kwargs for each file. Returns: {reader, writer}_map: Key: file_alias (e.g. "rna") diff --git a/dataforest/core/DataTree.py b/dataforest/core/DataTree.py index 0245947..5b82988 100644 --- a/dataforest/core/DataTree.py +++ b/dataforest/core/DataTree.py @@ -4,6 +4,8 @@ from dataforest.core.DataBase import DataBase from dataforest.core.DataBranch import DataBranch +from dataforest.core.ProcessTreeRun import ProcessTreeRun +from dataforest.core.RunGroupSpec import RunGroupSpec from dataforest.core.TreeSpec import TreeSpec from dataforest.processes.core.TreeProcessMethods import TreeProcessMethods from dataforest.structures.cache.BranchCache import BranchCache @@ -30,6 +32,7 @@ def __init__( self._current_process = None self.remote_root = remote_root self._branch_cache = BranchCache(root, self.tree_spec.branch_specs, self._BRANCH_CLASS, verbose, remote_root,) + self._process_tree_runs = dict() self.process = TreeProcessMethods(self.tree_spec, self._branch_cache) @property @@ -41,7 +44,6 @@ def current_process(self): return self._current_process if self._current_process else "root" def goto_process(self, process_name: str): - self._LOG.info(f"loading all branches to `goto_process`") self._branch_cache.load_all() for branch in self._branch_cache.values(): branch.goto_process(process_name) @@ -57,6 +59,9 @@ def update_spec(self, tree_spec: Union[List[dict], "TreeSpec[RunGroupSpec]"]): self._branch_cache.update_branch_specs(self.tree_spec.branch_specs) self.process = TreeProcessMethods(self.tree_spec, self._branch_cache) + def load_all(self): + self._branch_cache.load_all() + def run_all(self, workers: int = 1, batch_queue: Optional[str] = None): return [method() for method in self.process.process_methods] @@ -65,6 +70,13 @@ def create_root_plots(self, plot_kwargs: Optional[Dict[str, dict]] = None): rand_branch = self._branch_cache[str(rand_spec)] rand_branch.create_root_plots(plot_kwargs) + def __getitem__(self, process_name: str) -> ProcessTreeRun: + if process_name not in self._process_tree_runs: + process_name = "root" if process_name is None else process_name + process = self.tree_spec[process_name].process if process_name != "root" else "root" + self._process_tree_runs[process_name] = ProcessTreeRun(self, process_name, process) + return self._process_tree_runs[process_name] + @staticmethod def _init_spec(tree_spec: Union[list, TreeSpec]) -> TreeSpec: if tree_spec is None: diff --git a/dataforest/core/Interface.py b/dataforest/core/Interface.py index e725e18..570bc88 100644 --- a/dataforest/core/Interface.py +++ b/dataforest/core/Interface.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Union, Optional, Iterable, List +from typing import Union, Optional, Iterable, List, AnyStr import pandas as pd @@ -46,19 +46,24 @@ def load( } kwargs = cls._prune_kwargs(kwargs) interface_cls = cls._get_interface_class(kwargs) - return interface_cls(root, **kwargs) + inst = interface_cls(root, **kwargs) + if not inst.root_built: + raise ValueError( + "Root must be built once before it can be loaded. Use `from_input_dirs` or `from_sample_metadata`" + ) + return inst @classmethod def from_input_dirs( cls, - root: Union[str, Path], - input_paths: Optional[Union[str, Path, Iterable[Union[str, Path]]]] = None, + root: AnyStr, + input_paths: Optional[Union[AnyStr, Iterable[AnyStr]]] = None, mode: Optional[str] = None, branch_spec: Optional[List[dict]] = None, tree_spec: Optional[List[dict]] = None, verbose: bool = False, current_process: Optional[str] = None, - remote_root: Optional[Union[str, Path]] = None, + remote_root: Optional[AnyStr] = None, root_plots: bool = True, plot_kwargs: Optional[dict] = None, overwrite_plots: Optional[Iterable[str]] = None, diff --git a/dataforest/core/PlotWidget.py b/dataforest/core/PlotWidget.py new file mode 100644 index 0000000..ade4607 --- /dev/null +++ b/dataforest/core/PlotWidget.py @@ -0,0 +1,83 @@ +from copy import deepcopy +from typing import TYPE_CHECKING, Dict, Any + +from IPython.display import Image +import ipywidgets as widgets +from matplotlib.figure import Figure + +from dataforest.structures.cache.PlotCache import PlotCache + +if TYPE_CHECKING: + from dataforest.core.DataBranch import DataBranch + + +class PlotWidget: + def __init__(self, tree, plot_key, use_saved=True, **plot_kwargs): + self._tree = tree + self._branch_spec = deepcopy(tree.tree_spec.branch_specs[0]) + self._plot_key = plot_key + self._use_saved = use_saved + self._plot_kwargs = plot_kwargs + self._sweeps = self._tree.tree_spec.sweep_dict + self._plot_cache = dict() + # self.ax = plt.gca() + + def control(self): + process = self._tree.current_process + precursors = self._tree.tree_spec.get_precursors_lookup(incl_current=True)[process] + sweeps = set.union(*[self._sweeps[precursor] for precursor in precursors]) + param_sweeps = {":".join(swp[:3][::-1]): list(swp[3]) for swp in sweeps} + _kwargs = {**param_sweeps, **self._plot_kwargs} + + @widgets.interact(**_kwargs) + def _control(**kwargs: Dict[str, Any]): + for param_keys_str, value in kwargs.items(): + if param_keys_str in self._plot_kwargs: + self._plot_kwargs[param_keys_str] = value + continue + elif isinstance(value, float): + value = int(value) if int(value) == value else value + (name, operation, process) = param_keys_str.split(":") + self._branch_spec[process][operation][name] = value + # TODO: do we need to do something with kwargs? Or taken care of? + [kwargs.pop(k) for k in self._plot_kwargs] + branch = self._tree._branch_cache[str(self)] + return self._get_plot(branch) + + return _control + + def _get_plot(self, branch): + plot_map = branch[self._tree.current_process].plot_map + plot_path_lookup = {plot_key: next(iter(path_dict.values())) for plot_key, path_dict in plot_map.items()} + spec = branch.spec[: branch.current_process] + cache_key = str({"spec": spec, "plot_kwargs": self._plot_kwargs}) + if self._plot_key in plot_path_lookup and self._use_saved: + plot_path = plot_path_lookup.get(self._plot_key) + if plot_path.exists(): + return Image(plot_path) + generated = False + if not (plot_obj := self._plot_cache.get(cache_key, None)): + generated = True + plot_obj = self._generate_plot(branch) + self._plot_cache[cache_key] = plot_obj + if isinstance(plot_obj, tuple) and isinstance(plot_obj[0], Figure): + if not generated: + return plot_obj[0] + elif isinstance(plot_obj, Image): + return plot_obj + else: + raise TypeError(f"Expected types (matplotlib.axes.Axes, IPython.display.Image). Got {type(plot_obj)}") + + def _generate_plot(self, branch: "DataBranch", **kwargs): + method = branch.plot.method_lookup[self._plot_key] + # fig, ax = method(**self._plot_kwargs) + kwargs = {**self._plot_kwargs, **kwargs} + # TODO: might be able to use ax if integrate ax.figure + return method(**kwargs) # , ax=self.ax + # return fig, ax + + def __str__(self): + return str(self._branch_spec) + + def __repr__(self): + return repr(self._branch_spec) diff --git a/dataforest/core/ProcessRun.py b/dataforest/core/ProcessRun.py index dfb1c3a..c1936a5 100644 --- a/dataforest/core/ProcessRun.py +++ b/dataforest/core/ProcessRun.py @@ -1,9 +1,9 @@ from collections import ChainMap import logging +from pathlib import Path from typing import Dict, List, TYPE_CHECKING import pandas as pd -from pathlib import Path from termcolor import cprint @@ -18,7 +18,7 @@ class ProcessRun: """ def __init__(self, branch: "DataBranch", process_name: str, process: str): - self.logger = logging.getLogger(f"ProcessRun - {process_name}") + self._LOG = logging.getLogger(f"ProcessRun - {process_name}") if process_name not in branch.spec and process_name != "root": raise ValueError(f'key "{process_name}" not in branch_spec: {branch.spec}') self.process_name = process_name @@ -136,6 +136,10 @@ def file_map_done(self) -> Dict[str, str]: @property def done(self) -> bool: + """ + Whether or not process has been executed to completion, regardless of + success + """ if self.path.exists() and not (self.path / "INCOMPLETE").exists(): if len(self.files) > 0 or self.logs_path.exists(): return True diff --git a/dataforest/core/ProcessTreeRun.py b/dataforest/core/ProcessTreeRun.py new file mode 100644 index 0000000..f1f4897 --- /dev/null +++ b/dataforest/core/ProcessTreeRun.py @@ -0,0 +1,38 @@ +import logging +from typing import TYPE_CHECKING, Callable, Set, Tuple + +from dataforest.core.DataBranch import DataBranch + +if TYPE_CHECKING: + from dataforest.core.DataTree import DataTree + + +class ProcessTreeRun: + def __init__(self, tree: "DataTree", process_name: str, process: str): + self._LOG = logging.getLogger(f"ProcessRun - {process_name}") + if process_name not in tree.tree_spec and process_name != "root": + raise ValueError(f'key "{process_name}" not in tree_spec: {tree.tree_spec}') + self.process_name = process_name + self.process = process + self._tree = tree + + @property + def done(self): + self._tree.load_all() + return all(map(lambda branch: branch[self.process_name].done, self._tree._branch_cache.values())) + + @property + def failed(self): + filter_ = lambda branch, process: branch[process].done and not branch[process].success + return self._filter_branches(filter_) + + @property + def success(self): + filter_ = lambda branch, process: branch[process].done and branch[process].success + return self._filter_branches(filter_) + + def _filter_branches(self, filter_: Callable) -> tuple: + self._tree.load_all() + branches = self._tree._branch_cache.values() + process = self.process_name + return tuple([branch for branch in branches if filter_(branch, process)]) diff --git a/dataforest/core/RunGroupSpec.py b/dataforest/core/RunGroupSpec.py index b643e4f..1b9b668 100644 --- a/dataforest/core/RunGroupSpec.py +++ b/dataforest/core/RunGroupSpec.py @@ -58,7 +58,8 @@ def _map_sweeps(self): if isinstance(operation_dict, dict): for key, val in operation_dict.items(): if isinstance(val, dict) and "_SWEEP_" in val: - self.sweeps.add((self["_PROCESS_"], operation, key)) + sweep_info = (self["_PROCESS_"], operation, key, tuple(val["_SWEEP_"])) + self.sweeps.add(sweep_info) sweep_obj = val["_SWEEP_"] self[operation][key] = Sweep(operation, key, sweep_obj) diff --git a/dataforest/core/TreeSpec.py b/dataforest/core/TreeSpec.py index bb93a39..2318343 100644 --- a/dataforest/core/TreeSpec.py +++ b/dataforest/core/TreeSpec.py @@ -1,5 +1,5 @@ from itertools import product -from typing import Union, List +from typing import Union, List, Dict, Any from dataforest.core.BranchSpec import BranchSpec from dataforest.core.RunGroupSpec import RunGroupSpec @@ -45,12 +45,23 @@ class TreeSpec(BranchSpec): >>> tree_spec = TreeSpec(tree_spec) """ + _RUN_SPEC_CLASS = RunGroupSpec + def __init__(self, tree_spec: Union[List[dict], "TreeSpec[RunGroupSpec]"]): super(list, self).__init__() self.extend([RunGroupSpec(item) for item in tree_spec]) self.branch_specs = self._build_branch_specs() self.sweep_dict = {x["_PROCESS_"]: x.sweeps for x in self} + self.sweep_dict["root"] = set() + self._run_spec_lookup: Dict[str, "RunGroupSpec"] = self._build_run_spec_lookup() self._raw = tree_spec + self._run_spec_lookup: Dict[str, "RunSpec"] = self._build_run_spec_lookup() + self._precursors_lookup: Dict[str, List[str]] = self._build_precursors_lookup() + self._precursors_lookup_incl_curr: Dict[str, List[str]] = self._build_precursors_lookup(incl_current=True) + self._precursors_lookup_incl_root: Dict[str, List[str]] = self._build_precursors_lookup(incl_root=True) + self._precursors_lookup_incl_root_curr: Dict[str, List[str]] = self._build_precursors_lookup( + incl_root=True, incl_current=True + ) def _build_branch_specs(self): return list(map(BranchSpec, product(*[run_group_spec.run_specs for run_group_spec in self]))) @@ -60,42 +71,3 @@ def __setitem__(self, key, value): idx_lookup = {run_group_spec.name: i for i, run_group_spec in enumerate(self)} key = idx_lookup[key] list.__setitem__(self, key, value) - - -class PlotSpec: - def __init__(self, tree, plot_key, use_saved=True, **kwargs): - self._tree = tree - self._branch_spec = tree.tree_spec.branch_specs[0] - self._plot_key = plot_key - self._use_saved = use_saved - self._kwargs = kwargs - self._sweeps_remaining = len(self._tree.tree_spec.sweep_dict[self._tree.current_process]) - - def update(self, process, param, value): - self._branch_spec[process]["_PARAMS_"][param] = value - - def get_updater(self, process, param): - def updater(value): - self._branch_spec[process]["_PARAMS_"][param] = value - branch = self._tree._branch_cache[str(self)] - plot_map = branch[process].plot_map - plot_path_lookup = {plot_key: next(iter(path_dict.values())) for plot_key, path_dict in plot_map.items()} - if self._plot_key in plot_path_lookup and self._use_saved: - plot_path = plot_path_lookup.get(self._plot_key) - if plot_path.exists(): - return Image(plot_path) - self._generate_plot(branch, process) - - return updater - - def _generate_plot(self, branch: "DataBranch", process: str): - method = branch.plot.method_lookup[self._plot_key] - # fig, ax = method(**self._kwargs) - method(**self._kwargs) - # return fig, ax - - def __str__(self): - return str(self._branch_spec) - - def __repr__(self): - return repr(self._branch_spec) diff --git a/dataforest/processes/core/TreeProcessMethods.py b/dataforest/processes/core/TreeProcessMethods.py index bdc8c8d..6333cc6 100644 --- a/dataforest/processes/core/TreeProcessMethods.py +++ b/dataforest/processes/core/TreeProcessMethods.py @@ -10,7 +10,6 @@ class TreeProcessMethods: _N_JOBS = cpu_count() - 1 - _LOG = logging.getLogger("TreeProcessMethods") def __init__(self, tree_spec: TreeSpec, branch_cache: BranchCache): self._tree_spec = tree_spec @@ -86,18 +85,17 @@ def _single_kernel(branch): def _distributed_kernel_serial(): _ret_vals = [] - for branch in unique_branches: + for branch in unique_branches.values(): _ret_vals.append(_single_kernel(branch)) return _ret_vals def _distributed_kernel_parallel(): process = delayed(_single_kernel) pool = Parallel(n_jobs=self._N_JOBS) - return pool(process(branch) for branch in unique_branches) + return pool(process(branch) for branch in unique_branches.values()) exec_scheme = "PARALLEL" if parallel else "SERIAL" - print(exec_scheme) - self._LOG.info( + logging.info( f"{exec_scheme} execution of {method_name} over {self._N_JOBS} workers on {len(unique_branches)} " f"unique conditions" ) diff --git a/dataforest/structures/cache/BranchCache.py b/dataforest/structures/cache/BranchCache.py index 21d5085..7005710 100644 --- a/dataforest/structures/cache/BranchCache.py +++ b/dataforest/structures/cache/BranchCache.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import List, Union @@ -7,6 +8,8 @@ class BranchCache(HashCash): + _LOG = logging.getLogger("BranchCache") + def __init__(self, root: Union[str, Path], branch_specs: List[BranchSpec], branch_class: type, *args): super().__init__() self._root = root @@ -16,9 +19,11 @@ def __init__(self, root: Union[str, Path], branch_specs: List[BranchSpec], branc self.fully_loaded = False def load_all(self): - for spec_str in self._branch_spec_lookup: - _ = self[spec_str] # force load of all items - self.fully_loaded = True + self._LOG.info(f"loading all branches to `goto_process`") + if not self.fully_loaded: + for spec_str in self._branch_spec_lookup: + _ = self[spec_str] # force load of all items + self.fully_loaded = True def update_branch_specs(self, branch_specs: List[BranchSpec]): self._branch_spec_lookup = {str(spec): spec for spec in branch_specs} diff --git a/dataforest/structures/cache/PlotCache.py b/dataforest/structures/cache/PlotCache.py index 20f10eb..9bf34fa 100644 --- a/dataforest/structures/cache/PlotCache.py +++ b/dataforest/structures/cache/PlotCache.py @@ -1,3 +1,3 @@ class PlotCache: - def __init__(self, branch, process): - self._branch = branch + def __init__(self, branch, plot_key): + self._plot_key = plot_key diff --git a/dataforest/utils/plots_config.py b/dataforest/utils/plots_config.py index 11ae802..8092853 100644 --- a/dataforest/utils/plots_config.py +++ b/dataforest/utils/plots_config.py @@ -28,7 +28,7 @@ def build_process_plot_method_lookup(config: dict) -> Dict[str, Dict[str, str]]: def parse_plot_kwargs(config: dict): - """Parse plot methods kwargs per process from plot_map""" + """Parse plot methods plot_kwargs per process from plot_map""" plot_map = config["plot_map"] plot_kwargs_defaults = config["plot_kwargs_defaults"] all_plot_kwargs = {} @@ -57,7 +57,7 @@ def parse_plot_kwargs(config: dict): def parse_plot_map(config: dict): """ Parse plot file map per process from plot_map and ensures that - implicit definition returns a dictionary of default values for all kwargs + implicit definition returns a dictionary of default values for all plot_kwargs """ plot_map = config["plot_map"] plot_kwargs_defaults = config["plot_kwargs_defaults"] @@ -155,7 +155,7 @@ def _unify_kwargs_opt_lens(plot_kwargs: dict, plot_kwargs_defaults: dict, plot_n def _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults): - """Map plot_kwargs to values defined in kwargs defaults if available""" + """Map plot_kwargs to values defined in plot_kwargs defaults if available""" mapped_plot_kwargs = deepcopy(plot_kwargs) for kwarg_name, kwarg_values in plot_kwargs.items(): @@ -178,7 +178,7 @@ def _get_plot_kwargs_feed(plot_kwargs: dict, plot_kwargs_defaults: dict, plot_na plot_kwargs = _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults) plot_kwargs_feed = [ dict(j) for j in zip(*[[(k, i) for i in v] for k, v in plot_kwargs.items()]) - ] # 1-1 mapping of kwargs options + ] # 1-1 mapping of plot_kwargs options return plot_kwargs_feed diff --git a/requirements.txt b/requirements.txt index 88ef6c9..445851b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -jobplib +ipywidgets +joblib matplotlib numpy pandas diff --git a/tests/test_tree.py b/tests/test_tree.py index 5013b0d..9250581 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -17,7 +17,6 @@ def test_branch_1(): t[t.stack] = "c" t.down("c") t[t.stack] = "d" - print(t.dict) assert t.dict == {"a": {"b": "c", "c": "d"}} return t @@ -32,7 +31,6 @@ def test_branch_2(): t.down("c") t[t.stack] = "d" t[t.stack] = "e" - print(t.dict) assert t.dict == {"a": {"b": None, "c": {"d", "e"}}} return t @@ -64,10 +62,8 @@ def test_node_awareness_2(test_branch_2): def test_apply_leaves_1(test_branch_1): t = test_branch_1 t_2 = t.apply_leaves(lambda x: 1) - print(t_2.dict) def test_apply_leaves_2(test_branch_2): t = test_branch_2 t_2 = t.apply_leaves(lambda x: 1) - print(t_2.dict) From 2261270dd1db55b6cf8d5c6c9a03f8b3bc6e69c4 Mon Sep 17 00:00:00 2001 From: theaustinator Date: Mon, 12 Oct 2020 23:34:53 -0600 Subject: [PATCH 15/19] migrated plot wrappers from cellforest and added plot_sources to config; migrated groupby from cellforest; refactored subset and filter to allow `_MULTI_`; PlotMethods.{display, keys, methods}; PlotPreparator for stratify and faceting; improved config collectors; tether bug --- dataforest/__init__.py | 4 +- dataforest/config/MetaPlotMethods.py | 13 +- dataforest/config/MetaProcessSchema.py | 4 +- dataforest/core/BranchSpec.py | 3 +- dataforest/core/DataBranch.py | 92 ++++++++---- dataforest/core/PlotMethods.py | 67 +++++++-- dataforest/core/PlotWidget.py | 1 + dataforest/core/TreeSpec.py | 5 +- dataforest/hooks/dataprocess/dataprocess.py | 2 +- dataforest/hyperparams/Sweep.py | 2 +- dataforest/plot/PlotPreparator.py | 114 ++++++++++++++ dataforest/plot/__init__.py | 1 + dataforest/plot/wrappers.py | 142 ++++++++++++++++++ .../processes/core/TreeProcessMethods.py | 5 +- dataforest/utils/exceptions.py | 7 +- dataforest/utils/loaders/collectors.py | 7 +- dataforest/utils/loaders/config.py | 57 ++++--- dataforest/utils/loaders/path.py | 14 ++ dataforest/utils/loaders/update_config.py | 11 +- dataforest/utils/plots_config.py | 4 +- dataforest/utils/tether.py | 8 +- 21 files changed, 474 insertions(+), 89 deletions(-) create mode 100644 dataforest/plot/PlotPreparator.py create mode 100644 dataforest/plot/__init__.py create mode 100644 dataforest/plot/wrappers.py create mode 100644 dataforest/utils/loaders/path.py diff --git a/dataforest/__init__.py b/dataforest/__init__.py index a21a918..2273ab6 100644 --- a/dataforest/__init__.py +++ b/dataforest/__init__.py @@ -1,6 +1,8 @@ -from dataforest.utils.loaders.update_config import get_current_config, update_config +from dataforest.utils.loaders.config import CONFIG_OPTIONS, load_config as _load_config +from dataforest.utils.loaders.update_config import get_current_config, get_config_updater as _get_config_updater from dataforest.core.Interface import Interface +update_config = _get_config_updater(_load_config) load = Interface.load from_input_dirs = Interface.from_input_dirs diff --git a/dataforest/config/MetaPlotMethods.py b/dataforest/config/MetaPlotMethods.py index a2e3291..8a7164e 100644 --- a/dataforest/config/MetaPlotMethods.py +++ b/dataforest/config/MetaPlotMethods.py @@ -1,5 +1,9 @@ +from pathlib import Path +from typing import List + from dataforest.config.MetaConfig import MetaConfig from dataforest.utils.loaders.collectors import collect_plots +from dataforest.utils.loaders.path import get_module_paths from dataforest.utils.plots_config import build_process_plot_method_lookup, parse_plot_kwargs @@ -28,5 +32,12 @@ def PLOT_KWARGS_DEFAULTS(cls): @property def PLOT_KWARGS(cls): # TODO-QC: mapping of process, plot to plot_kwargs plot_kwargs = parse_plot_kwargs(config=cls.CONFIG) - return plot_kwargs + + # @property + # def R_FUNCTIONS_FILEPATH(cls) -> Path: + # return get_module_paths([cls.CONFIG["r_functions_sources"]])[0] + + @property + def R_PLOT_SOURCES(cls) -> List[Path]: + return get_module_paths(cls.CONFIG["r_plot_sources"]) diff --git a/dataforest/config/MetaProcessSchema.py b/dataforest/config/MetaProcessSchema.py index d86f68b..f1b3e68 100644 --- a/dataforest/config/MetaProcessSchema.py +++ b/dataforest/config/MetaProcessSchema.py @@ -21,7 +21,9 @@ def FILE_MAP(cls): @property def PLOT_MAP(cls): # TODO-QC: process plot map starting here? Make it into a class where you can fetch plot_kwargs? - return parse_plot_map(cls.CONFIG) + plot_map = cls.CONFIG.get("plot_map", dict()) + plot_kwargs_defaults = cls.CONFIG.get("plot_kwargs_defaults", dict()) + return parse_plot_map(plot_map, plot_kwargs_defaults) @property def LAYERS(cls): diff --git a/dataforest/core/BranchSpec.py b/dataforest/core/BranchSpec.py index e0173a5..8c14e52 100644 --- a/dataforest/core/BranchSpec.py +++ b/dataforest/core/BranchSpec.py @@ -4,6 +4,7 @@ from typeguard import typechecked +from dataforest.core.RunGroupSpec import RunGroupSpec from dataforest.core.RunSpec import RunSpec from dataforest.utils.exceptions import DuplicateProcessName @@ -176,7 +177,7 @@ def _get_data_operation_list(self, process_name: str, operation_name: str) -> Li operation_list.append(operation) return operation_list - def _build_run_spec_lookup(self) -> Dict[str, "RunSpec"]: + def _build_run_spec_lookup(self) -> Dict[str, Union["RunSpec", "RunGroupSpec"]]: """See class definition""" run_spec_lookup = {"root": self._RUN_SPEC_CLASS({})} for run_spec in self: diff --git a/dataforest/core/DataBranch.py b/dataforest/core/DataBranch.py index 9fce56f..6b38e62 100644 --- a/dataforest/core/DataBranch.py +++ b/dataforest/core/DataBranch.py @@ -213,6 +213,29 @@ def goto_process(self, process_name: str) -> "DataBranch": self.clear_data(path_map_changes) return self + def groupby(self, by: Union[str, list, set, tuple], **kwargs) -> Tuple[str, "DataBranch"]: + """ + Operates like a pandas group_labels, but does not return a GroupBy object, + and yields (name, DataBranch), where each DataBranch is subset according to `by`, + which corresponds to columns of `self.meta`. + This is useful for batching analysis across various conditions, where + each run requires an DataBranch. + Args: + by: variables over which to group (like pandas) + **kwargs: for pandas group_labels on `self.meta` + + Yields: + name: values for DataBranch `subset` according to keys specified in `by` + branch: new DataBranch which inherits `self.spec` with additional `subset`s + from `by` + """ + if isinstance(by, (tuple, set)): + by = list(by) + for (name, df) in self.meta.groupby(by, **kwargs): + branch = self.copy() + branch.set_meta(df) + yield name, branch + def fork(self, branch_spec: Union[list, BranchSpec]) -> "DataBranch": """ "Forks" the current branch to create a copy with an altered branch_spec, but @@ -264,8 +287,10 @@ def clear_data(self, attrs: Optional[Iterable[str]] = None, all_data: bool = Fal data_attr = f"_{attr_name}" setattr(self, data_attr, None) - def create_root_plots(self, plot_kwargs: Optional[Dict[str, dict]] = None): - if self.is_process_plots_exist("root"): + def create_root_plots( + self, plot_kwargs: Optional[Dict[str, dict]] = None, overwrite: bool = False, stop_on_error: bool = False + ): + if self.is_process_plots_exist("root") and not overwrite: self.logger.info( f"plots already present for `root` at {self['root'].plots_path}. To regenerate plots, delete directory" ) @@ -282,7 +307,7 @@ def create_root_plots(self, plot_kwargs: Optional[Dict[str, dict]] = None): method = getattr(self.plot, plot_method) kwargs = deepcopy(_kwargs) kwargs["plot_path"] = root_plot_map[plot_name][plot_kwargs_key] - method(**kwargs) + method(stop_on_error=stop_on_error, **kwargs) def is_process_plots_exist(self, process_name: str) -> bool: return self[process_name].plots_path.exists() @@ -388,11 +413,15 @@ def _apply_data_ops(self, process_name: str, df: Optional[pd.DataFrame] = None): df = self.meta.copy() for (subset, filter_) in zip(subset_list, filter_list): for column, val in subset.items(): - if val is not None: - df = self._do_subset(df, column, val) + if "_MULTI_" in column: + df = self._do_subset(df, val) + elif val is not None: + df = self._do_subset(df, {column: val}) for column, val in filter_.items(): - if val is not None: - df = self._do_filter(df, column, val) + if "_MULTI_" in column: + df = self._do_filter(df, val) + elif val is not None: + df = self._do_filter(df, {column: val}) df.replace(" ", "_", regex=True, inplace=True) partitions_list = self.spec.get_partition_list(process_name) partitions = set().union(*partitions_list) @@ -400,32 +429,37 @@ def _apply_data_ops(self, process_name: str, df: Optional[pd.DataFrame] = None): df = label_df_partitions(df, partitions, encodings=True) return df - @staticmethod - def _do_subset(df: pd.DataFrame, column: str, val: Any) -> pd.DataFrame: - prev_df = df.copy() - if isinstance(val, (list, set)): - df = df[df[column].isin(val)] - else: - df = df[df[column] == val] - if len(df) == len(prev_df): - logging.warning(f"Subset didn't change num of rows and may be unnecessary - column: {column}, val: {val}") - elif len(df) == 0: - raise BadSubset(column, val) + @classmethod + def _do_subset(cls, df: pd.DataFrame, subset_dict: Dict[str, Any]) -> pd.DataFrame: + df_selector = cls._get_df_selector(df, subset_dict) + df = df.loc[df_selector] + if len(df) == 0: + raise BadSubset(subset_dict) return df - @staticmethod - def _do_filter(df: pd.DataFrame, column: str, val: Any) -> pd.DataFrame: - prev_df = df.copy() - if isinstance(val, (list, set)): - df = df[~df[column].isin(val)] - else: - df = df[df[column] != val] - if len(df) == len(prev_df): - logging.warning(f"Filter didn't change num of rows and may be unnecessary - column: {column}, val: {val}") - elif len(df) == 0: - raise BadFilter(column, val) + @classmethod + def _do_filter(cls, df: pd.DataFrame, filter_dict: Dict[str, Any]) -> pd.DataFrame: + df_selector = cls._get_df_selector(df, filter_dict) + df = df.loc[~df_selector] + if len(df) == 0: + raise BadFilter(filter_dict) return df + @staticmethod + def _get_df_selector(df: pd.DataFrame, op_dict: Dict[str, Any]) -> pd.Series: + def _check_row_eq(row: pd.Series) -> bool: + for key, val in row.iteritems(): + if isinstance(val, set): + if not op_dict[key] in val: + return False + else: + if not op_dict[key] == val: + return False + return True + + df_selector = pd.DataFrame(pd.DataFrame(df[list(op_dict)]).apply(_check_row_eq, axis=1)).all(axis=1) + return df_selector + def _map_file_io(self) -> Tuple[Dict[str, IOCache], Dict[str, IOCache]]: """ Builds a lookup of lazy loading caches for file readers and writers, diff --git a/dataforest/core/PlotMethods.py b/dataforest/core/PlotMethods.py index 5280f15..a7aaa28 100644 --- a/dataforest/core/PlotMethods.py +++ b/dataforest/core/PlotMethods.py @@ -1,15 +1,18 @@ from functools import wraps import logging from pathlib import Path -from typing import Optional, Dict, Iterable +from typing import Optional, Dict, Iterable, TYPE_CHECKING, Set +from IPython.display import Image, display from typeguard import typechecked from dataforest.config.MetaPlotMethods import MetaPlotMethods -from dataforest.structures.cache.PlotCache import PlotCache from dataforest.utils import tether, copy_func from dataforest.utils.ExceptionHandler import ExceptionHandler +if TYPE_CHECKING: + from dataforest.core.DataBranch import DataBranch + class PlotMethods(metaclass=MetaPlotMethods): """ @@ -25,29 +28,37 @@ def __init__(self, branch: "DataBranch"): callable_ = copy_func(plot_method) callable_.__name__ = name setattr(self, name, self._wrap(callable_)) - tether(self, "branch") + tether(self, "branch", incl_methods=list(self.plot_method_lookup.keys())) # self._plot_cache = {process: PlotCache(self.branch, process) for process in self.branch.spec.processes} self._img_cache = {} + def display(self, process_name: str): + plots_path = self.branch[process_name].plots_path + for img_path in plots_path.iterdir(): + display(Image(img_path)) + @typechecked - def regenerate_plots( + def generate_plots( self, - processes: Optional[Iterable[str]], - plot_map: Optional[Dict[str, dict]], - plot_kwargs: Optional[Dict[str, dict]], + processes: Optional[Iterable[str]] = None, + plot_map: Optional[Dict[str, dict]] = None, + plot_kwargs: Optional[Dict[str, dict]] = None, ): plot_map = self.plot_map if not plot_map else plot_map plot_kwargs = self.plot_kwargs if not plot_kwargs else plot_kwargs if processes is not None: - plot_map = {proc: proc_plot_map for proc, proc_plot_map in plot_map if proc in processes} - plot_kwargs = {proc: proc_plot_kwargs for proc, proc_plot_kwargs in plot_kwargs if proc in processes} + plot_map = {proc: proc_plot_map for proc, proc_plot_map in plot_map.items() if proc in processes} + plot_kwargs = { + proc: proc_plot_kwargs for proc, proc_plot_kwargs in plot_kwargs.items() if proc in processes + } for process, proc_plot_map in plot_map.items(): method_names_config = tuple(proc_plot_map.keys()) for name_config in method_names_config: method_name = self.method_name_lookup[name_config] - kwargs = plot_kwargs[name_config] method = getattr(self, method_name) - method(**kwargs) + kwarg_sets = plot_kwargs[process][name_config].values() + for kwargs in kwarg_sets: + method(**kwargs) @property def plot_method_lookup(self): @@ -87,6 +98,38 @@ def plot_map(self): def plot_kwargs(self): return self.__class__.PLOT_KWARGS + @property + def methods(self) -> Dict[str, Set[str]]: + """ + Key: process_name at which plot becomes unlocked + Value: plot method names + """ + is_plot_method = lambda s: s.startswith("plot") and callable(getattr(self, s)) + method_names = list(filter(is_plot_method, dir(self))) + avail_dict = {} + + def _assign(method_name): + method = getattr(self, method_name) + if hasattr(method, "_requires"): + required = getattr(method, "_requires") + else: + required = "root" + if required not in avail_dict: + avail_dict[required] = set() + avail_dict[required].add(method_name) + + list(map(_assign, method_names)) + return avail_dict + + @property + def keys(self) -> Dict[str, Set[str]]: + """ + Key: process name at which plot becomes unlocked + Value: keys for plots in config + """ + convert_to_key = lambda set_: set(map(lambda s: "_" + s.upper()[5:] + "_", set_)) + return {k: convert_to_key(v) for k, v in self.methods.items()} + def _wrap(self, method): """Wrap with mkdirs and logging""" @@ -106,7 +149,7 @@ def wrapped(branch, method_name, *args, stop_on_error: bool = False, **kwargs): return method(branch, *args, **kwargs) except Exception as e: err_filename = method.__name__ - ExceptionHandler.handle(self.branch, e, err_filename, stop_on_error) + ExceptionHandler.handle(self.branch, e, f"PLOT_{err_filename}.err", stop_on_error) return wrapped diff --git a/dataforest/core/PlotWidget.py b/dataforest/core/PlotWidget.py index ade4607..668c475 100644 --- a/dataforest/core/PlotWidget.py +++ b/dataforest/core/PlotWidget.py @@ -26,6 +26,7 @@ def control(self): process = self._tree.current_process precursors = self._tree.tree_spec.get_precursors_lookup(incl_current=True)[process] sweeps = set.union(*[self._sweeps[precursor] for precursor in precursors]) + # {"key:operation:process": value, ...} (e.g. {"num_pcs:_PARAMS_:normalize": 30, ...}) param_sweeps = {":".join(swp[:3][::-1]): list(swp[3]) for swp in sweeps} _kwargs = {**param_sweeps, **self._plot_kwargs} diff --git a/dataforest/core/TreeSpec.py b/dataforest/core/TreeSpec.py index 2318343..0dd1172 100644 --- a/dataforest/core/TreeSpec.py +++ b/dataforest/core/TreeSpec.py @@ -1,5 +1,5 @@ from itertools import product -from typing import Union, List, Dict, Any +from typing import Union, List, Dict, TYPE_CHECKING from dataforest.core.BranchSpec import BranchSpec from dataforest.core.RunGroupSpec import RunGroupSpec @@ -53,9 +53,8 @@ def __init__(self, tree_spec: Union[List[dict], "TreeSpec[RunGroupSpec]"]): self.branch_specs = self._build_branch_specs() self.sweep_dict = {x["_PROCESS_"]: x.sweeps for x in self} self.sweep_dict["root"] = set() - self._run_spec_lookup: Dict[str, "RunGroupSpec"] = self._build_run_spec_lookup() self._raw = tree_spec - self._run_spec_lookup: Dict[str, "RunSpec"] = self._build_run_spec_lookup() + self._run_spec_lookup: Dict[str, "RunGroupSpec"] = self._build_run_spec_lookup() self._precursors_lookup: Dict[str, List[str]] = self._build_precursors_lookup() self._precursors_lookup_incl_curr: Dict[str, List[str]] = self._build_precursors_lookup(incl_current=True) self._precursors_lookup_incl_root: Dict[str, List[str]] = self._build_precursors_lookup(incl_root=True) diff --git a/dataforest/hooks/dataprocess/dataprocess.py b/dataforest/hooks/dataprocess/dataprocess.py index 104e32f..a86773b 100644 --- a/dataforest/hooks/dataprocess/dataprocess.py +++ b/dataforest/hooks/dataprocess/dataprocess.py @@ -18,7 +18,7 @@ class dataprocess(metaclass=MetaDataProcess): ✓ checks to ensure that specified input data is present - may make new tables and update them in the future - kwargs can be used to pass custom attributes + plot_kwargs can be used to pass custom attributes """ def __init__( diff --git a/dataforest/hyperparams/Sweep.py b/dataforest/hyperparams/Sweep.py index 75aafdb..e587782 100644 --- a/dataforest/hyperparams/Sweep.py +++ b/dataforest/hyperparams/Sweep.py @@ -88,7 +88,7 @@ def plot( plot_method_kwargs = dict() if figsize is None: figsize = tuple(self.DEFAULT_SUBPLOT_SIZE * np.array(self.shape)) - fig, ax = plt.subplots(*self.shape, sharex="col", sharey="row", figsize=figsize) + fig, ax = plt.subplots(*self.shape, sharex="cols", sharey="row", figsize=figsize) for (i, j) in self.indices: branch = self.branch_matrix[i, j] if branch is None: diff --git a/dataforest/plot/PlotPreparator.py b/dataforest/plot/PlotPreparator.py new file mode 100644 index 0000000..1e3ded7 --- /dev/null +++ b/dataforest/plot/PlotPreparator.py @@ -0,0 +1,114 @@ +from math import ceil +from typing import Callable, Optional, Literal, Union, List, Any, Tuple, Iterable + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from dataforest.core.DataBranch import DataBranch + + +class PlotPreparator: + DEFAULT_PLOT_RESOLUTION_PX = (500, 500) # width, height in pixels + _DEFAULT_N_COLS = 3 + + def __init__(self, branch: "DataBranch"): + self.branch_df = pd.DataFrame({"branch": [branch]}) + self.ax_arr = None + self.fig = None + self.facet_cols = None + self.strat_cols = None + + @property + def facet_vals(self): + if not self.facet_cols: + return None + return sorted(self.branch_df["facet"].unique()) + + @property + def srat_vals(self): + if not self.strat_cols: + return None + return sorted(self.branch_df["stratify"].unique()) + + def prepare(self, plot_kwargs: dict): + xlim = plot_kwargs.pop("xlim", None) + ylim = plot_kwargs.pop("ylim", None) + xscale = plot_kwargs.pop("xscale", None) + yscale = plot_kwargs.pop("yscale", None) + + @np.vectorize + def _apply_ax_kwargs(ax_: plt.Axes): + if xlim: + ax_.set_xlim(xlim) + if ylim: + ax_.set_ylim(ylim) + if xscale: + ax_.set_xscale(xscale) + if yscale: + ax_.set_yscale(yscale) + + plot_size = plot_kwargs.pop("plot_size", self.DEFAULT_PLOT_RESOLUTION_PX) + figsize = plot_kwargs.pop("figsize", None) + ax_arr = plot_kwargs.pop("ax", None) + fig = plot_kwargs.pop("fig", plt.gcf()) + if self.ax_arr is not None: + ax_arr = self.ax_arr + elif ax_arr is None: + fig, ax_arr = plt.subplots(1, 1) + if not isinstance(ax_arr, np.ndarray): + ax_arr = np.array([[ax_arr]]) + elif ax_arr.ndim == 1: + ax_arr = np.expand_dims(ax_arr, axis=0) + _apply_ax_kwargs(ax_arr) + dpi = fig.get_dpi() + # scale to pixel resolution, irrespective of screen resolution + fig.set_size_inches(plot_size[0] / float(dpi), plot_size[1] / float(dpi)) + if figsize: + fig.set_size_inches(*figsize) + self.fig = fig + self.ax_arr = ax_arr + + def facet(self, cols: Union[str, List[str]], n_rows: Optional[int] = None, n_cols: Optional[int] = None): + self._branch_groupby(cols, "facet") + self.facet_cols = cols + labels = self.branch_df["facet"].unique() + dim = self._get_facet_dim(labels, n_rows, n_cols) + _, self.ax_arr = plt.subplots(*dim, sharex="col", sharey="row") + + def stratify(self, cols: Union[str, List[str]], plot_kwargs): + self._branch_groupby(cols, "stratify") + self.strat_cols = cols + + def _branch_groupby(self, cols: Union[str, Iterable[str]], key_colname: Literal["facet", "stratify"]): + df = self.branch_df + df["grp"] = df["branch"].apply(lambda branch: list(branch.groupby(cols))) + df = df.explode("grp").reset_index(drop=True) + df[[key_colname, "branch"]] = pd.DataFrame(df["grp"].tolist(), index=df.index) + self.branch_df = df + + def _get_facet_dim( + self, labels: List[Any], n_rows: Optional[int] = None, n_cols: Optional[int] = None + ) -> Tuple[int, int]: + """ + Get dimensions of subplot grid for facet + Args: + labels: unique labels in facet col + n_rows: + n_cols: + + Returns: (n_rows, n_cols) + """ + # TODO: if two cols and rows/cols within range, have option to use each as an axis + if not n_rows and not n_cols: + n_cols = min(len(labels), self._DEFAULT_N_COLS) + if n_rows and n_cols and n_rows != len(labels) / n_cols: + raise ValueError( + f"If both `n_rows` and `n_cols` are specified, their product " + f"must be appropriate for ({len(labels)})" + ) + if n_cols: + n_rows = ceil(len(labels) / n_cols) + else: + n_cols = ceil(len(labels) / n_rows) + return n_rows, n_cols diff --git a/dataforest/plot/__init__.py b/dataforest/plot/__init__.py new file mode 100644 index 0000000..753bf3f --- /dev/null +++ b/dataforest/plot/__init__.py @@ -0,0 +1 @@ +from .wrappers import plot_py, plot_r, requires diff --git a/dataforest/plot/wrappers.py b/dataforest/plot/wrappers.py new file mode 100644 index 0000000..0b0cac1 --- /dev/null +++ b/dataforest/plot/wrappers.py @@ -0,0 +1,142 @@ +from functools import wraps +import json +import logging +from itertools import product +from typing import Tuple, Optional, AnyStr, Union + +from IPython.display import Image, display +import matplotlib +import matplotlib.pyplot as plt +import numpy as np + +from dataforest.core.DataBranch import DataBranch +from dataforest.core.PlotMethods import PlotMethods +from dataforest.plot.PlotPreparator import PlotPreparator + +_DEFAULT_BIG_PLOT_RESOLUTION_PX = (1000, 1000) # width, height in pixels +_PLOT_FILE_EXT = ".png" + + +_NONE_VARIANTS = [None, "none", "None", "NULL", "NA"] + + +def plot_py(plot_func): + @wraps(plot_func) + def wrapper( + branch: "DataBranch", + stratify: Optional[str] = None, + facet: Optional[str] = None, + plot_path: Optional[AnyStr] = None, + facet_dim: tuple = (), + **kwargs, + ) -> Union[plt.Figure, Tuple[plt.Figure, np.ndarray]]: + prep = PlotPreparator(branch) + if facet not in _NONE_VARIANTS: + kwargs["ax"] = prep.facet(facet, *facet_dim) + if plot_path is not None: + matplotlib.use("Agg") # don't plot on screen + if stratify not in _NONE_VARIANTS: + prep.stratify(stratify, kwargs) + prep.prepare(kwargs) + facet_inds = list(product(*map(range, prep.ax_arr.shape))) + for _, row in prep.branch_df.iterrows(): + ax = prep.ax_arr[0, 0] + if facet is not None: + i = prep.facet_vals.index(row["facet"]) + ax_i = facet_inds[i] + ax = prep.ax_arr[ax_i] + ax.set_title(row["facet"]) + if stratify is not None: + try: + kwargs["label"] = row["stratify"] + except: + import ipdb + + ipdb.set_trace() + plot_func(row["branch"], ax=ax, **kwargs) + if plot_path is not None: + logging.info(f"saving py figure to {plot_path}") + prep.fig.savefig(plot_path) + return prep.fig, prep.ax_arr + + return wrapper + + +def plot_r(plot_func): + @wraps(plot_func) + def wrapper( + branch: "DataBranch", + stratify: Optional[str] = None, + facet: Optional[str] = None, + plot_path: Optional[AnyStr] = None, + **kwargs, + ): + def _get_plot_script(): + for _plot_source in PlotMethods.R_PLOT_SOURCES: + _r_script = _plot_source / (plot_func.__name__ + ".R") + if _r_script.exists(): + return _plot_source, _r_script + + if facet is not None: + logging.warning("facet not yet implemented for R plots, but will create separate plots") + if stratify is not None: + logging.warning("stratify not yet supported for R plots") + plot_source, r_script = _get_plot_script() + plot_size = kwargs.pop("plot_size", PlotPreparator.DEFAULT_PLOT_RESOLUTION_PX) + if stratify not in _NONE_VARIANTS: + if stratify in branch.meta: # col exists in metadata + kwargs["group.by"] = stratify + else: + logging.warning(f"{plot_func.__name__} with key '{stratify}' is skipped because key is not in metadata") + return + subset_vals = [None] + if facet not in _NONE_VARIANTS: + subset_vals = sorted(branch.meta[facet].unique()) + img_arr = [] + + for val in subset_vals: + # corresponding arguments in r/plot_entry_point.R + args = [ + plot_source, # r_plot_scripts_path + branch.paths["root"], # root_dir + branch.spec.shell_str, # spec_str + facet, # subset_key + val, # subset_val + branch.current_process, # current_process + plot_path, # plot_file_path + plot_size[0], # plot_width_px + plot_size[1], # plot_height_px + json.dumps( + "kwargs = " + str(kwargs if kwargs else {}) + ), # TODO-QC: is there a better way to handle this? + ] + logging.info(f"saved R figure to {plot_path}") + plot_func(branch, r_script, args) # plot_kwargs already included in args + img = Image("/tmp/plot.png") + display(img) + img_arr.append(img) + return img_arr + + return wrapper + + +# noinspection PyPep8Naming +class requires: + def __init__(self, req_process): + self._req_process = req_process + + def __call__(self, func): + func._requires = self._req_process + + @wraps(func) + def wrapper(branch: "DataBranch", *args, **kwargs): + precursors = branch.spec.get_precursors_lookup(incl_current=True)[branch.current_process] + if self._req_process not in precursors: + proc = self._req_process + raise ValueError( + f"This plot method requires a branch at `{proc}` or later. Current process run: {precursors}. If " + f"`{proc}` has already been run, please use `branch.goto_process`. Otherwise, please run `{proc}`." + ) + return func(branch, *args, **kwargs) + + return wrapper diff --git a/dataforest/processes/core/TreeProcessMethods.py b/dataforest/processes/core/TreeProcessMethods.py index 6333cc6..16bd5ba 100644 --- a/dataforest/processes/core/TreeProcessMethods.py +++ b/dataforest/processes/core/TreeProcessMethods.py @@ -1,5 +1,4 @@ import logging -from multiprocessing import cpu_count from typing import Callable, List, Union from joblib import Parallel, delayed @@ -9,7 +8,7 @@ class TreeProcessMethods: - _N_JOBS = cpu_count() - 1 + _N_CPUS_EXCLUDED = 1 def __init__(self, tree_spec: TreeSpec, branch_cache: BranchCache): self._tree_spec = tree_spec @@ -91,7 +90,7 @@ def _distributed_kernel_serial(): def _distributed_kernel_parallel(): process = delayed(_single_kernel) - pool = Parallel(n_jobs=self._N_JOBS) + pool = Parallel(n_jobs=-1 - self._N_CPUS_EXCLUDED) return pool(process(branch) for branch in unique_branches.values()) exec_scheme = "PARALLEL" if parallel else "SERIAL" diff --git a/dataforest/utils/exceptions.py b/dataforest/utils/exceptions.py index cf3980f..61091b9 100644 --- a/dataforest/utils/exceptions.py +++ b/dataforest/utils/exceptions.py @@ -47,13 +47,12 @@ def __str__(self): class BadOperation(KeyError): OPERATION = None - def __init__(self, column, val): - self._column = column - self._val = val + def __init__(self, dict_): + self._dict = dict_ def __str__(self): return ( - f"{self.OPERATION} resulted in no rows - column: {self._column} val: {self._val}. Please note that any " + f"{self.OPERATION} resulted in no rows - {self._dict}. Please note that any " f"spaces in `spec` should be converted to underscores!" ) diff --git a/dataforest/utils/loaders/collectors.py b/dataforest/utils/loaders/collectors.py index 29d1494..43184d1 100644 --- a/dataforest/utils/loaders/collectors.py +++ b/dataforest/utils/loaders/collectors.py @@ -3,6 +3,8 @@ from pathlib import Path from typing import Optional, Callable +_PLOT_EXCLUDE_NAMES = ["plot_py", "plot_r"] + def collect_hooks(path): return _collector(path, "hooks.py", _function_filter_hook) @@ -63,9 +65,8 @@ def _function_filter_hook(func): def _function_filter_plot(func): - if hasattr(func, "__name__"): - return func.__name__.startswith("plot_") - # otherwise is just a variable + name = getattr(func, "__name__", "") + return name.startswith("plot_") and name not in _PLOT_EXCLUDE_NAMES def _function_filter_process(func): diff --git a/dataforest/utils/loaders/config.py b/dataforest/utils/loaders/config.py index 40cbb13..408fb08 100644 --- a/dataforest/utils/loaders/config.py +++ b/dataforest/utils/loaders/config.py @@ -1,24 +1,45 @@ import json from pathlib import Path -from typing import Union +from typing import Union, AnyStr, Dict, Callable import yaml +from dataforest import config as _config_module -def load_config(config: Union[dict, str, Path]): - """Load library loaders""" - if isinstance(config, (str, Path)): - config = Path(config) - if config.suffix == ".json": - with open(str(config), "r") as f: - config = json.load(f) - elif config.suffix == ".yaml": - with open(str(config), "r") as f: - config = yaml.load(f, yaml.FullLoader) - else: - raise ValueError("If filepath is passed, must be .json or .yaml") - if not isinstance(config, dict): - raise TypeError( - f"Config must be either a `dict` or a path to a .json or .yaml file as either a `str` or `pathlib.Path`" - ) - return config + +def get_config_options(config_dir: AnyStr) -> Dict[str, Path]: + config_paths = Path(config_dir).iterdir() + config_lookup = {path.stem: path for path in config_paths if not path.name.startswith("__")} + return config_lookup + + +def get_config_loader(config_options: Dict[str, Path]) -> Callable: + def _load_config(config: Union[dict, str, Path]) -> dict: + """ + Global configuration for package function from dict, filepath, or name + in `config_options` + """ + if isinstance(config, (str, Path)): + if config in config_options: + config = config_options[config] + config = Path(config) + if config.suffix == ".json": + with open(str(config), "r") as f: + config = json.load(f) + elif config.suffix == ".yaml": + with open(str(config), "r") as f: + config = yaml.load(f, yaml.FullLoader) + else: + raise ValueError("If filepath is passed, must be .json or .yaml") + if not isinstance(config, dict): + raise TypeError( + f"Config must be either a `dict` or a path to a .json or .yaml file as either a `str` or `pathlib.Path`" + ) + return config + + return _load_config + + +_CONFIG_DIR = Path(_config_module.__file__).parent +CONFIG_OPTIONS = get_config_options(_CONFIG_DIR) +load_config = get_config_loader(CONFIG_OPTIONS) diff --git a/dataforest/utils/loaders/path.py b/dataforest/utils/loaders/path.py new file mode 100644 index 0000000..e3feae4 --- /dev/null +++ b/dataforest/utils/loaders/path.py @@ -0,0 +1,14 @@ +from importlib import import_module +from pathlib import Path +from typing import AnyStr, List + + +def get_module_paths(paths: List[AnyStr]) -> List[Path]: + def _get_module_path(path: AnyStr) -> Path: + module = import_module(str(path)) + path = Path(module.__file__) + if path.stem == "__init__": + path = path.parent + return path + + return list(map(_get_module_path, paths)) diff --git a/dataforest/utils/loaders/update_config.py b/dataforest/utils/loaders/update_config.py index afaa601..59245c5 100644 --- a/dataforest/utils/loaders/update_config.py +++ b/dataforest/utils/loaders/update_config.py @@ -1,13 +1,16 @@ from pathlib import Path -from typing import Union +from typing import Union, Callable from dataforest.config.MetaConfig import MetaConfig from dataforest.utils.loaders.config import load_config -def update_config(config: Union[dict, str, Path]): - config = load_config(config) - MetaConfig.CONFIG = config +def get_config_updater(config_loader: Callable) -> Callable: + def _update_config(config: Union[dict, str, Path]): + config = config_loader(config) + MetaConfig.CONFIG = config + + return _update_config def get_current_config() -> dict: diff --git a/dataforest/utils/plots_config.py b/dataforest/utils/plots_config.py index 8092853..e1bdbe0 100644 --- a/dataforest/utils/plots_config.py +++ b/dataforest/utils/plots_config.py @@ -54,13 +54,11 @@ def parse_plot_kwargs(config: dict): return all_plot_kwargs -def parse_plot_map(config: dict): +def parse_plot_map(plot_map: dict, plot_kwargs_defaults: dict): """ Parse plot file map per process from plot_map and ensures that implicit definition returns a dictionary of default values for all plot_kwargs """ - plot_map = config["plot_map"] - plot_kwargs_defaults = config["plot_kwargs_defaults"] all_plot_maps = {} for process, plots in plot_map.items(): all_plot_maps[process] = {} diff --git a/dataforest/utils/tether.py b/dataforest/utils/tether.py index 39e9916..455e570 100644 --- a/dataforest/utils/tether.py +++ b/dataforest/utils/tether.py @@ -12,7 +12,7 @@ def tether(obj, tether_attr, incl_methods=None, excl_methods=None): if incl_methods and excl_methods: ValueError("Cannot specify both `incl_methods` and `excl_methods`") tether_arg = getattr(obj, tether_attr) - method_list = get_methods(obj, incl_methods, excl_methods) + method_list = get_methods(obj, excl_methods, incl_methods) for method_name in method_list: method = copy_func(getattr(obj, method_name)) tethered_method = make_tethered_method(method, tether_arg) @@ -20,19 +20,19 @@ def tether(obj, tether_attr, incl_methods=None, excl_methods=None): def get_methods( - obj: Any, excl_methods: Optional[Iterable[str]] = None, incl_methods: Optional[Iterable[str]] = None + obj: Any, excl_methods: Optional[Iterable[str]] = None, incl_methods: Optional[List[str]] = None ) -> List[str]: """Get all user defined methods of an object""" all_methods = filter(lambda x: not x.startswith("__"), dir(obj)) method_list = [] + if incl_methods: + return incl_methods for method_name in all_methods: if hasattr(obj, method_name): if callable(getattr(obj, method_name)): method_list.append(method_name) if excl_methods: method_list = [m for m in method_list if m not in excl_methods] - elif incl_methods: - method_list = incl_methods return method_list From dae6cd4e5ccf5d448aea34293fbc27bad7f1187a Mon Sep 17 00:00:00 2001 From: theaustinator Date: Tue, 13 Oct 2020 22:56:04 -0600 Subject: [PATCH 16/19] plot tweaks --- dataforest/plot/PlotPreparator.py | 1 + dataforest/plot/wrappers.py | 24 ++++++++++-------------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/dataforest/plot/PlotPreparator.py b/dataforest/plot/PlotPreparator.py index 1e3ded7..2ae5420 100644 --- a/dataforest/plot/PlotPreparator.py +++ b/dataforest/plot/PlotPreparator.py @@ -10,6 +10,7 @@ class PlotPreparator: DEFAULT_PLOT_RESOLUTION_PX = (500, 500) # width, height in pixels + NONE_VARIANTS = [None, "none", "None", "NULL", "NA"] _DEFAULT_N_COLS = 3 def __init__(self, branch: "DataBranch"): diff --git a/dataforest/plot/wrappers.py b/dataforest/plot/wrappers.py index 0b0cac1..8c85faf 100644 --- a/dataforest/plot/wrappers.py +++ b/dataforest/plot/wrappers.py @@ -17,9 +17,6 @@ _PLOT_FILE_EXT = ".png" -_NONE_VARIANTS = [None, "none", "None", "NULL", "NA"] - - def plot_py(plot_func): @wraps(plot_func) def wrapper( @@ -31,12 +28,14 @@ def wrapper( **kwargs, ) -> Union[plt.Figure, Tuple[plt.Figure, np.ndarray]]: prep = PlotPreparator(branch) - if facet not in _NONE_VARIANTS: + stratify = None if stratify in prep.NONE_VARIANTS else stratify + facet = None if facet in prep.NONE_VARIANTS else facet + if facet is not None: kwargs["ax"] = prep.facet(facet, *facet_dim) + if stratify is not None: + prep.stratify(stratify, kwargs) if plot_path is not None: matplotlib.use("Agg") # don't plot on screen - if stratify not in _NONE_VARIANTS: - prep.stratify(stratify, kwargs) prep.prepare(kwargs) facet_inds = list(product(*map(range, prep.ax_arr.shape))) for _, row in prep.branch_df.iterrows(): @@ -47,12 +46,7 @@ def wrapper( ax = prep.ax_arr[ax_i] ax.set_title(row["facet"]) if stratify is not None: - try: - kwargs["label"] = row["stratify"] - except: - import ipdb - - ipdb.set_trace() + kwargs["label"] = row["stratify"] plot_func(row["branch"], ax=ax, **kwargs) if plot_path is not None: logging.info(f"saving py figure to {plot_path}") @@ -77,20 +71,22 @@ def _get_plot_script(): if _r_script.exists(): return _plot_source, _r_script + stratify = None if stratify in PlotPreparator.NONE_VARIANTS else stratify + facet = None if facet in PlotPreparator.NONE_VARIANTS else facet if facet is not None: logging.warning("facet not yet implemented for R plots, but will create separate plots") if stratify is not None: logging.warning("stratify not yet supported for R plots") plot_source, r_script = _get_plot_script() plot_size = kwargs.pop("plot_size", PlotPreparator.DEFAULT_PLOT_RESOLUTION_PX) - if stratify not in _NONE_VARIANTS: + if stratify is not None: if stratify in branch.meta: # col exists in metadata kwargs["group.by"] = stratify else: logging.warning(f"{plot_func.__name__} with key '{stratify}' is skipped because key is not in metadata") return subset_vals = [None] - if facet not in _NONE_VARIANTS: + if facet is not None: subset_vals = sorted(branch.meta[facet].unique()) img_arr = [] From 94477f19e6fa4a62c8310db6419abf2ca990429d Mon Sep 17 00:00:00 2001 From: theaustinator Date: Thu, 15 Oct 2020 00:51:33 -0600 Subject: [PATCH 17/19] twigs and twigs widget; sundry plotting bugs --- dataforest/config/MetaPlotMethods.py | 10 ++- dataforest/core/DataBranch.py | 22 ----- dataforest/core/DataTree.py | 34 +++++--- dataforest/core/Interface.py | 4 +- dataforest/core/PlotMethods.py | 30 ++++++- dataforest/core/PlotTreeMethods.py | 6 ++ dataforest/core/PlotWidget.py | 53 +++++++++--- dataforest/core/ProcessRun.py | 85 ++++++++++++------- dataforest/core/TreeSpec.py | 39 +++++++-- dataforest/hooks/hooks/core/hooks.py | 14 ++- dataforest/plot/wrappers.py | 25 ++++-- .../processes/core/TreeProcessMethods.py | 5 +- dataforest/utils/plots_config.py | 7 +- 13 files changed, 219 insertions(+), 115 deletions(-) create mode 100644 dataforest/core/PlotTreeMethods.py diff --git a/dataforest/config/MetaPlotMethods.py b/dataforest/config/MetaPlotMethods.py index 8a7164e..68dfcc0 100644 --- a/dataforest/config/MetaPlotMethods.py +++ b/dataforest/config/MetaPlotMethods.py @@ -14,24 +14,26 @@ def PLOT_METHOD_LOOKUP(cls): @property def PLOT_MAP(cls): - return cls.CONFIG["plot_map"] + return cls.CONFIG.get("plot_map", dict()) @property def PROCESS_PLOT_METHODS(cls): try: plot_methods = cls.CONFIG["plot_methods"] except KeyError: - plot_methods = build_process_plot_method_lookup(config=cls.CONFIG) + plot_methods = build_process_plot_method_lookup(cls.CONFIG.get("plot_map", dict())) return plot_methods @property def PLOT_KWARGS_DEFAULTS(cls): - return cls.CONFIG["plot_kwargs_defaults"] + return cls.CONFIG.get("plot_kwargs_defaults", dict()) @property def PLOT_KWARGS(cls): # TODO-QC: mapping of process, plot to plot_kwargs - plot_kwargs = parse_plot_kwargs(config=cls.CONFIG) + plot_map = cls.CONFIG.get("plot_map", dict()) + plot_kwargs_defaults = cls.CONFIG.get("plot_kwargs_defaults", dict()) + plot_kwargs = parse_plot_kwargs(plot_map, plot_kwargs_defaults) return plot_kwargs # @property diff --git a/dataforest/core/DataBranch.py b/dataforest/core/DataBranch.py index 6b38e62..292a826 100644 --- a/dataforest/core/DataBranch.py +++ b/dataforest/core/DataBranch.py @@ -287,28 +287,6 @@ def clear_data(self, attrs: Optional[Iterable[str]] = None, all_data: bool = Fal data_attr = f"_{attr_name}" setattr(self, data_attr, None) - def create_root_plots( - self, plot_kwargs: Optional[Dict[str, dict]] = None, overwrite: bool = False, stop_on_error: bool = False - ): - if self.is_process_plots_exist("root") and not overwrite: - self.logger.info( - f"plots already present for `root` at {self['root'].plots_path}. To regenerate plots, delete directory" - ) - return - - if plot_kwargs is None: - plot_kwargs = self.plot.plot_kwargs["root"] - root_plot_map = self["root"].plot_map - root_plot_methods = self.plot.plot_methods.get("root", []) - - for plot_name, plot_method in root_plot_methods.items(): - kwargs_sets = plot_kwargs.get(plot_name, dict()) - for plot_kwargs_key, _kwargs in kwargs_sets.items(): - method = getattr(self.plot, plot_method) - kwargs = deepcopy(_kwargs) - kwargs["plot_path"] = root_plot_map[plot_name][plot_kwargs_key] - method(stop_on_error=stop_on_error, **kwargs) - def is_process_plots_exist(self, process_name: str) -> bool: return self[process_name].plots_path.exists() diff --git a/dataforest/core/DataTree.py b/dataforest/core/DataTree.py index 5b82988..cd4c0bd 100644 --- a/dataforest/core/DataTree.py +++ b/dataforest/core/DataTree.py @@ -19,6 +19,7 @@ def __init__( self, root: Union[str, Path], tree_spec: Optional[List[dict]] = None, + twigs: Optional[List[dict]] = None, verbose: bool = False, remote_root: Optional[Union[str, Path]] = None, ): @@ -27,7 +28,8 @@ def __init__( # it finishes. That way, we can super().__init__() self.root = root - self.tree_spec = self._init_spec(tree_spec) + self.tree_spec = self._init_spec(tree_spec, twigs) + self._twigs = twigs self._verbose = verbose self._current_process = None self.remote_root = remote_root @@ -43,22 +45,26 @@ def n_branches(self): def current_process(self): return self._current_process if self._current_process else "root" + @property + def has_sweeps(self) -> bool: + """Whether or not any sweeps are in the spec""" + return bool(set.union(*self.tree_spec.sweep_dict.values())) + + @property + def has_twigs(self): + return bool(self._twigs) + + @property + def branches(self): + self._branch_cache.load_all() + return list(self._branch_cache.values()) + def goto_process(self, process_name: str): self._branch_cache.load_all() for branch in self._branch_cache.values(): branch.goto_process(process_name) self._current_process = process_name - def update_process_spec(self, process_name: str, process_spec: dict): - self.tree_spec[process_name] = process_spec - self.update_spec(self.tree_spec) - - def update_spec(self, tree_spec: Union[List[dict], "TreeSpec[RunGroupSpec]"]): - tree_spec = list(tree_spec) - self.tree_spec = self._init_spec(tree_spec) - self._branch_cache.update_branch_specs(self.tree_spec.branch_specs) - self.process = TreeProcessMethods(self.tree_spec, self._branch_cache) - def load_all(self): self._branch_cache.load_all() @@ -68,7 +74,7 @@ def run_all(self, workers: int = 1, batch_queue: Optional[str] = None): def create_root_plots(self, plot_kwargs: Optional[Dict[str, dict]] = None): rand_spec = self.tree_spec.branch_specs[0] rand_branch = self._branch_cache[str(rand_spec)] - rand_branch.create_root_plots(plot_kwargs) + rand_branch._generate_root_plots(plot_kwargs) def __getitem__(self, process_name: str) -> ProcessTreeRun: if process_name not in self._process_tree_runs: @@ -78,9 +84,9 @@ def __getitem__(self, process_name: str) -> ProcessTreeRun: return self._process_tree_runs[process_name] @staticmethod - def _init_spec(tree_spec: Union[list, TreeSpec]) -> TreeSpec: + def _init_spec(tree_spec: Union[list, TreeSpec], twigs: Optional[List[dict]]) -> TreeSpec: if tree_spec is None: tree_spec = list() if not isinstance(tree_spec, TreeSpec): - tree_spec = TreeSpec(tree_spec) + tree_spec = TreeSpec(tree_spec, twigs) return tree_spec diff --git a/dataforest/core/Interface.py b/dataforest/core/Interface.py index 570bc88..f302ee5 100644 --- a/dataforest/core/Interface.py +++ b/dataforest/core/Interface.py @@ -104,7 +104,7 @@ def from_input_dirs( kwargs = {**additional_kwargs, **kwargs} inst = interface_cls(root, **kwargs) if root_plots: - inst.create_root_plots(plot_kwargs) + inst.plot.generate_plots("root", plot_kwargs) return inst @classmethod @@ -166,7 +166,7 @@ def from_sample_metadata( kwargs = {**additional_kwargs, **kwargs} inst = interface_cls(root, **kwargs) if root_plots: - inst.create_root_plots(plot_kwargs) + inst.plot.generate_plots("root", plot_kwargs) return inst @staticmethod diff --git a/dataforest/core/PlotMethods.py b/dataforest/core/PlotMethods.py index a7aaa28..3e8cf73 100644 --- a/dataforest/core/PlotMethods.py +++ b/dataforest/core/PlotMethods.py @@ -1,3 +1,4 @@ +from copy import deepcopy from functools import wraps import logging from pathlib import Path @@ -32,9 +33,10 @@ def __init__(self, branch: "DataBranch"): # self._plot_cache = {process: PlotCache(self.branch, process) for process in self.branch.spec.processes} self._img_cache = {} - def display(self, process_name: str): + def show(self, process_name: str): plots_path = self.branch[process_name].plots_path for img_path in plots_path.iterdir(): + display(str(img_path)) display(Image(img_path)) @typechecked @@ -138,8 +140,8 @@ def wrapped(branch, method_name, *args, stop_on_error: bool = False, **kwargs): try: process_run = branch[branch.current_process] plot_name = branch.plot.method_key_lookup.get(method_name, None) - if plot_name in process_run.plot_map: - if not (plt_filename_lookup := process_run.plot_map[plot_name]): + if plot_name in process_run._plot_map: + if not (plt_filename_lookup := process_run._plot_map[plot_name]): if plt_filename_lookup: _, plt_filepath = next(iter(plt_filename_lookup.items())) plt_dir = plt_filepath.parent @@ -153,4 +155,24 @@ def wrapped(branch, method_name, *args, stop_on_error: bool = False, **kwargs): return wrapped - # def __getitem__(self, key): + def _generate_root_plots( + self, plot_kwargs: Optional[Dict[str, dict]] = None, overwrite: bool = False, stop_on_error: bool = False + ): + if self.branch.is_process_plots_exist("root") and not overwrite: + logging.info( + f"plots already present for `root` at {self.branch['root'].plots_path}. To regenerate plots, delete dir" + ) + return + + if plot_kwargs is None: + plot_kwargs = self.plot_kwargs["root"] + root_plot_map = self.branch["root"]._plot_map + root_plot_methods = self.plot_methods.get("root", []) + + for plot_name, plot_method in root_plot_methods.items(): + kwargs_sets = plot_kwargs.get(plot_name, dict()) + for plot_kwargs_key, _kwargs in kwargs_sets.items(): + method = getattr(self, plot_method) + kwargs = deepcopy(_kwargs) + kwargs["plot_path"] = root_plot_map[plot_name][plot_kwargs_key] + method(stop_on_error=stop_on_error, **kwargs) diff --git a/dataforest/core/PlotTreeMethods.py b/dataforest/core/PlotTreeMethods.py new file mode 100644 index 0000000..c5eae6e --- /dev/null +++ b/dataforest/core/PlotTreeMethods.py @@ -0,0 +1,6 @@ +from dataforest.core.PlotMethods import PlotMethods + + +class PlotTreeMethods(PlotMethods): + def __init__(self, tree): + self._tree = tree diff --git a/dataforest/core/PlotWidget.py b/dataforest/core/PlotWidget.py index 668c475..b6d1098 100644 --- a/dataforest/core/PlotWidget.py +++ b/dataforest/core/PlotWidget.py @@ -1,3 +1,4 @@ +import logging from copy import deepcopy from typing import TYPE_CHECKING, Dict, Any @@ -5,14 +6,21 @@ import ipywidgets as widgets from matplotlib.figure import Figure -from dataforest.structures.cache.PlotCache import PlotCache - if TYPE_CHECKING: from dataforest.core.DataBranch import DataBranch + from dataforest.core.DataTree import DataTree class PlotWidget: - def __init__(self, tree, plot_key, use_saved=True, **plot_kwargs): + def __init__(self, tree: "DataTree", plot_key: str, use_saved: bool = True, **plot_kwargs): + """ + + Args: + tree: + plot_key: key format of plot name (e.g. "_UMIS_VS_PERC_HSP_SCAT_") + use_saved: use saved plots or regenerate + **plot_kwargs: + """ self._tree = tree self._branch_spec = deepcopy(tree.tree_spec.branch_specs[0]) self._plot_key = plot_key @@ -22,7 +30,17 @@ def __init__(self, tree, plot_key, use_saved=True, **plot_kwargs): self._plot_cache = dict() # self.ax = plt.gca() - def control(self): + def control(self, **kwargs): + if self._tree.has_sweeps and self._tree.has_twigs: + logging.warning( + "PlotWidget doesn't support trees with both sweeps and twigs. Please use one or the other. Defaulting to sweeps" + ) + if self._tree.has_sweeps: + return self._sweeps_control(**kwargs) + else: + return self._twigs_control(**kwargs) + + def _sweeps_control(self, **plotter_kwargs): process = self._tree.current_process precursors = self._tree.tree_spec.get_precursors_lookup(incl_current=True)[process] sweeps = set.union(*[self._sweeps[precursor] for precursor in precursors]) @@ -43,12 +61,23 @@ def _control(**kwargs: Dict[str, Any]): # TODO: do we need to do something with kwargs? Or taken care of? [kwargs.pop(k) for k in self._plot_kwargs] branch = self._tree._branch_cache[str(self)] - return self._get_plot(branch) + return self._get_plot(branch, **plotter_kwargs) + + return _control + + def _twigs_control(self, **plotter_kwargs): + twig_specs = self._tree.tree_spec.twig_specs + + @widgets.interact(twig_str=list(twig_specs.keys()), **self._plot_kwargs) + def _control(twig_str: str): + spec = twig_specs[twig_str] + branch = self._tree._branch_cache[str(spec)] + self._get_plot(branch, **plotter_kwargs) return _control - def _get_plot(self, branch): - plot_map = branch[self._tree.current_process].plot_map + def _get_plot(self, branch, **kwargs): + plot_map = branch[self._tree.current_process]._plot_map plot_path_lookup = {plot_key: next(iter(path_dict.values())) for plot_key, path_dict in plot_map.items()} spec = branch.spec[: branch.current_process] cache_key = str({"spec": spec, "plot_kwargs": self._plot_kwargs}) @@ -57,9 +86,10 @@ def _get_plot(self, branch): if plot_path.exists(): return Image(plot_path) generated = False - if not (plot_obj := self._plot_cache.get(cache_key, None)): + plot_obj = self._plot_cache.get(cache_key, None) + if not plot_obj: generated = True - plot_obj = self._generate_plot(branch) + plot_obj = self._generate_plot(branch, **kwargs) self._plot_cache[cache_key] = plot_obj if isinstance(plot_obj, tuple) and isinstance(plot_obj[0], Figure): if not generated: @@ -71,11 +101,8 @@ def _get_plot(self, branch): def _generate_plot(self, branch: "DataBranch", **kwargs): method = branch.plot.method_lookup[self._plot_key] - # fig, ax = method(**self._plot_kwargs) kwargs = {**self._plot_kwargs, **kwargs} - # TODO: might be able to use ax if integrate ax.figure - return method(**kwargs) # , ax=self.ax - # return fig, ax + return method(**kwargs) def __str__(self): return str(self._branch_spec) diff --git a/dataforest/core/ProcessRun.py b/dataforest/core/ProcessRun.py index c1936a5..63c5380 100644 --- a/dataforest/core/ProcessRun.py +++ b/dataforest/core/ProcessRun.py @@ -28,7 +28,7 @@ def __init__(self, branch: "DataBranch", process_name: str, process: str): self._file_lookup = self._build_file_lookup() self._plot_lookup = self._build_plot_lookup() self._path_map = None - self._plot_map = None + self._plot_map_cache = None self._path_map_prior = None @property @@ -46,6 +46,14 @@ def path(self) -> Path: """Path to directory containing processes output files and logs""" return self.branch.paths[self.process_name] + @property + def plots(self): + return self.branch.plot.show(self.process_name) + + @property + def plot_methods(self): + return self.branch.plot.methods[self.process] + @property def logs_path(self) -> Path: return self.path / "_logs" @@ -84,14 +92,12 @@ def process_path_map(self) -> Dict[str, Path]: return {file_alias: self.path / self._file_lookup[file_alias] for file_alias in self._file_lookup} @property - def process_plot_map(self) -> Dict[str, Path]: - plot_map_dict = {} - for plot_name in self._plot_lookup: - plot_map_dict[plot_name] = {} - for plot_kwargs_key, plot_filepath in self._plot_lookup[plot_name].items(): - plot_map_dict[plot_name][plot_kwargs_key] = self.plots_path / plot_filepath - - return plot_map_dict + def plot_lookup(self) -> Dict[str, List[Path]]: + """ + Key: plot key (e.g. "_UMIS_PER_CELL_HIST_") + Value: list of plot paths of the type specified by key + """ + return {plot_key: list(plot_path_map.values()) for plot_key, plot_path_map in self._plot_map.items()} @property def path_map(self) -> Dict[str, Path]: @@ -105,13 +111,6 @@ def path_map(self) -> Dict[str, Path]: self._path_map = self._build_path_map(incl_current=True) return self._path_map - @property - def plot_map(self) -> Dict[str, Path]: - # TODO: confusing with new plot_map name in config - rename that to plot_settings? - if self._plot_map is None: - self._plot_map = self._build_path_map(incl_current=True, plot_map=True) - return self._plot_map - @property def path_map_prior(self) -> Dict[str, Path]: """Like `path_map`, but for excluding the current process""" @@ -175,27 +174,29 @@ def logs(self): """ Prints stdout and stderr log files """ - log_dir = self.path / "_logs" - log_files = list(log_dir.iterdir()) - stdouts = list(filter(lambda x: str(x).endswith(".out"), log_files)) - stderrs = list(filter(lambda x: str(x).endswith(".err"), log_files)) - if (len(stdouts) == 0) and (len(stderrs) == 0): - raise ValueError(f"No logs for processes: {self.process_name}") - for stdout in stdouts: - name = str(stdout.name).split(".out")[0] - cprint(f"STDOUT: {name}", "cyan", "on_grey") - with open(str(stdout), "r") as f: - print(f.read()) - for stderr in stderrs: - name = str(stderr.name).split(".err")[0] - cprint(f"STDERR: {name}", "magenta", "on_grey") - with open(str(stderr), "r") as f: - print(f.read()) + self._print_logs() def subprocess_runs(self, process_name: str) -> pd.DataFrame: """DataFrame of branch_spec info for all runs of a given subprocess""" raise NotImplementedError() + @property + def _process_plot_map(self) -> Dict[str, Path]: + plot_map_dict = {} + for plot_name in self._plot_lookup: + plot_map_dict[plot_name] = {} + for plot_kwargs_key, plot_filepath in self._plot_lookup[plot_name].items(): + plot_map_dict[plot_name][plot_kwargs_key] = self.plots_path / plot_filepath + + return plot_map_dict + + @property + def _plot_map(self) -> Dict[str, Path]: + # TODO: confusing with new plot_map name in config - rename that to plot_settings? + if self._plot_map_cache is None: + self._plot_map_cache = self._build_path_map(incl_current=True, plot_map=True) + return self._plot_map_cache + def _build_layers_files(self) -> Dict[str, str]: """ Builds {file_alias: filename} lookup for additional layers specified @@ -228,13 +229,31 @@ def _build_path_map(self, incl_current: bool = False, plot_map: bool = False) -> precursor_lookup = spec.get_precursors_lookup(incl_current=incl_current, incl_root=True) precursors = precursor_lookup[self.process_name] process_runs = [self] if plot_map else [self.branch[process_name] for process_name in precursors] - pr_attr = "process_plot_map" if plot_map else "process_path_map" + pr_attr = "_process_plot_map" if plot_map else "process_path_map" process_path_map_list = [getattr(pr, pr_attr) for pr in process_runs] path_map = dict() for process_path_map in process_path_map_list: path_map.update(process_path_map) return path_map + def _print_logs(self): + log_dir = self.path / "_logs" + log_files = list(log_dir.iterdir()) + stdouts = list(filter(lambda x: str(x).endswith(".out"), log_files)) + stderrs = list(filter(lambda x: str(x).endswith(".err"), log_files)) + if (len(stdouts) == 0) and (len(stderrs) == 0): + raise ValueError(f"No logs for processes: {self.process_name}") + for stdout in stdouts: + name = str(stdout.name).split(".out")[0] + cprint(f"STDOUT: {name}", "cyan", "on_grey") + with open(str(stdout), "r") as f: + print(f.read()) + for stderr in stderrs: + name = str(stderr.name).split(".err")[0] + cprint(f"STDERR: {name}", "magenta", "on_grey") + with open(str(stderr), "r") as f: + print(f.read()) + def __repr__(self): repr_ = super().__repr__()[:-1] # remove closing bracket to append repr_ += f" process: {self.process}; process_name: {self.process_name}; done: {self.done}>" diff --git a/dataforest/core/TreeSpec.py b/dataforest/core/TreeSpec.py index 0dd1172..5d13fdf 100644 --- a/dataforest/core/TreeSpec.py +++ b/dataforest/core/TreeSpec.py @@ -1,9 +1,13 @@ +from copy import deepcopy from itertools import product -from typing import Union, List, Dict, TYPE_CHECKING +from typing import Union, List, Dict, TYPE_CHECKING, Optional from dataforest.core.BranchSpec import BranchSpec from dataforest.core.RunGroupSpec import RunGroupSpec +if TYPE_CHECKING: + from dataforest.core.RunSpec import RunSpec + class TreeSpec(BranchSpec): """ @@ -47,10 +51,11 @@ class TreeSpec(BranchSpec): _RUN_SPEC_CLASS = RunGroupSpec - def __init__(self, tree_spec: Union[List[dict], "TreeSpec[RunGroupSpec]"]): + def __init__(self, tree_spec: Union[List[dict], "TreeSpec[RunGroupSpec]"], twigs: Optional[List[dict]]): super(list, self).__init__() self.extend([RunGroupSpec(item) for item in tree_spec]) - self.branch_specs = self._build_branch_specs() + self.twig_specs = dict() + self.branch_specs = self._build_branch_specs(twigs) self.sweep_dict = {x["_PROCESS_"]: x.sweeps for x in self} self.sweep_dict["root"] = set() self._raw = tree_spec @@ -62,8 +67,32 @@ def __init__(self, tree_spec: Union[List[dict], "TreeSpec[RunGroupSpec]"]): incl_root=True, incl_current=True ) - def _build_branch_specs(self): - return list(map(BranchSpec, product(*[run_group_spec.run_specs for run_group_spec in self]))) + def _build_branch_specs(self, twigs: Optional[List[dict]]): + specs = list(map(BranchSpec, product(*[run_group_spec.run_specs for run_group_spec in self]))) + template = specs[0] + specs.extend(self._add_twigs(template, twigs)) + return specs + + def _add_twigs(self, template: "RunSpec", twigs: Optional[List[Union[tuple, list]]]): + specs = list() + self.twig_specs["base"] = template + if twigs is None: + return specs + for twig in twigs: + spec_ = deepcopy(template) + if isinstance(twig, tuple): + twig = [twig] + for mod in twig: + val = mod[-1] + final_key = mod[-2] + accessors = mod[:-2] + scope = spec_ + for key in accessors: + scope = scope[key] + scope[final_key] = val + specs.append(spec_) + self.twig_specs[str(twig)] = spec_ + return specs def __setitem__(self, key, value): if not isinstance(key, int): diff --git a/dataforest/hooks/hooks/core/hooks.py b/dataforest/hooks/hooks/core/hooks.py index cecef11..4bf07fa 100644 --- a/dataforest/hooks/hooks/core/hooks.py +++ b/dataforest/hooks/hooks/core/hooks.py @@ -129,11 +129,12 @@ def hook_catalogue(dp): def hook_generate_plots(dp: dataprocess): plot_sources = dp.branch.plot.plot_method_lookup current_process = dp.branch.current_process - all_plot_kwargs_sets = dp.branch.plot.plot_kwargs[current_process] - process_plot_methods = dp.branch.plot.plot_methods[current_process] - process_plot_map = dp.branch[dp.branch.current_process].plot_map + all_plot_kwargs_sets = dp.branch.plot.plot_kwargs.get(current_process, dict()) + process_plot_methods = dp.branch.plot.plot_methods.get(current_process, dict()) + process_plot_map = dp.branch[current_process]._plot_map requested_plot_methods = deepcopy(process_plot_methods) + exceptions = [] for method in plot_sources.values(): plot_method_name = method.__name__ if plot_method_name in requested_plot_methods.values(): @@ -143,13 +144,18 @@ def hook_generate_plots(dp: dataprocess): plot_path = process_plot_map[plot_name][plot_kwargs_key] kwargs = deepcopy(plot_kwargs_sets[plot_kwargs_key]) kwargs["plot_path"] = plot_path - method(dp.branch, **kwargs) + try: + method(dp.branch, show=False, **kwargs) + except Exception as e: + exceptions.append(e) requested_plot_methods.pop(plot_name) if len(requested_plot_methods) > 0: # if not all requested mapped to functions in plot sources logging.warning( f"Requested plotting methods {requested_plot_methods} are not implemented so they were skipped." ) + for e in exceptions: + raise e @hook diff --git a/dataforest/plot/wrappers.py b/dataforest/plot/wrappers.py index 8c85faf..2bbc20c 100644 --- a/dataforest/plot/wrappers.py +++ b/dataforest/plot/wrappers.py @@ -1,6 +1,7 @@ from functools import wraps import json import logging +from pathlib import Path from itertools import product from typing import Tuple, Optional, AnyStr, Union @@ -25,6 +26,7 @@ def wrapper( facet: Optional[str] = None, plot_path: Optional[AnyStr] = None, facet_dim: tuple = (), + show: bool = True, **kwargs, ) -> Union[plt.Figure, Tuple[plt.Figure, np.ndarray]]: prep = PlotPreparator(branch) @@ -48,9 +50,13 @@ def wrapper( if stratify is not None: kwargs["label"] = row["stratify"] plot_func(row["branch"], ax=ax, **kwargs) + if stratify is not None: + ax.legend() if plot_path is not None: logging.info(f"saving py figure to {plot_path}") prep.fig.savefig(plot_path) + if show: + display(prep.fig) return prep.fig, prep.ax_arr return wrapper @@ -63,6 +69,7 @@ def wrapper( stratify: Optional[str] = None, facet: Optional[str] = None, plot_path: Optional[AnyStr] = None, + show: bool = True, **kwargs, ): def _get_plot_script(): @@ -73,10 +80,12 @@ def _get_plot_script(): stratify = None if stratify in PlotPreparator.NONE_VARIANTS else stratify facet = None if facet in PlotPreparator.NONE_VARIANTS else facet + plot_path = plot_path if plot_path else "/tmp/plot.png" + if facet is not None: - logging.warning("facet not yet implemented for R plots, but will create separate plots") + logging.warning(f"{plot_func.__name__} facet not yet implemented for R plots") if stratify is not None: - logging.warning("stratify not yet supported for R plots") + logging.warning(f"{plot_func.__name__} stratify not yet supported for R plots") plot_source, r_script = _get_plot_script() plot_size = kwargs.pop("plot_size", PlotPreparator.DEFAULT_PLOT_RESOLUTION_PX) if stratify is not None: @@ -108,9 +117,13 @@ def _get_plot_script(): ] logging.info(f"saved R figure to {plot_path}") plot_func(branch, r_script, args) # plot_kwargs already included in args - img = Image("/tmp/plot.png") - display(img) - img_arr.append(img) + img = Image(plot_path) + if not Path(plot_path).exists(): + logging.warning(f"{plot_func.__name__} output file not found after execution for {val}") + else: + if show: + display(img) + img_arr.append(img) return img_arr return wrapper @@ -126,6 +139,8 @@ def __call__(self, func): @wraps(func) def wrapper(branch: "DataBranch", *args, **kwargs): + if not branch[self._req_process].success: + raise ValueError(f"`{func.__name__}` requires a complete and successful process: `{self._req_process}`") precursors = branch.spec.get_precursors_lookup(incl_current=True)[branch.current_process] if self._req_process not in precursors: proc = self._req_process diff --git a/dataforest/processes/core/TreeProcessMethods.py b/dataforest/processes/core/TreeProcessMethods.py index 16bd5ba..790498a 100644 --- a/dataforest/processes/core/TreeProcessMethods.py +++ b/dataforest/processes/core/TreeProcessMethods.py @@ -94,10 +94,7 @@ def _distributed_kernel_parallel(): return pool(process(branch) for branch in unique_branches.values()) exec_scheme = "PARALLEL" if parallel else "SERIAL" - logging.info( - f"{exec_scheme} execution of {method_name} over {self._N_JOBS} workers on {len(unique_branches)} " - f"unique conditions" - ) + logging.info(f"{exec_scheme} execution of {method_name} on {len(unique_branches)} unique branches") kernel = _distributed_kernel_parallel if parallel else _distributed_kernel_serial ret_vals = kernel() return ret_vals diff --git a/dataforest/utils/plots_config.py b/dataforest/utils/plots_config.py index e1bdbe0..f5be14d 100644 --- a/dataforest/utils/plots_config.py +++ b/dataforest/utils/plots_config.py @@ -4,7 +4,7 @@ from typing import Dict -def build_process_plot_method_lookup(config: dict) -> Dict[str, Dict[str, str]]: +def build_process_plot_method_lookup(plot_map: dict) -> Dict[str, Dict[str, str]]: """ Get a lookup of processes, each containing a mapping between plot method names in the config and the actual callable names. @@ -13,7 +13,6 @@ def build_process_plot_method_lookup(config: dict) -> Dict[str, Dict[str, str]]: Ex: {"normalize": {"_UMIS_PER_CELL_HIST_": "plot_umis_per_cell_hist", ...}, ...} """ - plot_map = config["plot_map"] process_plot_methods = {} for process, plots in plot_map.items(): process_plot_methods[process] = {} @@ -27,10 +26,8 @@ def build_process_plot_method_lookup(config: dict) -> Dict[str, Dict[str, str]]: return process_plot_methods -def parse_plot_kwargs(config: dict): +def parse_plot_kwargs(plot_map: dict, plot_kwargs_defaults: dict): """Parse plot methods plot_kwargs per process from plot_map""" - plot_map = config["plot_map"] - plot_kwargs_defaults = config["plot_kwargs_defaults"] all_plot_kwargs = {} for process, plots in plot_map.items(): all_plot_kwargs[process] = {} From 2a72c28d937cafdacd0a06e294567e6ee90cd352 Mon Sep 17 00:00:00 2001 From: Munchic Date: Tue, 20 Oct 2020 22:18:12 -0700 Subject: [PATCH 18/19] docstrings on plot helper functions --- dataforest/utils/plots_config.py | 44 +++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/dataforest/utils/plots_config.py b/dataforest/utils/plots_config.py index 8092853..9319a0f 100644 --- a/dataforest/utils/plots_config.py +++ b/dataforest/utils/plots_config.py @@ -28,7 +28,17 @@ def build_process_plot_method_lookup(config: dict) -> Dict[str, Dict[str, str]]: def parse_plot_kwargs(config: dict): - """Parse plot methods plot_kwargs per process from plot_map""" + """ + Parse plot methods plot_kwargs per process from plot_map + + Example: + { + "root": { + "_UMIS_PER_CELL_HIST_": { + '{"plot_size": "default", "stratify": "default"}': {"stratify": "none", "plot_size": [800, 800]} + } + } + """ plot_map = config["plot_map"] plot_kwargs_defaults = config["plot_kwargs_defaults"] all_plot_kwargs = {} @@ -58,6 +68,14 @@ def parse_plot_map(config: dict): """ Parse plot file map per process from plot_map and ensures that implicit definition returns a dictionary of default values for all plot_kwargs + + Example: + { + "root": { + "_UMIS_PER_CELL_HIST_": { + '{"plot_size": "default", "stratify": "default"}': "umis_per_cell_hist-plot_size:800+800-stratify:none.png" + } + } """ plot_map = config["plot_map"] plot_kwargs_defaults = config["plot_kwargs_defaults"] @@ -173,6 +191,30 @@ def _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults): def _get_plot_kwargs_feed(plot_kwargs: dict, plot_kwargs_defaults: dict, plot_name: str, map_to_default_kwargs=False): + """ + Create a feeding of kwargs tuples to create multiple plots, replacing values with + plot_kwargs_defaults values + + Example: + # input plot_kwargs + { + "stratify": ["sample_id", "default"], + "plot_size": ["default", "default"] + } + + # output: + [ + { + "stratify": "sample_id", + "plot_size": [800, 800] + }, + { + "stratify": None, + "plot_size": [800, 800] + } + ] + """ + plot_kwargs = _unify_kwargs_opt_lens(plot_kwargs, plot_kwargs_defaults, plot_name) if map_to_default_kwargs: plot_kwargs = _map_kwargs_opts_to_values(plot_kwargs, plot_kwargs_defaults) From c77f979948311b0354a0928a261deda14061bdc4 Mon Sep 17 00:00:00 2001 From: theaustinator Date: Sat, 24 Oct 2020 23:51:38 -0600 Subject: [PATCH 19/19] Unified widget for twigs and sweeps; metadata operations for branch and tree -- so dope --- dataforest/core/DataBase.py | 4 +- dataforest/core/DataBranch.py | 8 +- dataforest/core/DataTree.py | 18 +++- dataforest/core/Interface.py | 5 +- dataforest/core/PlotMethods.py | 34 ++++--- dataforest/core/PlotTreeMethods.py | 12 +++ dataforest/core/RunGroupSpec.py | 18 ++-- dataforest/core/TreeDataFrame.py | 73 +++++++++++++++ dataforest/core/TreeSpec.py | 56 +++++++----- dataforest/{core => plot}/PlotWidget.py | 88 ++++++++++--------- .../processes/core/TreeProcessMethods.py | 17 ++-- requirements.txt | 1 + 12 files changed, 228 insertions(+), 106 deletions(-) create mode 100644 dataforest/core/TreeDataFrame.py rename dataforest/{core => plot}/PlotWidget.py (53%) diff --git a/dataforest/core/DataBase.py b/dataforest/core/DataBase.py index e5bda84..c1929d4 100644 --- a/dataforest/core/DataBase.py +++ b/dataforest/core/DataBase.py @@ -10,9 +10,11 @@ class DataBase: Mixin for `DataTree`, `DataBranch`, and derived class """ + _PLOT_METHODS = PlotMethods + def __init__(self): self.root = None - self.plot = PlotMethods(self) + self.plot = self._PLOT_METHODS(self) @property def root_built(self): diff --git a/dataforest/core/DataBranch.py b/dataforest/core/DataBranch.py index 292a826..427a6ad 100644 --- a/dataforest/core/DataBranch.py +++ b/dataforest/core/DataBranch.py @@ -5,6 +5,8 @@ import pandas as pd from pathlib import Path +from typeguard import typechecked + from dataforest.core.DataBase import DataBase from dataforest.structures.cache.PathCache import PathCache from dataforest.structures.cache.IOCache import IOCache @@ -87,7 +89,6 @@ class DataBranch(DataBase): _writer_method_map: """ - PLOT_METHODS: Type = PlotMethods PROCESS_METHODS: Type = ProcessMethods SCHEMA_CLASS: Type = ProcessSchema READER_METHODS: Type = ReaderMethods @@ -123,7 +124,6 @@ def __init__( self.spec = self._init_spec(branch_spec) self.verbose = verbose self.logger = logging.getLogger(self.__class__.__name__) - self.plot = self.PLOT_METHODS(self) self.process = self.PROCESS_METHODS(self, self.spec) self.schema = self.SCHEMA_CLASS() self._paths_exists = PathCache(self.root, self.spec, exists_req=True) @@ -188,6 +188,10 @@ def meta(self) -> pd.DataFrame: self._meta = self._get_meta(self.current_process) return self._meta + def add_metadata(self, df: pd.DataFrame): + # TODO: need to write to keep persistent? + self.set_meta(pd.concat([self.meta, df], axis=1)) + def goto_process(self, process_name: str) -> "DataBranch": """ Updates the state of the `DataBranch` object to reflect the diff --git a/dataforest/core/DataTree.py b/dataforest/core/DataTree.py index cd4c0bd..c77fbcc 100644 --- a/dataforest/core/DataTree.py +++ b/dataforest/core/DataTree.py @@ -4,8 +4,9 @@ from dataforest.core.DataBase import DataBase from dataforest.core.DataBranch import DataBranch +from dataforest.core.PlotTreeMethods import PlotTreeMethods from dataforest.core.ProcessTreeRun import ProcessTreeRun -from dataforest.core.RunGroupSpec import RunGroupSpec +from dataforest.core.TreeDataFrame import DataFrameList from dataforest.core.TreeSpec import TreeSpec from dataforest.processes.core.TreeProcessMethods import TreeProcessMethods from dataforest.structures.cache.BranchCache import BranchCache @@ -14,6 +15,7 @@ class DataTree(DataBase): _LOG = logging.getLogger("DataTree") _BRANCH_CLASS = DataBranch + _PLOT_METHODS = PlotTreeMethods def __init__( self, @@ -35,7 +37,11 @@ def __init__( self.remote_root = remote_root self._branch_cache = BranchCache(root, self.tree_spec.branch_specs, self._BRANCH_CLASS, verbose, remote_root,) self._process_tree_runs = dict() - self.process = TreeProcessMethods(self.tree_spec, self._branch_cache) + self.process = TreeProcessMethods(self) + + @property + def meta(self): + return DataFrameList([branch.meta for branch in self.branches]) @property def n_branches(self): @@ -71,6 +77,14 @@ def load_all(self): def run_all(self, workers: int = 1, batch_queue: Optional[str] = None): return [method() for method in self.process.process_methods] + def unique_branches_at_process(self, process_name: str) -> Dict[str, "DataBranch"]: + """ + Gets a subset of branches representing those unique up to the specified + process. From two branches which only become distinguished after + `process_name`, just one will be selected. + """ + return {str(branch.spec[:process_name]): branch for branch in self.branches} + def create_root_plots(self, plot_kwargs: Optional[Dict[str, dict]] = None): rand_spec = self.tree_spec.branch_specs[0] rand_branch = self._branch_cache[str(rand_spec)] diff --git a/dataforest/core/Interface.py b/dataforest/core/Interface.py index f302ee5..c7f9f59 100644 --- a/dataforest/core/Interface.py +++ b/dataforest/core/Interface.py @@ -1,3 +1,4 @@ +from copy import deepcopy from pathlib import Path from typing import Union, Optional, Iterable, List, AnyStr @@ -37,8 +38,8 @@ def load( """ kwargs = { - "branch_spec": branch_spec, - "tree_spec": tree_spec, + "branch_spec": deepcopy(branch_spec), + "tree_spec": deepcopy(tree_spec), "verbose": verbose, "current_process": current_process, "remote_root": remote_root, diff --git a/dataforest/core/PlotMethods.py b/dataforest/core/PlotMethods.py index 3e8cf73..333293c 100644 --- a/dataforest/core/PlotMethods.py +++ b/dataforest/core/PlotMethods.py @@ -23,14 +23,12 @@ class PlotMethods(metaclass=MetaPlotMethods): """ def __init__(self, branch: "DataBranch"): - self._logger = logging.getLogger(self.__class__.__name__) self.branch = branch for name, plot_method in self.plot_method_lookup.items(): callable_ = copy_func(plot_method) callable_.__name__ = name setattr(self, name, self._wrap(callable_)) tether(self, "branch", incl_methods=list(self.plot_method_lookup.keys())) - # self._plot_cache = {process: PlotCache(self.branch, process) for process in self.branch.spec.processes} self._img_cache = {} def show(self, process_name: str): @@ -72,20 +70,7 @@ def plot_methods(self): @property def method_lookup(self): - return {k: getattr(self, method_name) for k, method_name in self.method_name_lookup.items()} - - @property - def method_name_lookup(self): - global_plot_methods = { - config_name: callable_name - for name_mapping in self.plot_methods.values() - for config_name, callable_name in name_mapping.items() - } - return global_plot_methods - - @property - def method_key_lookup(self): - return {v: k for k, v in self.method_name_lookup.items()} + return {k: getattr(self, method_name) for k, method_name in self.key_method_lookup.items()} @property def plot_kwargs_defaults(self): @@ -123,14 +108,27 @@ def _assign(method_name): list(map(_assign, method_names)) return avail_dict + @property + def key_method_lookup(self): + """ + Key: method key in format of config (e.g. "_UMAP_EMBEDDINGS_SCAT_") + Value: method name (e.g. "plot_umap_embeddings_scat") + """ + convert_to_key = lambda s: "_" + s.upper()[5:] + "_" + return {convert_to_key(x): x for k, v in self.methods.items() for x in v} + + @property + def method_key_lookup(self): + """inverted `key_method_lookup`""" + return {v: k for k, v in self.key_method_lookup.items()} + @property def keys(self) -> Dict[str, Set[str]]: """ Key: process name at which plot becomes unlocked Value: keys for plots in config """ - convert_to_key = lambda set_: set(map(lambda s: "_" + s.upper()[5:] + "_", set_)) - return {k: convert_to_key(v) for k, v in self.methods.items()} + return {k: set(map(lambda x: self.method_key_lookup[x], v)) for k, v in self.methods.items()} def _wrap(self, method): """Wrap with mkdirs and logging""" diff --git a/dataforest/core/PlotTreeMethods.py b/dataforest/core/PlotTreeMethods.py index c5eae6e..5cebedb 100644 --- a/dataforest/core/PlotTreeMethods.py +++ b/dataforest/core/PlotTreeMethods.py @@ -1,6 +1,18 @@ from dataforest.core.PlotMethods import PlotMethods +from dataforest.plot.PlotWidget import PlotWidget class PlotTreeMethods(PlotMethods): def __init__(self, tree): self._tree = tree + for method_name in self.plot_method_lookup.keys(): + setattr(self, method_name, self._widget_wrap(method_name)) + + def _widget_wrap(self, method_name): + def wrap(**kwargs): + plot_key = self.method_key_lookup[method_name] + widget = PlotWidget(self._tree, plot_key, **kwargs) + return widget.build_control(show=True, stop_on_error=True) + + wrap.__name__ = method_name + return wrap diff --git a/dataforest/core/RunGroupSpec.py b/dataforest/core/RunGroupSpec.py index 1b9b668..88ad8f8 100644 --- a/dataforest/core/RunGroupSpec.py +++ b/dataforest/core/RunGroupSpec.py @@ -1,5 +1,5 @@ from itertools import product -from typing import Union, Iterable, List +from typing import Union, Iterable, List, Tuple, Any from dataforest.core.RunSpec import RunSpec from dataforest.core.Sweep import Sweep @@ -33,11 +33,12 @@ def _build_run_specs(self): run_specs: list of `RunSpec`s representing all combinations of values from sweeps """ - sub_groups = self._expand_sweeps(self, (list, set, tuple)) + sub_groups, combos = self._expand_sweeps(self, (list, set, tuple)) # if there were no operations in array format to be expanded if isinstance(sub_groups, dict): self._map_sweeps() - self_expand = {k: self._expand_sweeps(v, Sweep) for k, v in self.items() if isinstance(v, dict)} + # expand _SWEEP_ keys + self_expand = {k: self._expand_sweeps(v, Sweep)[0] for k, v in self.items() if isinstance(v, dict)} self_expand.update({k: v for k, v in self.items() if not isinstance(v, dict)}) # if any operations are still in array format if any(map(lambda x: isinstance(x, list), self_expand.values())): @@ -64,7 +65,8 @@ def _map_sweeps(self): self[operation][key] = Sweep(operation, key, sweep_obj) @staticmethod - def _expand_sweeps(dict_: dict, types: Union[type, Iterable[type]]) -> Union[dict, List[dict]]: + # TODO: type hint combos + def _expand_sweeps(dict_: dict, types: Union[type, Iterable[type]]) -> Tuple[Union[dict, List[dict]], Any]: """ Expand sweeps into a list of dicts representing all possible combinations. Args: @@ -73,7 +75,9 @@ def _expand_sweeps(dict_: dict, types: Union[type, Iterable[type]]) -> Union[dic """ sweeps_part = order_dict({k: v for k, v in dict_.items() if isinstance(v, types)}) static_part = {k: v for k, v in dict_.items() if k not in sweeps_part} - combos = product(*sweeps_part.values()) + combos = list(product(*sweeps_part.values())) dicts = [{**static_part, **dict(zip(sweeps_part.keys(), combo))} for combo in combos] - dicts = dicts[0] if len(dicts) == 1 else dicts - return dicts + if len(dicts) == 1: + dicts = dicts[0] + combos = combos[0] + return dicts, combos diff --git a/dataforest/core/TreeDataFrame.py b/dataforest/core/TreeDataFrame.py new file mode 100644 index 0000000..53a1327 --- /dev/null +++ b/dataforest/core/TreeDataFrame.py @@ -0,0 +1,73 @@ +from functools import wraps +from typing import TYPE_CHECKING, Callable, List, Iterable, Union + +import pandas as pd + +from dataforest.structures.cache.BranchCache import BranchCache + +if TYPE_CHECKING: + from dataforest.core.DataTree import DataTree + + +class DataFrameList(list): + TETHER_EXCLUDE = {"__class__", "__init__", "__weakref__", "__dict__", "__getitem__", "__setitem__"} + + def __init__(self, df_list: Iterable[Union[pd.DataFrame, pd.Series]]): + super().__init__(list(df_list)) + self._elem_class = self._get_elem_class(self) + self._tether_df_methods() + + def get_elem(self, i: int): + """ + Get the df or series from the list-like structure since getitem is + overloaded by the pandas method + Args: + i: index in list-like structure + """ + return list.__getitem__(self, i) + + @staticmethod + def _get_elem_class(container): + if all(isinstance(x, pd.DataFrame) for x in container): + return pd.DataFrame + elif all(isinstance(x, pd.Series) for x in container): + return pd.Series + + def _tether_df_methods(self): + if self._elem_class is None: + return + names = set(dir(self._elem_class)).difference(self.TETHER_EXCLUDE) + for name in names: + distributed_method = self._build_distributed_method(name) + setattr(self, name, distributed_method) + self._getitem = self._build_distributed_method("__getitem__") + self._setitem = self._build_distributed_method("__setitem__") + + def _build_distributed_method(self, method_name) -> Callable: + df_method = getattr(self._elem_class, method_name) + + @wraps(df_method) + def _distributed_method(*args, **kwargs): + def _split_args(i, arg): + if isinstance(arg, self.__class__): + return arg.get_elem(i) + return arg + + def _single_kernel(i, df): + args_ = [_split_args(i, arg) for arg in args] + kwargs_ = {k: _split_args(i, v) for k, v in kwargs.items()} + return df_method(df, *args_, **kwargs_) + + def _distributed_kernel(): + ret = [_single_kernel(*x) for x in enumerate(self)] + return self.__class__(ret) + + return _distributed_kernel() + + return _distributed_method + + def __getitem__(self, item): + return self._getitem(item) + + def __setitem__(self, key, value): + return self._setitem(key, value) diff --git a/dataforest/core/TreeSpec.py b/dataforest/core/TreeSpec.py index 5d13fdf..2abdff4 100644 --- a/dataforest/core/TreeSpec.py +++ b/dataforest/core/TreeSpec.py @@ -6,6 +6,7 @@ from dataforest.core.RunGroupSpec import RunGroupSpec if TYPE_CHECKING: + from dataforest.core.DataBranch import DataBranch from dataforest.core.RunSpec import RunSpec @@ -54,7 +55,10 @@ class TreeSpec(BranchSpec): def __init__(self, tree_spec: Union[List[dict], "TreeSpec[RunGroupSpec]"], twigs: Optional[List[dict]]): super(list, self).__init__() self.extend([RunGroupSpec(item) for item in tree_spec]) - self.twig_specs = dict() + self.twig_lookup = {str(twig): twig for twig in twigs} + # TODO: abstract to method + if self.twig_lookup: + self.twig_lookup = {"base": [], **self.twig_lookup} self.branch_specs = self._build_branch_specs(twigs) self.sweep_dict = {x["_PROCESS_"]: x.sweeps for x in self} self.sweep_dict["root"] = set() @@ -67,31 +71,39 @@ def __init__(self, tree_spec: Union[List[dict], "TreeSpec[RunGroupSpec]"], twigs incl_root=True, incl_current=True ) - def _build_branch_specs(self, twigs: Optional[List[dict]]): - specs = list(map(BranchSpec, product(*[run_group_spec.run_specs for run_group_spec in self]))) - template = specs[0] - specs.extend(self._add_twigs(template, twigs)) - return specs + @staticmethod + def add_twig(template: "BranchSpec", twig: Union[tuple, list]): + spec = deepcopy(template) + if isinstance(twig, tuple): + twig = [twig] + for mod in twig: + val = mod[-1] + final_key = mod[-2] + accessors = mod[:-2] + scope = spec + for key in accessors: + scope = scope[key] + scope[final_key] = val + return spec - def _add_twigs(self, template: "RunSpec", twigs: Optional[List[Union[tuple, list]]]): - specs = list() - self.twig_specs["base"] = template + @staticmethod + def _apply_add_twigs(specs: List["BranchSpec"], twigs: Optional[Union[tuple, list]]): if twigs is None: return specs + specs = [spec for x in specs for spec in TreeSpec._add_twigs(x, twigs)] + return specs + + @staticmethod + def _add_twigs(template: "BranchSpec", twigs: List[Union[tuple, list]]): + specs = list() for twig in twigs: - spec_ = deepcopy(template) - if isinstance(twig, tuple): - twig = [twig] - for mod in twig: - val = mod[-1] - final_key = mod[-2] - accessors = mod[:-2] - scope = spec_ - for key in accessors: - scope = scope[key] - scope[final_key] = val - specs.append(spec_) - self.twig_specs[str(twig)] = spec_ + spec = TreeSpec.add_twig(template, twig) + specs.append(spec) + return specs + + def _build_branch_specs(self, twigs: Optional[List[dict]]): + specs = list(map(BranchSpec, product(*[run_group_spec.run_specs for run_group_spec in self]))) + specs.extend(self._apply_add_twigs(specs, twigs)) return specs def __setitem__(self, key, value): diff --git a/dataforest/core/PlotWidget.py b/dataforest/plot/PlotWidget.py similarity index 53% rename from dataforest/core/PlotWidget.py rename to dataforest/plot/PlotWidget.py index b6d1098..b9748b7 100644 --- a/dataforest/core/PlotWidget.py +++ b/dataforest/plot/PlotWidget.py @@ -1,11 +1,13 @@ import logging from copy import deepcopy -from typing import TYPE_CHECKING, Dict, Any +from typing import TYPE_CHECKING, Dict, Any, Optional -from IPython.display import Image +from IPython.display import Image, display import ipywidgets as widgets from matplotlib.figure import Figure +from dataforest.core.TreeSpec import TreeSpec + if TYPE_CHECKING: from dataforest.core.DataBranch import DataBranch from dataforest.core.DataTree import DataTree @@ -21,35 +23,40 @@ def __init__(self, tree: "DataTree", plot_key: str, use_saved: bool = True, **pl use_saved: use saved plots or regenerate **plot_kwargs: """ + self.branch = None self._tree = tree - self._branch_spec = deepcopy(tree.tree_spec.branch_specs[0]) + self._branch_spec_template = deepcopy(tree.tree_spec.branch_specs[0]) + self._branch_spec = deepcopy(self._branch_spec_template) self._plot_key = plot_key self._use_saved = use_saved + self._bypass_kwargs = plot_kwargs.pop("bypass_kwargs", dict()) self._plot_kwargs = plot_kwargs self._sweeps = self._tree.tree_spec.sweep_dict self._plot_cache = dict() - # self.ax = plt.gca() - def control(self, **kwargs): - if self._tree.has_sweeps and self._tree.has_twigs: - logging.warning( - "PlotWidget doesn't support trees with both sweeps and twigs. Please use one or the other. Defaulting to sweeps" - ) - if self._tree.has_sweeps: - return self._sweeps_control(**kwargs) - else: - return self._twigs_control(**kwargs) + def build_control(self, **plotter_kwargs): + twig_lookup = self._tree.tree_spec.twig_lookup + _kwargs = {} - def _sweeps_control(self, **plotter_kwargs): - process = self._tree.current_process - precursors = self._tree.tree_spec.get_precursors_lookup(incl_current=True)[process] - sweeps = set.union(*[self._sweeps[precursor] for precursor in precursors]) - # {"key:operation:process": value, ...} (e.g. {"num_pcs:_PARAMS_:normalize": 30, ...}) - param_sweeps = {":".join(swp[:3][::-1]): list(swp[3]) for swp in sweeps} - _kwargs = {**param_sweeps, **self._plot_kwargs} + def _prep_sweeps_kwargs(): + process = self._tree.current_process + precursors = self._tree.tree_spec.get_precursors_lookup(incl_current=True)[process] + sweeps = set().union(*[self._sweeps[precursor] for precursor in precursors]) + # {"key:operation:process": value, ...} (e.g. {"num_pcs:_PARAMS_:normalize": 30, ...}) + param_sweeps = {":".join(swp[:3][::-1]): list(swp[3]) for swp in sweeps} + _kwargs.update({**param_sweeps, **self._plot_kwargs}) + + if self._tree.has_sweeps: + _prep_sweeps_kwargs() + if self._tree.has_twigs: + _kwargs["twig_str"] = list(twig_lookup.keys()) @widgets.interact(**_kwargs) + # TODO: this seems to be slower now that the spec is recalculated every time -- is it replotted? def _control(**kwargs: Dict[str, Any]): + spec = deepcopy(self._branch_spec_template) + twig_str = kwargs.pop("twig_str", None) + # add sweeps for param_keys_str, value in kwargs.items(): if param_keys_str in self._plot_kwargs: self._plot_kwargs[param_keys_str] = value @@ -57,26 +64,21 @@ def _control(**kwargs: Dict[str, Any]): elif isinstance(value, float): value = int(value) if int(value) == value else value (name, operation, process) = param_keys_str.split(":") - self._branch_spec[process][operation][name] = value + spec[process][operation][name] = value # TODO: do we need to do something with kwargs? Or taken care of? - [kwargs.pop(k) for k in self._plot_kwargs] - branch = self._tree._branch_cache[str(self)] - return self._get_plot(branch, **plotter_kwargs) - - return _control - - def _twigs_control(self, **plotter_kwargs): - twig_specs = self._tree.tree_spec.twig_specs - - @widgets.interact(twig_str=list(twig_specs.keys()), **self._plot_kwargs) - def _control(twig_str: str): - spec = twig_specs[twig_str] - branch = self._tree._branch_cache[str(spec)] - self._get_plot(branch, **plotter_kwargs) + [print(k, kwargs.pop(k, None)) for k in self._plot_kwargs] + # add twigs + if twig_str: + twig = twig_lookup[twig_str] + spec = TreeSpec.add_twig(spec, twig) + # get branch + self.branch = self._tree._branch_cache[str(spec)] + return self._get_plot(self.branch, **plotter_kwargs) return _control def _get_plot(self, branch, **kwargs): + kwargs = {**self._bypass_kwargs, **kwargs} plot_map = branch[self._tree.current_process]._plot_map plot_path_lookup = {plot_key: next(iter(path_dict.values())) for plot_key, path_dict in plot_map.items()} spec = branch.spec[: branch.current_process] @@ -96,16 +98,16 @@ def _get_plot(self, branch, **kwargs): return plot_obj[0] elif isinstance(plot_obj, Image): return plot_obj + elif isinstance(plot_obj, list) and all(isinstance(x, Image) for x in plot_obj): + for img in plot_obj: + display(img) + return plot_obj else: - raise TypeError(f"Expected types (matplotlib.axes.Axes, IPython.display.Image). Got {type(plot_obj)}") + raise TypeError( + f"Expected types (matplotlib.axes.Axes, IPython.display.Image, List[Ipython.display.Image]). Got {type(plot_obj)}" + ) def _generate_plot(self, branch: "DataBranch", **kwargs): method = branch.plot.method_lookup[self._plot_key] - kwargs = {**self._plot_kwargs, **kwargs} + kwargs = {**self._plot_kwargs, **kwargs, "show": False} return method(**kwargs) - - def __str__(self): - return str(self._branch_spec) - - def __repr__(self): - return repr(self._branch_spec) diff --git a/dataforest/processes/core/TreeProcessMethods.py b/dataforest/processes/core/TreeProcessMethods.py index 790498a..7ed4441 100644 --- a/dataforest/processes/core/TreeProcessMethods.py +++ b/dataforest/processes/core/TreeProcessMethods.py @@ -1,18 +1,20 @@ import logging -from typing import Callable, List, Union +from typing import Callable, List, Union, TYPE_CHECKING from joblib import Parallel, delayed from dataforest.core.TreeSpec import TreeSpec from dataforest.structures.cache.BranchCache import BranchCache +if TYPE_CHECKING: + from dataforest.core.DataTree import DataTree + class TreeProcessMethods: _N_CPUS_EXCLUDED = 1 - def __init__(self, tree_spec: TreeSpec, branch_cache: BranchCache): - self._tree_spec = tree_spec - self._branch_cache = branch_cache + def __init__(self, tree: "DataTree"): + self._tree = tree self._process_methods = list() self._tether_process_methods() @@ -22,7 +24,7 @@ def process_methods(self): def _tether_process_methods(self): # TODO: docstring - method_names = [run_group_spec.name for run_group_spec in self._tree_spec] + method_names = [run_group_spec.name for run_group_spec in self._tree.tree_spec] for name in method_names: distributed_method = self._build_distributed_method(name) setattr(self, name, distributed_method) @@ -68,10 +70,7 @@ def _distributed_method( """ kwargs = {"stop_on_error": stop_on_error, "stop_on_hook_error": stop_on_hook_error, **kwargs} - if not self._branch_cache.fully_loaded: - self._branch_cache.load_all() - all_branches = list(self._branch_cache.values()) - unique_branches = {str(branch.spec[:method_name]): branch for branch in all_branches} + unique_branches = self._tree.unique_branches_at_process(method_name) def _single_kernel(branch): branch_method = getattr(branch.process, method_name) diff --git a/requirements.txt b/requirements.txt index 445851b..b8a6d1a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +fastcore ipywidgets joblib matplotlib