diff --git a/00_basic_tutorial.ipynb b/00_basic_tutorial.ipynb new file mode 100644 index 0000000..4407d41 --- /dev/null +++ b/00_basic_tutorial.ipynb @@ -0,0 +1,405 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "8a39bdeb-5d2f-4857-acfe-b9678ba25ae1", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "%config Completer.use_jedi = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f35f534b-9692-4ecd-94c9-24efc0e2dd0b", + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "\n", + "from tiled.client import from_uri\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib as mpl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa1941aa-87af-4023-a63b-95f8d96b437e", + "metadata": {}, + "outputs": [], + "source": [ + "mpl.rcParams['mathtext.fontset'] = 'stix'\n", + "mpl.rcParams['font.family'] = 'STIXGeneral'\n", + "mpl.rcParams['text.usetex'] = True\n", + "plt.rc('xtick', labelsize=12)\n", + "plt.rc('ytick', labelsize=12)\n", + "plt.rc('axes', labelsize=12)\n", + "mpl.rcParams['figure.dpi'] = 300" + ] + }, + { + "cell_type": "markdown", + "id": "43445174-ee13-47ac-9cb1-26482e1739eb", + "metadata": {}, + "source": [ + "# Basic Tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "f5260357-c5a8-41fe-8b5d-f728d96b5963", + "metadata": {}, + "source": [ + "The [AIMM post-processing pipeline](https://github.com/AI-multimodal/aimm-post-processing) is built around the `Operator` object. The `Operator`'s job is to take a `client`-like object and execute a post-processing operation on it. The specific type of operation is defined by the operator. All metadata/provenance is tracked." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ceb1b1f8-f109-4d05-a98d-64b377514934", + "metadata": {}, + "outputs": [], + "source": [ + "from aimm_post_processing import operations" + ] + }, + { + "cell_type": "markdown", + "id": "bc4a0fb0-7da9-4ea5-80ec-b339a6b85fe9", + "metadata": {}, + "source": [ + "Connect to the `tiled` client. This one is the [aimmdb](https://github.com/AI-multimodal/aimmdb) hosted at [aimm.lbl.gov](https://aimm.lbl.gov/api). Note that my API key is stored in an environment variable, `TILED_API_KEY`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3e40949-1d2c-41bd-b207-43ccde4c0f0b", + "metadata": {}, + "outputs": [], + "source": [ + "CLIENT = from_uri(\"https://aimm.lbl.gov/api\")" + ] + }, + { + "cell_type": "markdown", + "id": "c7ae73e7-d693-4c9b-a253-b13e299a3d05", + "metadata": {}, + "source": [ + "## Unary operators" + ] + }, + { + "cell_type": "markdown", + "id": "33d644b3-4233-49b0-90f6-339ad3e8db56", + "metadata": {}, + "source": [ + "A [unary operator](https://en.wikipedia.org/wiki/Unary_operation) takes a single input. This input specifically refers to the fact that these operators only act on a single data point (meaning a `DataFrameClient`) at a time. We'll provide some examples here.\n", + "\n", + "First, lets get a single `DataFrameClient` object:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23cb9f6a-47a7-41ae-9b33-9796cceb8f5a", + "metadata": {}, + "outputs": [], + "source": [ + "df_client = CLIENT[\"dataset\"][\"newville\"][\"edge\"][\"K\"][\"element\"][\"Co\"][\"uid\"][\"Bt5hUbgkfzR\"]\n", + "type(df_client)" + ] + }, + { + "cell_type": "markdown", + "id": "eb29a693-fedf-4369-b12b-3eb0afd4e30f", + "metadata": {}, + "source": [ + "### The identity" + ] + }, + { + "cell_type": "markdown", + "id": "6963e452-9cb0-4ac8-95d1-3c44a736ae7f", + "metadata": {}, + "source": [ + "The simplest operation we can perform is nothing! Let's see what it does. First, feel free to print the output of the `df_client` so you can see what's contained. Using the `read()` method will allow you to access the actual data, and the `metadata` property will allow you to access the metadata:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56a35197-ea9b-4164-82fa-18576486ba97", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "df_client.read() # is a pandas.DataFrame\n", + "df_client.metadata # is a python dictionary" + ] + }, + { + "cell_type": "markdown", + "id": "3e06093f-e79a-42ea-940e-5a008cfdcab9", + "metadata": {}, + "source": [ + "The identity operator is instantiated and then run on the `df_client`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad8a8f5e-72f7-404c-87c8-c82ef5394c3a", + "metadata": {}, + "outputs": [], + "source": [ + "op = operations.Identity()\n", + "result = op(df_client)" + ] + }, + { + "cell_type": "markdown", + "id": "24e0e9a1-e808-4e56-b9c1-ced7d0461d04", + "metadata": {}, + "source": [ + "Every result of any operator will be a dictionary with two keys: `\"data\"` and `\"metadata\"`, which correspond to the results of `read()` and `metadata` above. The data is the correspondingly modified `pandas.DataFrame` object (which in the case of the identity, is of course the same as what we started with). The metadata is custom created for a derived, post-processed object.\n", + "\n", + "First, let's check that the original and \"post-processed\" data are the same." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a66a114-f2d9-49a4-b3d1-f27bf18a1c34", + "metadata": {}, + "outputs": [], + "source": [ + "assert (df_client.read() == result[\"data\"]).all().all()" + ] + }, + { + "cell_type": "markdown", + "id": "a590df1c-602d-4c7f-8d5d-661930116b92", + "metadata": {}, + "source": [ + "Next, the metadata:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4707f59f-fd44-4414-a565-931fb226c06a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "result[\"metadata\"]" + ] + }, + { + "cell_type": "markdown", + "id": "7d5b78e3-f9dd-4712-8409-5c64c006ad50", + "metadata": {}, + "source": [ + "First, a new unique id is assigned. Second, given this is a derived quantity, the previous original metadata is now gone in place of a `post_processing` key. This key contains every bit of information needed for provenance, including the parents (which is just one in the case of a unary operator), the operator details (including code version), any keyword arguments used during instantiation, and the datetime at which the opration was run. We use the [MSONable](https://pythonhosted.org/monty/_modules/monty/json.html) library to take care of most of this for us.\n", + "\n", + "We can compare against the original metadata to see the differences." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc1bd8b7-5038-4b1f-aebe-0b60762986e1", + "metadata": {}, + "outputs": [], + "source": [ + "df_client.metadata" + ] + }, + { + "cell_type": "markdown", + "id": "92af26f9-2c72-4882-b672-7f7ea7fcb5aa", + "metadata": {}, + "source": [ + "### Standardizing the grids" + ] + }, + { + "cell_type": "markdown", + "id": "3b36155c-2066-4d78-a8a9-df73b2bfdece", + "metadata": {}, + "source": [ + "Often times (and especially for e.g. machine learning applications) we need to interpolate our spectral data onto a common grid. We can do this easily with the `StandardizeGrid` unary operator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "831f8b6f-499f-471c-8ab9-498a720aaee6", + "metadata": {}, + "outputs": [], + "source": [ + "op = operations.StandardizeGrid(x0=7550.0, xf=8900.0, nx=100, x_column=\"energy\", y_columns=[\"itrans\"])\n", + "result = op(df_client)" + ] + }, + { + "cell_type": "markdown", + "id": "1e684792-589e-4857-b5d4-9e5d30a278d6", + "metadata": {}, + "source": [ + "Here's a visualization of what it's done:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b239efd-1c9b-4dee-85d2-639d6827dde4", + "metadata": {}, + "outputs": [], + "source": [ + "d0 = df_client.read()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a70e7bae-cdb9-4b4e-9b2f-f73872f8d177", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n", + "ax.plot(d0[\"energy\"], d0[\"itrans\"], 'k-')\n", + "ax.plot(result[\"data\"][\"energy\"], result[\"data\"][\"itrans\"], 'r-')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "16a49a37-e025-458b-90b7-ec79c9834b74", + "metadata": {}, + "source": [ + "### Batch processing" + ] + }, + { + "cell_type": "markdown", + "id": "34b4fe25-53af-4ca8-9455-ae7cce45d0a7", + "metadata": {}, + "source": [ + "While a unary operator acts on only a single input, there are cases where we might wish to apply the same operator to a list of `client`-like objects. The operator `__call__` can handle this. For example, consider the following:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a782fd6-8340-424d-9464-88b7a90ae4a7", + "metadata": {}, + "outputs": [], + "source": [ + "node_client = CLIENT[\"edge\"][\"L3\"][\"uid\"]\n", + "node_client" + ] + }, + { + "cell_type": "markdown", + "id": "db48f974-0fcb-4379-bbad-427ee529883d", + "metadata": {}, + "source": [ + "Currently, there are 23 entries with `L3` edge keys in the entire database. Let's act the identity on this `Node`, which will apply the operator to each entry individually." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80abb7ea-bec8-467c-af8e-29828cdeea01", + "metadata": {}, + "outputs": [], + "source": [ + "op = operations.Identity()\n", + "result = op(node_client)" + ] + }, + { + "cell_type": "markdown", + "id": "e1f20507-510b-4801-9b7b-2a53e74a6ebf", + "metadata": {}, + "source": [ + "The first of these results corresponds to the first entry above, the second to the second, and so on." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea77f2ae-c17f-41e8-b57b-e98149e5435e", + "metadata": {}, + "outputs": [], + "source": [ + "result[0][\"metadata\"]" + ] + }, + { + "cell_type": "markdown", + "id": "6d75b36d-11be-4cdc-aca4-e2c633d86003", + "metadata": {}, + "source": [ + "Note as well that `__call__` will attempt to intelligently detect if you provided it with the incorrect type of node. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a12f3e79-4dcd-49c2-9ff1-a8f69c530eee", + "metadata": {}, + "outputs": [], + "source": [ + "node_client = CLIENT[\"edge\"][\"L3\"]\n", + "node_client" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc19d5e3-14c4-4694-958e-0aae19c9c4fc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "op = operations.Identity()\n", + "op(node_client)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "toc-autonumbering": true + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/aimm_post_processing/client.py b/aimm_post_processing/client.py deleted file mode 100644 index 95b7d18..0000000 --- a/aimm_post_processing/client.py +++ /dev/null @@ -1,130 +0,0 @@ -from tiled.client import from_uri -from tiled.query_registration import register -from dataclasses import dataclass - -import collections -import json - - -@register(name="raw_mongo", overwrite=True) -@dataclass -class RawMongo: - """Run a MongoDB query against a given collection.""" - - query: str # We cannot put a dict in a URL, so this a JSON str. - - def __init__(self, query): - if isinstance(query, collections.abc.Mapping): - query = json.dumps(query) - self.query = query - - -################## - - -# TODO: figure out a better way to document the return values for this function -def search_sets_in_child( - root, child_names, search_symbol=None, search_edge=None -): - """Walks down a branch of nodes to get a specified child node and use a - search criteria. - - Parameters - ---------- - root : tiled.client.node.Node - The parent node that the client will use to start going down using - the specified branch in child_names. - child_names : list of tiled.client.node.Node - List of subsequent child nodes. They must be sorted in the same way - they were created in the tree. - search_symbol : str, optional - Search criteria for the element symbol (aka the element, e.g. "Cu"). - Default is None. - search_edge : str, optional - Search criteria for the spectroscopy edge (e.g. the K-edge, "K", or the - L3 edge "L3"). Default is None. - - Returns - ------- - tiled.client.node.Node - A structure representing an element in the tree. If the element is the - parent node or a node in the middle of the tree, it returns - tiled.client.node.Node. If it reaches the last node of a branch, - it returns a client structure that represents the type of data - structure that it contains; e.g. - tiled.client.dataframe.DataFrameClient or - tiled.client.array.ArrayClient. - """ - - if isinstance(child_names, list): - child_names = tuple(child_names) - - # Search subsequent child nodes along the tree - child_node = root - for child in child_names: - child_node = child_node[child] - - if search_symbol is None and search_edge is None: - return child_node - - if search_symbol is not None: - child_node = child_node.search(symbol(search_symbol)) - - if search_edge is not None: - child_node = child_node.search(edge(search_edge)) - - return child_node - - -def symbol(symbol): - """Wrapper method to generate a RawMongo query to search for an element - symbol - - Parameters - ---------- - symbol : str - Query parameter used to search for an element symbol. - - Returns - ------- - tiled.client.node.Node - If a match was found, it returns a client node that includes all the - child nodes containing each individual dataset. - """ - - return RawMongo({"metadata.element.symbol": symbol}) - - -def edge(edge): - """Wrapper method to generate a RawMongo query to search for an element - edge. - - Parameters - ---------- - edge : str - The parent node that the client will use to start going down using the - specified branch in child_names. - - Returns - ------- - tiled.client.node.Node - If a match was found, it returns a client node that includes all the - child nodes containing each individual dataset. - """ - - return RawMongo({"metadata.element.edge": edge}) - - -if __name__ == "__main__": - - # Example Code - client = from_uri("https://aimm.lbl.gov/api") - child = ["NCM", "BM_NCMA"] - result_nodes = search_sets_in_child( - client, child, search_symbol="Ni", search_edge="L3" - ) - - # TODO: Test multiple cases by passing a list with multiple paths. - # What is the best data structure to use as a container for the results? - # children = [['NCM', 'BM_NCMA'], - # ['NCM', 'BM_NCM622']] diff --git a/aimm_post_processing/operations.py b/aimm_post_processing/operations.py index 98a1383..c0999a3 100644 --- a/aimm_post_processing/operations.py +++ b/aimm_post_processing/operations.py @@ -1,6 +1,6 @@ """Module for housing post-processing operations.""" -from abc import ABC +from abc import ABC, abstractmethod from datetime import datetime from uuid import uuid4 @@ -9,10 +9,11 @@ import pandas as pd from scipy.interpolate import InterpolatedUnivariateSpline from sklearn.linear_model import LinearRegression -from sklearn.metrics import mean_squared_error + from aimm_post_processing import utils -from copy import deepcopy -from abc import ABC, abstractmethod +from tiled.client.dataframe import DataFrameClient +from tiled.client.node import Node + class Operator(MSONable, ABC): """Base operator class. Tracks everything required through a combination @@ -22,93 +23,99 @@ class Operator(MSONable, ABC): .. important:: The __call__ method must be derived for every operator. In particular, - this operator should take as arguments at least one other data point - (node). + this operator should take as arguments at least one data point (node). """ - def __init__( # Initialize common parameters for all operators - self, - x_column="energy", - y_columns=["mu"], - ): - self.x_column = x_column - self.y_columns = y_columns - self.operator_id = str(uuid4()) # UID for the defined operator. + @abstractmethod + def _process_data(self): + ... + @abstractmethod + def _process_metadata(self): + ... + + @abstractmethod + def __call__(self, client): + ... - def __call__(self, dataDict): - # meke a copy, otherwise python will make modification to input dataDict instead. - copy_dataDict = deepcopy(dataDict) - new_metadata = self._process_metadata(copy_dataDict["metadata"]) - new_df = self._process_data(copy_dataDict["data"]) +class UnaryOperator(Operator): + """Specialized operator class which takes only a single input. This input + must be of instance :class:`DataFrameClient`. - return {"data": new_df, "metadata": new_metadata} + Particularly, the operator object's ``__call__`` method can be executed on + either a :class:`DataFrameClient` or :class:`Node` object. If run on the + :class:`DataFrameClient`, the operator call will return a single dictionary + with the keys "data" and "metadata", as one would expect. If the input is + of type :class:`Node`, then an attempt is made to iterate through all + children of that node, executing the operator on each instance + individually. A list of dictionaries is then returned. + """ def _process_metadata(self, metadata): - """Preliminary pre-processing of the dictionary object that contains data and metadata. - Takes the:class:`dict` object as input and returns the untouched data in addition to an - augmented metadata dictionary. - + """Processing of the metadata dictionary object. Takes the + :class:`dict` object as input and returns a modified + dictionary with the following changes: + + 1. metadata["_tiled"]["uid"] is replaced with a new uuid string. + 2. metadata["post_processing"] is created with keys that indicate + the current state of the class, the parent ids + Parameters - --------- - dataDict : dict - The data dictionary that contains data and metadata - local_kwargs : tuple - A tuple (usually `locals().items()` where it is called) of local variables in operator - space. + ---------- + metadata : dict + The metadata dictionary accessed via ``df_client.metadata``. - Notes - ----- - 1. Measurement `id` is suspiciously `_id` that sits under `sample` in the metadata. - Need to check if this hierchy is universal for all data. + Returns + ------- + dict + The new metadata object for the post-processed child. """ - # parents are the uid of the last processed data, or the original sample id otherwise. - try: - parent_id = metadata["post_processing"]["id"] - except: - parent_id = metadata['_tiled']['uid'] - dt = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") - metadata["post_processing"] = { - "id": str(uuid4()), - "parent": parent_id, - "operator": self.as_dict(), - "kwargs": self.__dict__, - "datetime": f"{dt} UTC", + return { + "_tiled": {"uid": str(uuid4())}, # Assign a new uid + "post_processing": { + "parents": [metadata["_tiled"]["uid"]], + "operator": self.as_dict(), + "kwargs": self.__dict__, + "datetime": f"{dt} UTC", + }, } - return metadata - - @abstractmethod - def _process_data(self, df) -> dict: - """User must override this method. - """ - raise NotImplementedError - - -class Pull(Operator): - def __init__(self): - super().__init__() - - def __call__(self, dfClient): - """ This operator does nothing but return the data/metadata dictionary for a given - `tiled.client.dataframe.DataFrameClient`. - """ - metadata = deepcopy(dict(dfClient.metadata)) - new_metadata = self._process_metadata(metadata) - df = deepcopy(dfClient.read()) - new_df = self._process_data(df) - - return {"data": new_df, "metadata": new_metadata} + def _call_on_client(self, client): + return { + "data": self._process_data(client.read()), + "metadata": self._process_metadata(client.metadata), + } - def _process_data(self, df): - return df + def __call__(self, client): + if isinstance(client, DataFrameClient): + return self._call_on_client(client) + + elif isinstance(client, Node): + # Apply the operator to each of the instances in the node + values = [value for value in client.values()] + if not all([isinstance(v, DataFrameClient) for v in values]): + raise RuntimeError( + "Provided client when iterated on has produced entries " + "that are not DataFrameClient objects. This is likely " + "due to passing a query such as " + 'df_client = c["edge"]["K"], when a query like ' + 'df_client = c["edge"]["K"]["uid"] is required' + ) + return list(map(self._call_on_client, values)) + + else: + raise ValueError( + f"client is of type {type(client)} but should be of " + "type DataFrameClient" + ) -class Identity(Operator): - """The identity operation. Does nothing. Used for testing purposes.""" +class Identity(UnaryOperator): + """The identity operation. Does nothing. Primarily used for testing + purposes.""" def __init__(self): super().__init__() @@ -117,65 +124,62 @@ def _process_data(self, df): """ Parameters ---------- - dataDict : pandas.DataFrame - The data is a :class:`pd.DataFrame` + df : pandas.DataFrame + The dataframe object to process. Returns ------- - dataDict : dict - Same as input + pandas.DataFrame """ + return df -class StandardizeGrid(Operator): - """Interpolates specified columns onto a common grid.""" +class StandardizeGrid(UnaryOperator): + """Interpolates specified columns onto a common grid. + + Parameters + ---------- + x0 : float + The lower bound of the grid to interpolate onto. + xf : float + The upper bound of the grid to interpolate onto. + nx : int + The number of interpolation points. + interpolated_univariate_spline_kwargs : dict, optional + Keyword arguments to be passed to the + :class:`InterpolatedUnivariateSpline`. See + [here](https://docs.scipy.org/doc/scipy/reference/generated/ + scipy.interpolate.InterpolatedUnivariateSpline.html) for the + documentation on this class. + x_column : str, optional + References a single column in the DataFrameClient (this is the + "x-axis"). + y_columns : list, optional + References a list of columns in the DataFrameClient (these are the + "y-axes"). + """ def __init__( self, + *, x0, xf, nx, interpolated_univariate_spline_kwargs=dict(), x_column="energy", - y_columns=["mu"] + y_columns=["mu"], ): - """Interpolates the provided DataFrameClient onto a grid as specified - by the provided parameters. - - Parameters - ---------- - x0 : float - The lower bound of the grid to interpolate onto. - xf : float - The upper bound of the grid to interpolate onto. - nx : int - The number of interpolation points. - interpolated_univariate_spline_kwargs : TYPE, optional - Keyword arguments to be passed to the - :class:`InterpolatedUnivariateSpline` class. See - [here](https://docs.scipy.org/doc/scipy/reference/generated/ - scipy.interpolate.InterpolatedUnivariateSpline.html) for the - documentation on this class. Default is {}. - x_column : str, optional - References a single column in the DataFrameClient (this is the - "x-axis"). Default is "energy". - y_columns : list, optional - References a list of columns in the DataFrameClient (these are the - "y-axes"). Default is ["mu"]. - - Returns: - ------- - An instance of StandardGrid operator. - """ - super().__init__(x_column, y_columns) self.x0 = x0 self.xf = xf self.nx = nx - self.interpolated_univariate_spline_kwargs = interpolated_univariate_spline_kwargs - - def _process_data(self, df): + self.interpolated_univariate_spline_kwargs = ( + interpolated_univariate_spline_kwargs + ) + self.x_column = x_column + self.y_columns = y_columns + def _process_data(self, df): """Takes in a dictionary of the data amd metadata. The data is a :class:`pd.DataFrame`, and the metadata is itself a dictionary. Returns the same dictionary with processed data and metadata. @@ -194,128 +198,92 @@ def _process_data(self, df): return pd.DataFrame(new_data) -class RemoveBackground(Operator): - """Fit the pre-edge region to a victoreen function and subtract it from the spectrum. +class RemoveBackground(UnaryOperator): + """Fit the pre-edge region to a Victoreen function and subtract it from the + spectrum. + + Parameters + ---------- + x0 : float + The lower bound of energy range on which the background is fitted. + xf : float + The upper bound of energy range on which the background is fitted. + x_column : str, optional + References a single column in the DataFrameClient (this is the + "x-axis"). + y_columns : list, optional + References a list of columns in the DataFrameClient (these are the + "y-axes"). + victoreen_order : int + The order of Victoreen function. The selected data is fitted to + Victoreen pre-edge function (in which one fits a line to + :math:`E^n \\mu(E)` for some value of :math:`n`. """ + def __init__( - self, - *, - x0, - xf, - x_column="energy", - y_columns=["mu"], - victoreen_order=0 + self, *, x0, xf, x_column="energy", y_columns=["mu"], victoreen_order=0 ): - """Subtract background from data. - Fit the pre-edge data to a line with slope, and subtract slope info from data. - - Parameters - ---------- - dfClient : tiled.client.dataframe.DataFrameClient - x0 : float - The lower bound of energy range on which the background is fitted. - xf : flaot - The upper bound of energy range on which the background is fitted. - x_column : str, optional - References a single column in the DataFrameClient (this is the - "x-axis"). Default is "energy". - y_columns : list, optional - References a list of columns in the DataFrameClient (these are the - "y-axes"). Default is ["mu"]. - victoreen_order : int - The order of Victoreen function. The selected data is fitted to Victoreen pre-edge - function (in which one fits a line to μ(E)*E^n for some value of n. Default is 0, - which is a linear fit. - - Returns - ------- - An instance of RemoveBackground operator - """ - super().__init__(x_column, y_columns) self.x0 = x0 self.xf = xf self.victoreen_order = victoreen_order + self.x_column = x_column + self.y_columns = y_columns def _process_data(self, df): """ - Takes in a dictionary of the data amd metadata. The data is a - :class:`pd.DataFrame`, and the metadata is itself a dictionary. - Returns the same dictionary with processed data and metadata. - Notes ----- - `LinearRegression().fit()` takes 2-D arrays as input. This can be explored - for batch processing of multiple spectra + ``LinearRegression().fit()`` takes 2-D arrays as input. This can be + explored for batch processing of multiple spectra """ - bg_data = df.loc[(df[self.x_column] >= self.x0) * (df[self.x_column] < self.xf)] + bg_data = df.loc[ + (df[self.x_column] >= self.x0) * (df[self.x_column] < self.xf) + ] new_data = {self.x_column: df[self.x_column]} for column in self.y_columns: - y = bg_data[column] * bg_data[self.x_column]**self.victoreen_order + y = bg_data[column] * bg_data[self.x_column] ** self.victoreen_order reg = LinearRegression().fit( - bg_data[self.x_column].to_numpy().reshape(-1,1), - y.to_numpy().reshape(-1,1) + bg_data[self.x_column].to_numpy().reshape(-1, 1), + y.to_numpy().reshape(-1, 1), + ) + background = reg.predict( + df[self.x_column].to_numpy().reshape(-1, 1) + ) + new_data[column] = ( + df.loc[:, column].to_numpy() - background.flatten() ) - background = reg.predict(df[self.x_column].to_numpy().reshape(-1,1)) - new_data[column] = df.loc[:,column].to_numpy() - background.flatten() return pd.DataFrame(new_data) -# class Normalize(Operator): -# """ -# """ -# def __init__( -# self, -# x_column="energy", -# y_columns=["mu"] -# ): -# super().__init__(x_column, y_columns) - -# def _process_data(self, df): -# xas_ds = XASDataSet(name="Shift XANES", energy=grid, mu=dd) -# xas_ds.norm1 = norm1 # update atribute for force_normalization -# xas_ds.normalize_force() # force the normalization again with updated atribute - - -class StandardizeIntensity(Operator): - """ Scale the intensity so they vary in similar range. +class StandardizeIntensity(UnaryOperator): + """Scale the intensity so they vary in similar range. Specifically, aligns + the intensity to the mean of a selected range, and scale the intensity up + to standard deviation. + + Parameters + ---------- + x0 : float + The lower bound of energy range for which the mean is calculated. If None, the first + point in the energy grid is used. Default is None. + yf : float + The upper bound of energy range for which the mean is calculated. If None, the last + point in the energy grid is used. Default is None. + x_column : str, optional + References a single column in the DataFrameClient (this is the + "x-axis"). Default is "energy". + y_columns : list, optional + References a list of columns in the DataFrameClient (these are the + "y-axes"). Default is ["mu"]. """ - def __init__( - self, - *, - x0 = None, - xf = None, - x_column="energy", - y_columns=["mu"] - ): - """Align the intensity to the mean of a selected range, and scale the intensity up to standard - deviation. - Parameters - ---------- - dfClient : tiled.client.dataframe.DataFrameClient - x0 : float - The lower bound of energy range for which the mean is calculated. If None, the first - point in the energy grid is used. Default is None. - yf : float - The upper bound of energy range for which the mean is calculated. If None, the last - point in the energy grid is used. Default is None. - x_column : str, optional - References a single column in the DataFrameClient (this is the - "x-axis"). Default is "energy". - y_columns : list, optional - References a list of columns in the DataFrameClient (these are the - "y-axes"). Default is ["mu"]. - - Returns - ------- - An instance of StandardizeIntensity operator - """ - super().__init__(x_column, y_columns) + def __init__(self, *, x0, xf, x_column="energy", y_columns=["mu"]): self.x0 = x0 self.xf = xf + self.x_column = x_column + self.y_columns = y_columns def _process_data(self, df): """ @@ -326,90 +294,74 @@ def _process_data(self, df): """ grid = df.loc[:, self.x_column] - if self.x0 is None: self.x0 = grid[0] - if self.xf is None: self.xf = grid[-1] + if self.x0 is None: + self.x0 = grid[0] + if self.xf is None: + self.xf = grid[-1] assert self.x0 < self.xf, "Invalid range, make sure x0 < xf" select_mean_range = (grid > self.x0) & (grid < self.xf) - + new_data = {self.x_column: df[self.x_column]} for column in self.y_columns: mu = df.loc[:, column] mu_mean = mu[select_mean_range].mean() mu_std = mu.std() - new_data.update({column: (mu-mu_mean)/mu_std}) - + new_data.update({column: (mu - mu_mean) / mu_std}) + return pd.DataFrame(new_data) -class Smooth(Operator): +class Smooth(UnaryOperator): """Return the simple moving average of spectra with a rolling window. - Parameters - ---------- - - window : float, in eV. - The rolling window in eV over which the average intensity is taken. - x_column : str, optional - References a single column in the DataFrameClient (this is the - "x-axis"). Default is "energy". - y_columns : list, optional - References a list of columns in the DataFrameClient (these are the - "y-axes"). Default is ["mu"]. + + Parameters + ---------- + window : float, in eV. + The rolling window in eV over which the average intensity is taken. + x_column : str, optional + References a single column in the DataFrameClient (this is the + "x-axis"). + y_columns : list, optional + References a list of columns in the DataFrameClient (these are the + "y-axes"). """ - def __init__( - self, - *, - window=10, - x_column='energy', - y_columns=['mu'] - ): - super().__init__(x_column, y_columns) + + def __init__(self, *, window=10.0, x_column="energy", y_columns=["mu"]): self.window = window + self.x_column = x_column + self.y_columns = y_columns - def _apply(self, df): + def _process_data(self, df): """ Takes in a dictionary of the data amd metadata. The data is a :class:`pd.DataFrame`, and the metadata is itself a dictionary. Returns the same dictionary with processed data and metadata. - + Returns: -------- dict - A dictionary of the data and metadata. The data is a :class:`pd.DataFrame`, + A dictionary of the data and metadata. The data is a :class:`pd.DataFrame`, and the metadata is itself a dictionary. """ - - grid = df.loc[:,self.x_column] + + grid = df.loc[:, self.x_column] new_data = {self.x_column: df[self.x_column]} for column in self.y_columns: y = df.loc[:, column] y_smooth = utils.simple_moving_average(grid, y, window=self.window) new_data.update({column: y_smooth}) - mse = mean_squared_error(y, y_smooth) - n2s = mse / y_smooth.std() return pd.DataFrame(new_data) -class Classify(Operator): - """ Label the spectrum as "good", "noisy" or "discard" based on the quality of the spectrum. - """ - def __init__(self, classifier): - super().__init__() - self.classifier = classifier - - def _process_data(self, df): - """ - Parameters - ---------- - dfClient : tiled.client.dataframe.DataFrameClient - classifier : Callable - The classifier that takes in the spectrum and output a label. - - """ - return df +# TODO +class Classify(UnaryOperator): + """Label the spectrum as "good", "noisy" or "discard" based on the quality + of the spectrum.""" + ... -class PreNormalize(Operator): - """ - """ \ No newline at end of file +# TODO +class PreNormalize(UnaryOperator): + ...