From 9d5f63e2e798415f7d4cf5e6cb089b794e3b9ae4 Mon Sep 17 00:00:00 2001 From: Matthew Carbone Date: Fri, 12 Aug 2022 12:35:14 -0400 Subject: [PATCH 1/5] Update everything up to and including interpolate --- aimm_post_processing/operations.py | 156 +++++++++++++---------------- 1 file changed, 68 insertions(+), 88 deletions(-) diff --git a/aimm_post_processing/operations.py b/aimm_post_processing/operations.py index 98a1383..e755918 100644 --- a/aimm_post_processing/operations.py +++ b/aimm_post_processing/operations.py @@ -1,6 +1,6 @@ """Module for housing post-processing operations.""" -from abc import ABC +from abc import ABC, abstractmethod from datetime import datetime from uuid import uuid4 @@ -10,9 +10,11 @@ from scipy.interpolate import InterpolatedUnivariateSpline from sklearn.linear_model import LinearRegression from sklearn.metrics import mean_squared_error + + from aimm_post_processing import utils -from copy import deepcopy -from abc import ABC, abstractmethod +from tiled.client.dataframe import DataFrameClient + class Operator(MSONable, ABC): """Base operator class. Tracks everything required through a combination @@ -22,93 +24,72 @@ class Operator(MSONable, ABC): .. important:: The __call__ method must be derived for every operator. In particular, - this operator should take as arguments at least one other data point - (node). + this operator should take as arguments at least one data point (node). """ - def __init__( # Initialize common parameters for all operators - self, - x_column="energy", - y_columns=["mu"], - ): - self.x_column = x_column - self.y_columns = y_columns - self.operator_id = str(uuid4()) # UID for the defined operator. + @abstractmethod + def _process_data(self): + ... + @abstractmethod + def _process_metadata(self): + ... - def __call__(self, dataDict): - # meke a copy, otherwise python will make modification to input dataDict instead. - copy_dataDict = deepcopy(dataDict) + @abstractmethod + def __call__(self, client): + ... - new_metadata = self._process_metadata(copy_dataDict["metadata"]) - new_df = self._process_data(copy_dataDict["data"]) - return {"data": new_df, "metadata": new_metadata} +class UnaryOperator(Operator): + """Specialized operator class which takes only a single input. This input + must be of instance :class:`DataFrameClient`.""" def _process_metadata(self, metadata): - """Preliminary pre-processing of the dictionary object that contains data and metadata. - Takes the:class:`dict` object as input and returns the untouched data in addition to an - augmented metadata dictionary. + """Processing of the metadata dictionary object. Takes the + :class:`dict` object as input and returns a modified + dictionary with the following changes: + + 1. metadata["_tiled"]["uid"] is replaced with a new uuid string. + 2. metadata["post_processing"] is created with keys that indicate + the current state of the class, the parent ids Parameters - --------- - dataDict : dict - The data dictionary that contains data and metadata - local_kwargs : tuple - A tuple (usually `locals().items()` where it is called) of local variables in operator - space. + ---------- + metadata : dict + The metadata dictionary accessed via ``df_client.metadata``. - Notes - ----- - 1. Measurement `id` is suspiciously `_id` that sits under `sample` in the metadata. - Need to check if this hierchy is universal for all data. + Returns + ------- + dict + The new metadata object for the post-processed child. """ - # parents are the uid of the last processed data, or the original sample id otherwise. - try: - parent_id = metadata["post_processing"]["id"] - except: - parent_id = metadata['_tiled']['uid'] - dt = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") - metadata["post_processing"] = { - "id": str(uuid4()), - "parent": parent_id, - "operator": self.as_dict(), - "kwargs": self.__dict__, - "datetime": f"{dt} UTC", + return { + "_tiled": {"uid": str(uuid4())}, # Assign a new uid + "post_processing": { + "parents": [metadata["_tiled"]["uid"]], + "operator": self.as_dict(), + "kwargs": self.__dict__, + "datetime": f"{dt} UTC", + } } - return metadata - - @abstractmethod - def _process_data(self, df) -> dict: - """User must override this method. - """ - raise NotImplementedError - - -class Pull(Operator): - def __init__(self): - super().__init__() - - def __call__(self, dfClient): - """ This operator does nothing but return the data/metadata dictionary for a given - `tiled.client.dataframe.DataFrameClient`. - """ - metadata = deepcopy(dict(dfClient.metadata)) - new_metadata = self._process_metadata(metadata) - df = deepcopy(dfClient.read()) - new_df = self._process_data(df) - - return {"data": new_df, "metadata": new_metadata} - - def _process_data(self, df): - return df + def __call__(self, df_client): + if not isinstance(df_client, DataFrameClient): + raise ValueError( + f"df_client is of type {type(df_client)} but should be of " + "type DataFrameClient" + ) + return { + "data": self._process_data(df_client.read()), + "metadata": self._process_metadata(df_client.metadata) + } -class Identity(Operator): - """The identity operation. Does nothing. Used for testing purposes.""" +class Identity(UnaryOperator): + """The identity operation. Does nothing. Primarily used for testing + purposes.""" def __init__(self): super().__init__() @@ -117,22 +98,23 @@ def _process_data(self, df): """ Parameters ---------- - dataDict : pandas.DataFrame - The data is a :class:`pd.DataFrame` + df : pandas.DataFrame + The dataframe object to process. Returns ------- - dataDict : dict - Same as input + pandas.DataFrame """ + return df -class StandardizeGrid(Operator): +class StandardizeGrid(UnaryOperator): """Interpolates specified columns onto a common grid.""" def __init__( self, + *, x0, xf, nx, @@ -151,24 +133,22 @@ def __init__( The upper bound of the grid to interpolate onto. nx : int The number of interpolation points. - interpolated_univariate_spline_kwargs : TYPE, optional + interpolated_univariate_spline_kwargs : dict, optional Keyword arguments to be passed to the - :class:`InterpolatedUnivariateSpline` class. See + :class:`InterpolatedUnivariateSpline`. See [here](https://docs.scipy.org/doc/scipy/reference/generated/ scipy.interpolate.InterpolatedUnivariateSpline.html) for the - documentation on this class. Default is {}. + documentation on this class. x_column : str, optional References a single column in the DataFrameClient (this is the - "x-axis"). Default is "energy". + "x-axis"). y_columns : list, optional References a list of columns in the DataFrameClient (these are the - "y-axes"). Default is ["mu"]. - - Returns: - ------- - An instance of StandardGrid operator. + "y-axes"). """ - super().__init__(x_column, y_columns) + + self.x_column = x_column + self.y_columns = y_columns self.x0 = x0 self.xf = xf self.nx = nx @@ -412,4 +392,4 @@ def _process_data(self, df): class PreNormalize(Operator): """ - """ \ No newline at end of file + """ From a118c58ea74937d76bc4f1f4c5e252f887ea08d0 Mon Sep 17 00:00:00 2001 From: Matthew Carbone Date: Fri, 12 Aug 2022 12:49:56 -0400 Subject: [PATCH 2/5] Refactor RemoveBackground --- aimm_post_processing/operations.py | 253 ++++++++++++----------------- 1 file changed, 108 insertions(+), 145 deletions(-) diff --git a/aimm_post_processing/operations.py b/aimm_post_processing/operations.py index e755918..a5e2282 100644 --- a/aimm_post_processing/operations.py +++ b/aimm_post_processing/operations.py @@ -9,7 +9,8 @@ import pandas as pd from scipy.interpolate import InterpolatedUnivariateSpline from sklearn.linear_model import LinearRegression -from sklearn.metrics import mean_squared_error + +# from sklearn.metrics import mean_squared_error from aimm_post_processing import utils @@ -37,7 +38,7 @@ def _process_metadata(self): @abstractmethod def __call__(self, client): - ... + ... class UnaryOperator(Operator): @@ -48,11 +49,11 @@ def _process_metadata(self, metadata): """Processing of the metadata dictionary object. Takes the :class:`dict` object as input and returns a modified dictionary with the following changes: - + 1. metadata["_tiled"]["uid"] is replaced with a new uuid string. 2. metadata["post_processing"] is created with keys that indicate the current state of the class, the parent ids - + Parameters ---------- metadata : dict @@ -72,7 +73,7 @@ def _process_metadata(self, metadata): "operator": self.as_dict(), "kwargs": self.__dict__, "datetime": f"{dt} UTC", - } + }, } def __call__(self, df_client): @@ -83,7 +84,7 @@ def __call__(self, df_client): ) return { "data": self._process_data(df_client.read()), - "metadata": self._process_metadata(df_client.metadata) + "metadata": self._process_metadata(df_client.metadata), } @@ -110,7 +111,29 @@ def _process_data(self, df): class StandardizeGrid(UnaryOperator): - """Interpolates specified columns onto a common grid.""" + """Interpolates specified columns onto a common grid. + + Parameters + ---------- + x0 : float + The lower bound of the grid to interpolate onto. + xf : float + The upper bound of the grid to interpolate onto. + nx : int + The number of interpolation points. + interpolated_univariate_spline_kwargs : dict, optional + Keyword arguments to be passed to the + :class:`InterpolatedUnivariateSpline`. See + [here](https://docs.scipy.org/doc/scipy/reference/generated/ + scipy.interpolate.InterpolatedUnivariateSpline.html) for the + documentation on this class. + x_column : str, optional + References a single column in the DataFrameClient (this is the + "x-axis"). + y_columns : list, optional + References a list of columns in the DataFrameClient (these are the + "y-axes"). + """ def __init__( self, @@ -120,42 +143,18 @@ def __init__( nx, interpolated_univariate_spline_kwargs=dict(), x_column="energy", - y_columns=["mu"] + y_columns=["mu"], ): - """Interpolates the provided DataFrameClient onto a grid as specified - by the provided parameters. - - Parameters - ---------- - x0 : float - The lower bound of the grid to interpolate onto. - xf : float - The upper bound of the grid to interpolate onto. - nx : int - The number of interpolation points. - interpolated_univariate_spline_kwargs : dict, optional - Keyword arguments to be passed to the - :class:`InterpolatedUnivariateSpline`. See - [here](https://docs.scipy.org/doc/scipy/reference/generated/ - scipy.interpolate.InterpolatedUnivariateSpline.html) for the - documentation on this class. - x_column : str, optional - References a single column in the DataFrameClient (this is the - "x-axis"). - y_columns : list, optional - References a list of columns in the DataFrameClient (these are the - "y-axes"). - """ - - self.x_column = x_column - self.y_columns = y_columns self.x0 = x0 self.xf = xf self.nx = nx - self.interpolated_univariate_spline_kwargs = interpolated_univariate_spline_kwargs - - def _process_data(self, df): + self.interpolated_univariate_spline_kwargs = ( + interpolated_univariate_spline_kwargs + ) + self.x_column = x_column + self.y_columns = y_columns + def _process_data(self, df): """Takes in a dictionary of the data amd metadata. The data is a :class:`pd.DataFrame`, and the metadata is itself a dictionary. Returns the same dictionary with processed data and metadata. @@ -175,100 +174,69 @@ def _process_data(self, df): class RemoveBackground(Operator): - """Fit the pre-edge region to a victoreen function and subtract it from the spectrum. + """Fit the pre-edge region to a Victoreen function and subtract it from the spectrum. + + Parameters + ---------- + x0 : float + The lower bound of energy range on which the background is fitted. + xf : float + The upper bound of energy range on which the background is fitted. + x_column : str, optional + References a single column in the DataFrameClient (this is the + "x-axis"). + y_columns : list, optional + References a list of columns in the DataFrameClient (these are the + "y-axes"). + victoreen_order : int + The order of Victoreen function. The selected data is fitted to + Victoreen pre-edge function (in which one fits a line to μ(E)*E^n for + some value of n. """ + def __init__( - self, - *, - x0, - xf, - x_column="energy", - y_columns=["mu"], - victoreen_order=0 + self, *, x0, xf, x_column="energy", y_columns=["mu"], victoreen_order=0 ): - """Subtract background from data. - Fit the pre-edge data to a line with slope, and subtract slope info from data. - - Parameters - ---------- - dfClient : tiled.client.dataframe.DataFrameClient - x0 : float - The lower bound of energy range on which the background is fitted. - xf : flaot - The upper bound of energy range on which the background is fitted. - x_column : str, optional - References a single column in the DataFrameClient (this is the - "x-axis"). Default is "energy". - y_columns : list, optional - References a list of columns in the DataFrameClient (these are the - "y-axes"). Default is ["mu"]. - victoreen_order : int - The order of Victoreen function. The selected data is fitted to Victoreen pre-edge - function (in which one fits a line to μ(E)*E^n for some value of n. Default is 0, - which is a linear fit. - - Returns - ------- - An instance of RemoveBackground operator - """ - super().__init__(x_column, y_columns) self.x0 = x0 self.xf = xf self.victoreen_order = victoreen_order + self.x_column = x_column + self.y_columns = y_columns def _process_data(self, df): """ - Takes in a dictionary of the data amd metadata. The data is a - :class:`pd.DataFrame`, and the metadata is itself a dictionary. - Returns the same dictionary with processed data and metadata. - Notes ----- - `LinearRegression().fit()` takes 2-D arrays as input. This can be explored - for batch processing of multiple spectra + `LinearRegression().fit()` takes 2-D arrays as input. This can be + explored for batch processing of multiple spectra """ - bg_data = df.loc[(df[self.x_column] >= self.x0) * (df[self.x_column] < self.xf)] + bg_data = df.loc[ + (df[self.x_column] >= self.x0) * (df[self.x_column] < self.xf) + ] new_data = {self.x_column: df[self.x_column]} for column in self.y_columns: - y = bg_data[column] * bg_data[self.x_column]**self.victoreen_order + y = bg_data[column] * bg_data[self.x_column] ** self.victoreen_order reg = LinearRegression().fit( - bg_data[self.x_column].to_numpy().reshape(-1,1), - y.to_numpy().reshape(-1,1) + bg_data[self.x_column].to_numpy().reshape(-1, 1), + y.to_numpy().reshape(-1, 1), + ) + background = reg.predict( + df[self.x_column].to_numpy().reshape(-1, 1) + ) + new_data[column] = ( + df.loc[:, column].to_numpy() - background.flatten() ) - background = reg.predict(df[self.x_column].to_numpy().reshape(-1,1)) - new_data[column] = df.loc[:,column].to_numpy() - background.flatten() return pd.DataFrame(new_data) -# class Normalize(Operator): -# """ -# """ -# def __init__( -# self, -# x_column="energy", -# y_columns=["mu"] -# ): -# super().__init__(x_column, y_columns) - -# def _process_data(self, df): -# xas_ds = XASDataSet(name="Shift XANES", energy=grid, mu=dd) -# xas_ds.norm1 = norm1 # update atribute for force_normalization -# xas_ds.normalize_force() # force the normalization again with updated atribute - - class StandardizeIntensity(Operator): - """ Scale the intensity so they vary in similar range. - """ + """Scale the intensity so they vary in similar range.""" + def __init__( - self, - *, - x0 = None, - xf = None, - x_column="energy", - y_columns=["mu"] + self, *, x0=None, xf=None, x_column="energy", y_columns=["mu"] ): """Align the intensity to the mean of a selected range, and scale the intensity up to standard deviation. @@ -277,10 +245,10 @@ def __init__( ---------- dfClient : tiled.client.dataframe.DataFrameClient x0 : float - The lower bound of energy range for which the mean is calculated. If None, the first + The lower bound of energy range for which the mean is calculated. If None, the first point in the energy grid is used. Default is None. yf : float - The upper bound of energy range for which the mean is calculated. If None, the last + The upper bound of energy range for which the mean is calculated. If None, the last point in the energy grid is used. Default is None. x_column : str, optional References a single column in the DataFrameClient (this is the @@ -288,7 +256,7 @@ def __init__( y_columns : list, optional References a list of columns in the DataFrameClient (these are the "y-axes"). Default is ["mu"]. - + Returns ------- An instance of StandardizeIntensity operator @@ -306,42 +274,39 @@ def _process_data(self, df): """ grid = df.loc[:, self.x_column] - if self.x0 is None: self.x0 = grid[0] - if self.xf is None: self.xf = grid[-1] + if self.x0 is None: + self.x0 = grid[0] + if self.xf is None: + self.xf = grid[-1] assert self.x0 < self.xf, "Invalid range, make sure x0 < xf" select_mean_range = (grid > self.x0) & (grid < self.xf) - + new_data = {self.x_column: df[self.x_column]} for column in self.y_columns: mu = df.loc[:, column] mu_mean = mu[select_mean_range].mean() mu_std = mu.std() - new_data.update({column: (mu-mu_mean)/mu_std}) - + new_data.update({column: (mu - mu_mean) / mu_std}) + return pd.DataFrame(new_data) class Smooth(Operator): """Return the simple moving average of spectra with a rolling window. - Parameters - ---------- - - window : float, in eV. - The rolling window in eV over which the average intensity is taken. - x_column : str, optional - References a single column in the DataFrameClient (this is the - "x-axis"). Default is "energy". - y_columns : list, optional - References a list of columns in the DataFrameClient (these are the - "y-axes"). Default is ["mu"]. + Parameters + ---------- + + window : float, in eV. + The rolling window in eV over which the average intensity is taken. + x_column : str, optional + References a single column in the DataFrameClient (this is the + "x-axis"). Default is "energy". + y_columns : list, optional + References a list of columns in the DataFrameClient (these are the + "y-axes"). Default is ["mu"]. """ - def __init__( - self, - *, - window=10, - x_column='energy', - y_columns=['mu'] - ): + + def __init__(self, *, window=10, x_column="energy", y_columns=["mu"]): super().__init__(x_column, y_columns) self.window = window @@ -350,29 +315,29 @@ def _apply(self, df): Takes in a dictionary of the data amd metadata. The data is a :class:`pd.DataFrame`, and the metadata is itself a dictionary. Returns the same dictionary with processed data and metadata. - + Returns: -------- dict - A dictionary of the data and metadata. The data is a :class:`pd.DataFrame`, + A dictionary of the data and metadata. The data is a :class:`pd.DataFrame`, and the metadata is itself a dictionary. """ - - grid = df.loc[:,self.x_column] + + grid = df.loc[:, self.x_column] new_data = {self.x_column: df[self.x_column]} for column in self.y_columns: y = df.loc[:, column] y_smooth = utils.simple_moving_average(grid, y, window=self.window) new_data.update({column: y_smooth}) - mse = mean_squared_error(y, y_smooth) - n2s = mse / y_smooth.std() + # mse = mean_squared_error(y, y_smooth) + # n2s = mse / y_smooth.std() return pd.DataFrame(new_data) class Classify(Operator): - """ Label the spectrum as "good", "noisy" or "discard" based on the quality of the spectrum. - """ + """Label the spectrum as "good", "noisy" or "discard" based on the quality of the spectrum.""" + def __init__(self, classifier): super().__init__() self.classifier = classifier @@ -389,7 +354,5 @@ def _process_data(self, df): return df - class PreNormalize(Operator): - """ - """ + """ """ From 11d6bfa55da838c5ed3eaa333a2c522230d2c51f Mon Sep 17 00:00:00 2001 From: Matthew Carbone Date: Fri, 12 Aug 2022 14:06:20 -0400 Subject: [PATCH 3/5] Finalize refactor of operations.py --- aimm_post_processing/operations.py | 109 ++++++++++++----------------- 1 file changed, 45 insertions(+), 64 deletions(-) diff --git a/aimm_post_processing/operations.py b/aimm_post_processing/operations.py index a5e2282..4666c13 100644 --- a/aimm_post_processing/operations.py +++ b/aimm_post_processing/operations.py @@ -10,9 +10,6 @@ from scipy.interpolate import InterpolatedUnivariateSpline from sklearn.linear_model import LinearRegression -# from sklearn.metrics import mean_squared_error - - from aimm_post_processing import utils from tiled.client.dataframe import DataFrameClient @@ -173,8 +170,9 @@ def _process_data(self, df): return pd.DataFrame(new_data) -class RemoveBackground(Operator): - """Fit the pre-edge region to a Victoreen function and subtract it from the spectrum. +class RemoveBackground(UnaryOperator): + """Fit the pre-edge region to a Victoreen function and subtract it from the + spectrum. Parameters ---------- @@ -190,8 +188,8 @@ class RemoveBackground(Operator): "y-axes"). victoreen_order : int The order of Victoreen function. The selected data is fitted to - Victoreen pre-edge function (in which one fits a line to μ(E)*E^n for - some value of n. + Victoreen pre-edge function (in which one fits a line to + :math:`E^n \\mu(E)` for some value of :math:`n`. """ def __init__( @@ -207,7 +205,7 @@ def _process_data(self, df): """ Notes ----- - `LinearRegression().fit()` takes 2-D arrays as input. This can be + ``LinearRegression().fit()`` takes 2-D arrays as input. This can be explored for batch processing of multiple spectra """ @@ -232,38 +230,32 @@ def _process_data(self, df): return pd.DataFrame(new_data) -class StandardizeIntensity(Operator): - """Scale the intensity so they vary in similar range.""" - - def __init__( - self, *, x0=None, xf=None, x_column="energy", y_columns=["mu"] - ): - """Align the intensity to the mean of a selected range, and scale the intensity up to standard - deviation. +class StandardizeIntensity(UnaryOperator): + """Scale the intensity so they vary in similar range. Specifically, aligns + the intensity to the mean of a selected range, and scale the intensity up + to standard deviation. - Parameters - ---------- - dfClient : tiled.client.dataframe.DataFrameClient - x0 : float - The lower bound of energy range for which the mean is calculated. If None, the first - point in the energy grid is used. Default is None. - yf : float - The upper bound of energy range for which the mean is calculated. If None, the last - point in the energy grid is used. Default is None. - x_column : str, optional - References a single column in the DataFrameClient (this is the - "x-axis"). Default is "energy". - y_columns : list, optional - References a list of columns in the DataFrameClient (these are the - "y-axes"). Default is ["mu"]. + Parameters + ---------- + x0 : float + The lower bound of energy range for which the mean is calculated. If None, the first + point in the energy grid is used. Default is None. + yf : float + The upper bound of energy range for which the mean is calculated. If None, the last + point in the energy grid is used. Default is None. + x_column : str, optional + References a single column in the DataFrameClient (this is the + "x-axis"). Default is "energy". + y_columns : list, optional + References a list of columns in the DataFrameClient (these are the + "y-axes"). Default is ["mu"]. + """ - Returns - ------- - An instance of StandardizeIntensity operator - """ - super().__init__(x_column, y_columns) + def __init__(self, *, x0, xf, x_column="energy", y_columns=["mu"]): self.x0 = x0 self.xf = xf + self.x_column = x_column + self.y_columns = y_columns def _process_data(self, df): """ @@ -291,26 +283,27 @@ def _process_data(self, df): return pd.DataFrame(new_data) -class Smooth(Operator): +class Smooth(UnaryOperator): """Return the simple moving average of spectra with a rolling window. + Parameters ---------- - window : float, in eV. The rolling window in eV over which the average intensity is taken. x_column : str, optional - References a single column in the DataFrameClient (this is the - "x-axis"). Default is "energy". + References a single column in the DataFrameClient (this is the + "x-axis"). y_columns : list, optional References a list of columns in the DataFrameClient (these are the - "y-axes"). Default is ["mu"]. + "y-axes"). """ - def __init__(self, *, window=10, x_column="energy", y_columns=["mu"]): - super().__init__(x_column, y_columns) + def __init__(self, *, window=10.0, x_column="energy", y_columns=["mu"]): self.window = window + self.x_column = x_column + self.y_columns = y_columns - def _apply(self, df): + def _process_data(self, df): """ Takes in a dictionary of the data amd metadata. The data is a :class:`pd.DataFrame`, and the metadata is itself a dictionary. @@ -329,30 +322,18 @@ def _apply(self, df): y = df.loc[:, column] y_smooth = utils.simple_moving_average(grid, y, window=self.window) new_data.update({column: y_smooth}) - # mse = mean_squared_error(y, y_smooth) - # n2s = mse / y_smooth.std() return pd.DataFrame(new_data) -class Classify(Operator): - """Label the spectrum as "good", "noisy" or "discard" based on the quality of the spectrum.""" - - def __init__(self, classifier): - super().__init__() - self.classifier = classifier - - def _process_data(self, df): - """ - Parameters - ---------- - dfClient : tiled.client.dataframe.DataFrameClient - classifier : Callable - The classifier that takes in the spectrum and output a label. +# TODO +class Classify(UnaryOperator): + """Label the spectrum as "good", "noisy" or "discard" based on the quality + of the spectrum.""" - """ - return df + ... -class PreNormalize(Operator): - """ """ +# TODO +class PreNormalize(UnaryOperator): + ... From de85d5aff49bb7545f768967389c22d919404edd Mon Sep 17 00:00:00 2001 From: Matthew Carbone Date: Fri, 12 Aug 2022 15:30:53 -0400 Subject: [PATCH 4/5] Remove outdated client, add batch functionality to UnaryOperator --- aimm_post_processing/client.py | 130 ----------------------------- aimm_post_processing/operations.py | 44 ++++++++-- 2 files changed, 36 insertions(+), 138 deletions(-) delete mode 100644 aimm_post_processing/client.py diff --git a/aimm_post_processing/client.py b/aimm_post_processing/client.py deleted file mode 100644 index 95b7d18..0000000 --- a/aimm_post_processing/client.py +++ /dev/null @@ -1,130 +0,0 @@ -from tiled.client import from_uri -from tiled.query_registration import register -from dataclasses import dataclass - -import collections -import json - - -@register(name="raw_mongo", overwrite=True) -@dataclass -class RawMongo: - """Run a MongoDB query against a given collection.""" - - query: str # We cannot put a dict in a URL, so this a JSON str. - - def __init__(self, query): - if isinstance(query, collections.abc.Mapping): - query = json.dumps(query) - self.query = query - - -################## - - -# TODO: figure out a better way to document the return values for this function -def search_sets_in_child( - root, child_names, search_symbol=None, search_edge=None -): - """Walks down a branch of nodes to get a specified child node and use a - search criteria. - - Parameters - ---------- - root : tiled.client.node.Node - The parent node that the client will use to start going down using - the specified branch in child_names. - child_names : list of tiled.client.node.Node - List of subsequent child nodes. They must be sorted in the same way - they were created in the tree. - search_symbol : str, optional - Search criteria for the element symbol (aka the element, e.g. "Cu"). - Default is None. - search_edge : str, optional - Search criteria for the spectroscopy edge (e.g. the K-edge, "K", or the - L3 edge "L3"). Default is None. - - Returns - ------- - tiled.client.node.Node - A structure representing an element in the tree. If the element is the - parent node or a node in the middle of the tree, it returns - tiled.client.node.Node. If it reaches the last node of a branch, - it returns a client structure that represents the type of data - structure that it contains; e.g. - tiled.client.dataframe.DataFrameClient or - tiled.client.array.ArrayClient. - """ - - if isinstance(child_names, list): - child_names = tuple(child_names) - - # Search subsequent child nodes along the tree - child_node = root - for child in child_names: - child_node = child_node[child] - - if search_symbol is None and search_edge is None: - return child_node - - if search_symbol is not None: - child_node = child_node.search(symbol(search_symbol)) - - if search_edge is not None: - child_node = child_node.search(edge(search_edge)) - - return child_node - - -def symbol(symbol): - """Wrapper method to generate a RawMongo query to search for an element - symbol - - Parameters - ---------- - symbol : str - Query parameter used to search for an element symbol. - - Returns - ------- - tiled.client.node.Node - If a match was found, it returns a client node that includes all the - child nodes containing each individual dataset. - """ - - return RawMongo({"metadata.element.symbol": symbol}) - - -def edge(edge): - """Wrapper method to generate a RawMongo query to search for an element - edge. - - Parameters - ---------- - edge : str - The parent node that the client will use to start going down using the - specified branch in child_names. - - Returns - ------- - tiled.client.node.Node - If a match was found, it returns a client node that includes all the - child nodes containing each individual dataset. - """ - - return RawMongo({"metadata.element.edge": edge}) - - -if __name__ == "__main__": - - # Example Code - client = from_uri("https://aimm.lbl.gov/api") - child = ["NCM", "BM_NCMA"] - result_nodes = search_sets_in_child( - client, child, search_symbol="Ni", search_edge="L3" - ) - - # TODO: Test multiple cases by passing a list with multiple paths. - # What is the best data structure to use as a container for the results? - # children = [['NCM', 'BM_NCMA'], - # ['NCM', 'BM_NCM622']] diff --git a/aimm_post_processing/operations.py b/aimm_post_processing/operations.py index 4666c13..c0999a3 100644 --- a/aimm_post_processing/operations.py +++ b/aimm_post_processing/operations.py @@ -12,6 +12,7 @@ from aimm_post_processing import utils from tiled.client.dataframe import DataFrameClient +from tiled.client.node import Node class Operator(MSONable, ABC): @@ -40,7 +41,16 @@ def __call__(self, client): class UnaryOperator(Operator): """Specialized operator class which takes only a single input. This input - must be of instance :class:`DataFrameClient`.""" + must be of instance :class:`DataFrameClient`. + + Particularly, the operator object's ``__call__`` method can be executed on + either a :class:`DataFrameClient` or :class:`Node` object. If run on the + :class:`DataFrameClient`, the operator call will return a single dictionary + with the keys "data" and "metadata", as one would expect. If the input is + of type :class:`Node`, then an attempt is made to iterate through all + children of that node, executing the operator on each instance + individually. A list of dictionaries is then returned. + """ def _process_metadata(self, metadata): """Processing of the metadata dictionary object. Takes the @@ -73,16 +83,34 @@ def _process_metadata(self, metadata): }, } - def __call__(self, df_client): - if not isinstance(df_client, DataFrameClient): + def _call_on_client(self, client): + return { + "data": self._process_data(client.read()), + "metadata": self._process_metadata(client.metadata), + } + + def __call__(self, client): + if isinstance(client, DataFrameClient): + return self._call_on_client(client) + + elif isinstance(client, Node): + # Apply the operator to each of the instances in the node + values = [value for value in client.values()] + if not all([isinstance(v, DataFrameClient) for v in values]): + raise RuntimeError( + "Provided client when iterated on has produced entries " + "that are not DataFrameClient objects. This is likely " + "due to passing a query such as " + 'df_client = c["edge"]["K"], when a query like ' + 'df_client = c["edge"]["K"]["uid"] is required' + ) + return list(map(self._call_on_client, values)) + + else: raise ValueError( - f"df_client is of type {type(df_client)} but should be of " + f"client is of type {type(client)} but should be of " "type DataFrameClient" ) - return { - "data": self._process_data(df_client.read()), - "metadata": self._process_metadata(df_client.metadata), - } class Identity(UnaryOperator): From 470f10866098236018d9f5ed1abb8ede0c625091 Mon Sep 17 00:00:00 2001 From: Matthew Carbone Date: Fri, 12 Aug 2022 16:09:53 -0400 Subject: [PATCH 5/5] Add basic tutorial --- 00_basic_tutorial.ipynb | 405 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 405 insertions(+) create mode 100644 00_basic_tutorial.ipynb diff --git a/00_basic_tutorial.ipynb b/00_basic_tutorial.ipynb new file mode 100644 index 0000000..4407d41 --- /dev/null +++ b/00_basic_tutorial.ipynb @@ -0,0 +1,405 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "8a39bdeb-5d2f-4857-acfe-b9678ba25ae1", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "%config Completer.use_jedi = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f35f534b-9692-4ecd-94c9-24efc0e2dd0b", + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "\n", + "from tiled.client import from_uri\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib as mpl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa1941aa-87af-4023-a63b-95f8d96b437e", + "metadata": {}, + "outputs": [], + "source": [ + "mpl.rcParams['mathtext.fontset'] = 'stix'\n", + "mpl.rcParams['font.family'] = 'STIXGeneral'\n", + "mpl.rcParams['text.usetex'] = True\n", + "plt.rc('xtick', labelsize=12)\n", + "plt.rc('ytick', labelsize=12)\n", + "plt.rc('axes', labelsize=12)\n", + "mpl.rcParams['figure.dpi'] = 300" + ] + }, + { + "cell_type": "markdown", + "id": "43445174-ee13-47ac-9cb1-26482e1739eb", + "metadata": {}, + "source": [ + "# Basic Tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "f5260357-c5a8-41fe-8b5d-f728d96b5963", + "metadata": {}, + "source": [ + "The [AIMM post-processing pipeline](https://github.com/AI-multimodal/aimm-post-processing) is built around the `Operator` object. The `Operator`'s job is to take a `client`-like object and execute a post-processing operation on it. The specific type of operation is defined by the operator. All metadata/provenance is tracked." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ceb1b1f8-f109-4d05-a98d-64b377514934", + "metadata": {}, + "outputs": [], + "source": [ + "from aimm_post_processing import operations" + ] + }, + { + "cell_type": "markdown", + "id": "bc4a0fb0-7da9-4ea5-80ec-b339a6b85fe9", + "metadata": {}, + "source": [ + "Connect to the `tiled` client. This one is the [aimmdb](https://github.com/AI-multimodal/aimmdb) hosted at [aimm.lbl.gov](https://aimm.lbl.gov/api). Note that my API key is stored in an environment variable, `TILED_API_KEY`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3e40949-1d2c-41bd-b207-43ccde4c0f0b", + "metadata": {}, + "outputs": [], + "source": [ + "CLIENT = from_uri(\"https://aimm.lbl.gov/api\")" + ] + }, + { + "cell_type": "markdown", + "id": "c7ae73e7-d693-4c9b-a253-b13e299a3d05", + "metadata": {}, + "source": [ + "## Unary operators" + ] + }, + { + "cell_type": "markdown", + "id": "33d644b3-4233-49b0-90f6-339ad3e8db56", + "metadata": {}, + "source": [ + "A [unary operator](https://en.wikipedia.org/wiki/Unary_operation) takes a single input. This input specifically refers to the fact that these operators only act on a single data point (meaning a `DataFrameClient`) at a time. We'll provide some examples here.\n", + "\n", + "First, lets get a single `DataFrameClient` object:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23cb9f6a-47a7-41ae-9b33-9796cceb8f5a", + "metadata": {}, + "outputs": [], + "source": [ + "df_client = CLIENT[\"dataset\"][\"newville\"][\"edge\"][\"K\"][\"element\"][\"Co\"][\"uid\"][\"Bt5hUbgkfzR\"]\n", + "type(df_client)" + ] + }, + { + "cell_type": "markdown", + "id": "eb29a693-fedf-4369-b12b-3eb0afd4e30f", + "metadata": {}, + "source": [ + "### The identity" + ] + }, + { + "cell_type": "markdown", + "id": "6963e452-9cb0-4ac8-95d1-3c44a736ae7f", + "metadata": {}, + "source": [ + "The simplest operation we can perform is nothing! Let's see what it does. First, feel free to print the output of the `df_client` so you can see what's contained. Using the `read()` method will allow you to access the actual data, and the `metadata` property will allow you to access the metadata:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56a35197-ea9b-4164-82fa-18576486ba97", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "df_client.read() # is a pandas.DataFrame\n", + "df_client.metadata # is a python dictionary" + ] + }, + { + "cell_type": "markdown", + "id": "3e06093f-e79a-42ea-940e-5a008cfdcab9", + "metadata": {}, + "source": [ + "The identity operator is instantiated and then run on the `df_client`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad8a8f5e-72f7-404c-87c8-c82ef5394c3a", + "metadata": {}, + "outputs": [], + "source": [ + "op = operations.Identity()\n", + "result = op(df_client)" + ] + }, + { + "cell_type": "markdown", + "id": "24e0e9a1-e808-4e56-b9c1-ced7d0461d04", + "metadata": {}, + "source": [ + "Every result of any operator will be a dictionary with two keys: `\"data\"` and `\"metadata\"`, which correspond to the results of `read()` and `metadata` above. The data is the correspondingly modified `pandas.DataFrame` object (which in the case of the identity, is of course the same as what we started with). The metadata is custom created for a derived, post-processed object.\n", + "\n", + "First, let's check that the original and \"post-processed\" data are the same." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a66a114-f2d9-49a4-b3d1-f27bf18a1c34", + "metadata": {}, + "outputs": [], + "source": [ + "assert (df_client.read() == result[\"data\"]).all().all()" + ] + }, + { + "cell_type": "markdown", + "id": "a590df1c-602d-4c7f-8d5d-661930116b92", + "metadata": {}, + "source": [ + "Next, the metadata:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4707f59f-fd44-4414-a565-931fb226c06a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "result[\"metadata\"]" + ] + }, + { + "cell_type": "markdown", + "id": "7d5b78e3-f9dd-4712-8409-5c64c006ad50", + "metadata": {}, + "source": [ + "First, a new unique id is assigned. Second, given this is a derived quantity, the previous original metadata is now gone in place of a `post_processing` key. This key contains every bit of information needed for provenance, including the parents (which is just one in the case of a unary operator), the operator details (including code version), any keyword arguments used during instantiation, and the datetime at which the opration was run. We use the [MSONable](https://pythonhosted.org/monty/_modules/monty/json.html) library to take care of most of this for us.\n", + "\n", + "We can compare against the original metadata to see the differences." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc1bd8b7-5038-4b1f-aebe-0b60762986e1", + "metadata": {}, + "outputs": [], + "source": [ + "df_client.metadata" + ] + }, + { + "cell_type": "markdown", + "id": "92af26f9-2c72-4882-b672-7f7ea7fcb5aa", + "metadata": {}, + "source": [ + "### Standardizing the grids" + ] + }, + { + "cell_type": "markdown", + "id": "3b36155c-2066-4d78-a8a9-df73b2bfdece", + "metadata": {}, + "source": [ + "Often times (and especially for e.g. machine learning applications) we need to interpolate our spectral data onto a common grid. We can do this easily with the `StandardizeGrid` unary operator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "831f8b6f-499f-471c-8ab9-498a720aaee6", + "metadata": {}, + "outputs": [], + "source": [ + "op = operations.StandardizeGrid(x0=7550.0, xf=8900.0, nx=100, x_column=\"energy\", y_columns=[\"itrans\"])\n", + "result = op(df_client)" + ] + }, + { + "cell_type": "markdown", + "id": "1e684792-589e-4857-b5d4-9e5d30a278d6", + "metadata": {}, + "source": [ + "Here's a visualization of what it's done:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b239efd-1c9b-4dee-85d2-639d6827dde4", + "metadata": {}, + "outputs": [], + "source": [ + "d0 = df_client.read()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a70e7bae-cdb9-4b4e-9b2f-f73872f8d177", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n", + "ax.plot(d0[\"energy\"], d0[\"itrans\"], 'k-')\n", + "ax.plot(result[\"data\"][\"energy\"], result[\"data\"][\"itrans\"], 'r-')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "16a49a37-e025-458b-90b7-ec79c9834b74", + "metadata": {}, + "source": [ + "### Batch processing" + ] + }, + { + "cell_type": "markdown", + "id": "34b4fe25-53af-4ca8-9455-ae7cce45d0a7", + "metadata": {}, + "source": [ + "While a unary operator acts on only a single input, there are cases where we might wish to apply the same operator to a list of `client`-like objects. The operator `__call__` can handle this. For example, consider the following:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a782fd6-8340-424d-9464-88b7a90ae4a7", + "metadata": {}, + "outputs": [], + "source": [ + "node_client = CLIENT[\"edge\"][\"L3\"][\"uid\"]\n", + "node_client" + ] + }, + { + "cell_type": "markdown", + "id": "db48f974-0fcb-4379-bbad-427ee529883d", + "metadata": {}, + "source": [ + "Currently, there are 23 entries with `L3` edge keys in the entire database. Let's act the identity on this `Node`, which will apply the operator to each entry individually." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80abb7ea-bec8-467c-af8e-29828cdeea01", + "metadata": {}, + "outputs": [], + "source": [ + "op = operations.Identity()\n", + "result = op(node_client)" + ] + }, + { + "cell_type": "markdown", + "id": "e1f20507-510b-4801-9b7b-2a53e74a6ebf", + "metadata": {}, + "source": [ + "The first of these results corresponds to the first entry above, the second to the second, and so on." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea77f2ae-c17f-41e8-b57b-e98149e5435e", + "metadata": {}, + "outputs": [], + "source": [ + "result[0][\"metadata\"]" + ] + }, + { + "cell_type": "markdown", + "id": "6d75b36d-11be-4cdc-aca4-e2c633d86003", + "metadata": {}, + "source": [ + "Note as well that `__call__` will attempt to intelligently detect if you provided it with the incorrect type of node. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a12f3e79-4dcd-49c2-9ff1-a8f69c530eee", + "metadata": {}, + "outputs": [], + "source": [ + "node_client = CLIENT[\"edge\"][\"L3\"]\n", + "node_client" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc19d5e3-14c4-4694-958e-0aae19c9c4fc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "op = operations.Identity()\n", + "op(node_client)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "toc-autonumbering": true + }, + "nbformat": 4, + "nbformat_minor": 5 +}