Skip to content
This repository was archived by the owner on Dec 18, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 77 additions & 30 deletions src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
*/

import {Axis} from '@bokehjs/models/axes/axis';
import {cumulativeSum} from '../stats/array';
import {arrayMean, arrayMedian, cumulativeSum} from '../stats/array';
import {scaleToOne} from '../stats/dataTransformation';
import {
interval as hdiInterval,
data as hdiData,
} from '../stats/highestDensityInterval';
import {oneD} from '../stats/marginal';
import {mean as computeMean} from '../stats/pointStatistic';
import {interpolatePoints} from '../stats/utils';
import * as interfaces from './interfaces';

Expand Down Expand Up @@ -46,6 +45,8 @@ export const updateAxisLabel = (axis: Axis, label: string | null): void => {
* @param {number[]} marginalX - The support of the Kernel Density Estimate of the
* random variable.
* @param {number[]} marginalY - The Kernel Density Estimate of the random variable.
* @param {number} activeStatistic - The statistic to show in the tool. 0 is the mean
* and 1 is the median.
* @param {number | null} [hdiProb=null] - The highest density interval probability
* value. If the default value is not overwritten, then the default HDI probability
* is 0.89. See Statistical Rethinking by McElreath for a description as to why this
Expand All @@ -62,6 +63,7 @@ export const computeStats = (
rawData: number[],
marginalX: number[],
marginalY: number[],
activeStatistic: number,
hdiProb: number | null = null,
text_align: string[] = ['right', 'center', 'left'],
x_offset: number[] = [-5, 0, 5],
Expand All @@ -72,24 +74,44 @@ export const computeStats = (

// Compute the point statistics for the KDE, and create labels to display them in the
// figures.
const mean = computeMean(rawData);
const mean = arrayMean(rawData);
const median = arrayMedian(rawData);
const hdiBounds = hdiInterval(rawData, hdiProbability);
const x = [hdiBounds.lowerBound, mean, hdiBounds.upperBound];
const y = interpolatePoints({x: marginalX, y: marginalY, points: x});
const text = [
let x = [hdiBounds.lowerBound, mean, median, hdiBounds.upperBound];
let y = interpolatePoints({x: marginalX, y: marginalY, points: x});
let text = [
`Lower HDI: ${hdiBounds.lowerBound.toFixed(3)}`,
`Mean: ${mean.toFixed(3)}`,
`Median: ${median.toFixed(3)}`,
`Upper HDI: ${hdiBounds.upperBound.toFixed(3)}`,
];

return {
// We will filter the output based on the active statistic from the tool.
let mask: number[] = [];
if (activeStatistic === 0) {
mask = [0, 1, 3];
} else if (activeStatistic === 1) {
mask = [0, 2, 3];
}
x = mask.map((i) => {
return x[i];
});
y = mask.map((i) => {
return y[i];
});
text = mask.map((i) => {
return text[i];
});

const output = {
x: x,
y: y,
text: text,
text_align: text_align,
x_offset: x_offset,
y_offset: y_offset,
};
return output;
};

/**
Expand All @@ -100,6 +122,8 @@ export const computeStats = (
* calculating the Kernel Density Estimate (KDE).
* @param {number} hdiProbability - The highest density interval probability to use when
* calculating the HDI.
* @param {number} activeStatistic - The statistic to show in the tool. 0 is the mean
* and 1 is the median.
* @returns {interfaces.Data} The marginal distribution and cumulative
* distribution calculated from the given random variable data. Point statistics are
* also calculated.
Expand All @@ -108,6 +132,7 @@ export const computeData = (
data: number[],
bwFactor: number,
hdiProbability: number,
activeStatistic: number,
): interfaces.Data => {
const output = {} as interfaces.Data;
for (let i = 0; i < figureNames.length; i += 1) {
Expand All @@ -125,7 +150,13 @@ export const computeData = (
}

// Compute the point statistics for the given data.
const stats = computeStats(data, distribution.x, distribution.y, hdiProbability);
const stats = computeStats(
data,
distribution.x,
distribution.y,
activeStatistic,
hdiProbability,
);

output[figureName] = {
distribution: distribution,
Expand All @@ -150,6 +181,7 @@ export const computeData = (
* application.
* @param {interfaces.Figures} figures - Bokeh figures shown in the application.
* @param {interfaces.Tooltips} tooltips - Bokeh tooltips shown on the glyphs.
* @param {interfaces.Widgets} widgets - Bokeh widget object for the tool.
* @returns {number} We display the value of the bandwidth used for computing the Kernel
* Density Estimate in a div, and must return that value here in order to update the
* value displayed to the user.
Expand All @@ -162,29 +194,44 @@ export const update = (
sources: interfaces.Sources,
figures: interfaces.Figures,
tooltips: interfaces.Tooltips,
widgets: interfaces.Widgets,
): number => {
const computedData = computeData(data, bwFactor, hdiProbability);
for (let i = 0; i < figureNames.length; i += 1) {
// Update all sources with new data calculated above.
const figureName = figureNames[i];
sources[figureName].distribution.data = {
x: computedData[figureName].distribution.x,
y: computedData[figureName].distribution.y,
};
sources[figureName].hdi.data = {
base: computedData[figureName].hdi.base,
lower: computedData[figureName].hdi.lower,
upper: computedData[figureName].hdi.upper,
};
sources[figureName].stats.data = computedData[figureName].stats;
sources[figureName].labels.data = computedData[figureName].labels;
const activeStatistic = widgets.stats_button.active as number;
const computedData = computeData(data, bwFactor, hdiProbability, activeStatistic);

// Update the axes labels.
updateAxisLabel(figures[figureName].below[0], rvName);
// Marginal figure.
// eslint-disable-next-line prefer-destructuring
const bandwidth = computedData.marginal.distribution.bandwidth;
sources.marginal.distribution.data = {
x: computedData.marginal.distribution.x,
y: computedData.marginal.distribution.y,
};
sources.marginal.hdi.data = {
base: computedData.marginal.hdi.base,
lower: computedData.marginal.hdi.lower,
upper: computedData.marginal.hdi.upper,
};
sources.marginal.stats.data = computedData.marginal.stats;
sources.marginal.labels.data = computedData.marginal.labels;
tooltips.marginal.distribution.tooltips = [[rvName, '@x']];
tooltips.marginal.stats.tooltips = [['', '@text']];
updateAxisLabel(figures.marginal.below[0] as Axis, rvName);

// Update the tooltips.
tooltips[figureName].stats.tooltips = [['', '@text']];
tooltips[figureName].distribution.tooltips = [[rvName, '@x']];
}
return computedData.marginal.distribution.bandwidth;
// Cumulative figure.
sources.cumulative.distribution.data = {
x: computedData.cumulative.distribution.x,
y: computedData.cumulative.distribution.y,
};
sources.cumulative.hdi.data = {
base: computedData.cumulative.hdi.base,
lower: computedData.cumulative.hdi.lower,
upper: computedData.cumulative.hdi.upper,
};
sources.cumulative.stats.data = computedData.cumulative.stats;
sources.cumulative.labels.data = computedData.cumulative.labels;
tooltips.cumulative.distribution.tooltips = [[rvName, '@x']];
tooltips.cumulative.stats.tooltips = [['', '@text']];
updateAxisLabel(figures.cumulative.below[0] as Axis, rvName);

return bandwidth;
};
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import {Plot} from '@bokehjs/models/plots/plot';
import {ColumnDataSource} from '@bokehjs/models/sources/column_data_source';
import {HoverTool} from '@bokehjs/models/tools/inspectors/hover_tool';
import {Div} from '@bokehjs/models/widgets/div';
import {RadioButtonGroup} from '@bokehjs/models/widgets/radio_button_group';
import {Select} from '@bokehjs/models/widgets/selectbox';
import {Slider} from '@bokehjs/models/widgets/slider';

// NOTE: In the corresponding Python typing files for the diagnostic tool, we define
// similar types using a TypedDict object. TypeScript allows us to maintain
Expand Down Expand Up @@ -95,3 +99,11 @@ export interface Tooltips {
marginal: Tooltip;
cumulative: Tooltip;
}

export interface Widgets {
rv_select: Select;
bw_factor_slider: Slider;
bw_div: Div;
hdi_slider: Slider;
stats_button: RadioButtonGroup;
}
43 changes: 43 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/src/stats/array.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,29 @@
* LICENSE file in the root directory of this source tree.
*/

/**
* Syntactic sugar for summing an array of numbers.
*
* @param {number[]} data - The array of data.
* @returns {number} The sum of the array of data.
*/
export const arraySum = (data: number[]): number => {
return data.reduce((previousValue, currentValue) => {
return previousValue + currentValue;
});
};

/**
* Calculate the mean of the given array of data.
*
* @param {number[]} data - The array of data.
* @returns {number} The mean of the given data.
*/
export const arrayMean = (data: number[]): number => {
const dataSum = arraySum(data);
return dataSum / data.length;
};

/**
* Cumulative sum of the given data.
*
Expand Down Expand Up @@ -128,3 +151,23 @@ export const valueCounts = (data: number[]): {[key: string]: number} => {
}
return counts;
};

/**
* Calculate the median value for the given array.
*
* @param {number[]} data - Numerical array of data.
* @returns {number} The median value of the given data.
*/
export const arrayMedian = (data: number[]): number => {
const sortedArray = numericalSort(data);
const arrayLength = sortedArray.length;
const isEven = sortedArray.length % 2 === 0;
let median;
if (isEven) {
const index = arrayLength / 2;
median = (sortedArray[index - 1] + sortedArray[index]) / 2;
} else {
median = sortedArray[Math.floor(arrayLength / 2)];
}
return median;
};
17 changes: 14 additions & 3 deletions src/beanmachine/ppl/diagnostics/tools/marginal1d/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ class Marginal1d(DiagnosticToolBaseClass):

Attributes:
data (Dict[str, List[List[float]]]): JSON serializable representation of the
given `mcs` object.
given ``mcs`` object.
rv_names (List[str]): The list of random variables string names for the given
model.
num_chains (int): The number of chains of the model.
num_draws (int): The number of draws of the model for each chain.
palette (List[str]): A list of color values used for the glyphs in the figures.
The colors are specifically chosen from the Colorblind palette defined in
Bokeh.
The colors are specifically chosen from the ``Colorblind`` palette defined
in Bokeh.
tool_js (str):The JavaScript callbacks needed to render the Bokeh tool
independently from a Python server.
"""
Expand All @@ -40,6 +40,12 @@ def __init__(self: Marginal1d, mcs: MonteCarloSamples) -> None:
super(Marginal1d, self).__init__(mcs)

def create_document(self: Marginal1d) -> Model:
"""
Create the Bokeh document for the diagnostic tool.

Returns:
Model: A Bokeh Model object.
"""
# Initialize widget values using Python.
rv_name = self.rv_names[0]
bw_factor = 1.0
Expand Down Expand Up @@ -110,6 +116,7 @@ def create_document(self: Marginal1d) -> Model:
sources,
figures,
tooltips,
widgets,
);
}} catch (error) {{
{self.tool_js}
Expand All @@ -121,6 +128,7 @@ def create_document(self: Marginal1d) -> Model:
sources,
figures,
tooltips,
widgets,
);
}}
"""
Expand All @@ -135,6 +143,7 @@ def create_document(self: Marginal1d) -> Model:
"figures": figures,
"tooltips": tooltips,
"toolView": tool_view,
"widgets": widgets,
}

# Each widget requires slightly different JS, except for the sliders.
Expand All @@ -155,10 +164,12 @@ def create_document(self: Marginal1d) -> Model:
"""
rv_select_callback = CustomJS(args=callback_arguments, code=rv_select_js)
slider_callback = CustomJS(args=callback_arguments, code=slider_js)
button_callback = CustomJS(args=callback_arguments, code=slider_js)

# Tell Python to use the JavaScript.
widgets["rv_select"].js_on_change("value", rv_select_callback)
widgets["bw_factor_slider"].js_on_change("value", slider_callback)
widgets["hdi_slider"].js_on_change("value", slider_callback)
widgets["stats_button"].js_on_change("active", button_callback)

return tool_view
Loading