diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 39030653b..5f02b3e05 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -12,10 +12,29 @@ import numpy as np import xarray as xr +from numpy.typing import NDArray _logger = logging.getLogger(__name__) +def _flatten_or_average(arr: NDArray) -> NDArray: + """Flatten array or average across non-quantile dimensions. + + Parameters + ---------- + arr : np.ndarray + Input array, possibly multi-dimensional. + + Returns + ------- + np.ndarray + Flattened 1D array, averaged across extra dimensions if needed. + """ + if arr.ndim > 1: + return np.mean(arr, axis=tuple(range(1, arr.ndim))).flatten() + return arr + + def collect_streams(runs: dict): """Get all unique streams across runs, sorted. @@ -409,6 +428,111 @@ def list_streams(cls): return list(cls._marker_size_stream.keys()) +def quantile_plot_metric_region( + metric: str, + region: str, + runs: dict, + scores_dict: dict, + plotter: object, +) -> None: + """ + Create quantile-quantile (Q-Q) plots for extreme value analysis for all streams + and channels for a given metric and region. + + Parameters + ---------- + metric: str + String specifying the metric to plot (should be 'qq_analysis') + region: str + String specifying the region to plot + runs: dict + Dictionary containing the config for all runs + scores_dict : dict + The dictionary containing all computed metrics. + plotter: + Plotter object to handle the plotting part. Must have a qq_plot method. + """ + streams_set = collect_streams(runs) + channels_set = collect_channels(scores_dict, metric, region, runs) + + for stream in streams_set: + for ch in channels_set: + selected_data, labels, run_ids = [], [], [] + qq_full_data = [] # Store full Q-Q datasets for detailed plotting + + for run_id, data in scores_dict[metric][region].get(stream, {}).items(): + # skip if channel is missing + if ch not in np.atleast_1d(data.channel.values): + continue + + # Select channel + data_for_channel = data.sel(channel=ch) if "channel" in data.dims else data + + # Check for NaN + if data_for_channel.isnull().all(): + continue + + # For qq_analysis, extract Q-Q data from attributes + if metric == "qq_analysis" and "p_quantiles" in data_for_channel.attrs: + attrs = data_for_channel.attrs + # Convert to numpy arrays once + p_quantiles_arr = np.array(attrs["p_quantiles"]) + gt_quantiles_arr = np.array(attrs["gt_quantiles"]) + qq_deviation_arr = np.array(attrs["qq_deviation"]) + qq_deviation_norm_arr = np.array(attrs["qq_deviation_normalized"]) + quantile_levels = np.array(attrs["quantile_levels"]) + extreme_low_mse = float(np.mean(np.array(attrs["extreme_low_mse"]))) + extreme_high_mse = float(np.mean(np.array(attrs["extreme_high_mse"]))) + + qq_dataset = xr.Dataset( + { + "quantile_levels": (["quantile"], quantile_levels), + "p_quantiles": (["quantile"], _flatten_or_average(p_quantiles_arr)), + "gt_quantiles": (["quantile"], _flatten_or_average(gt_quantiles_arr)), + "qq_deviation": (["quantile"], _flatten_or_average(qq_deviation_arr)), + "qq_deviation_normalized": ( + ["quantile"], + _flatten_or_average(qq_deviation_norm_arr), + ), + "extreme_low_mse": ([], extreme_low_mse), + "extreme_high_mse": ([], extreme_high_mse), + } + ) + # Store extreme percentiles for plotting + qq_dataset.attrs["extreme_percentiles"] = tuple(attrs["extreme_percentiles"]) + qq_full_data.append(qq_dataset) + + selected_data.append(data_for_channel) + labels.append(runs[run_id].get("label", run_id)) + run_ids.append(run_id) + + if selected_data: + _logger.info(f"Creating Q-Q plot for {metric} - {region} - {stream} - {ch}.") + + name = create_filename( + prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream, ch] + ) + + # Check if plotter has qq_plot method and Q-Q data is available + if hasattr(plotter, "qq_plot") and qq_full_data: + _logger.info(f"Creating Q-Q plot with {len(qq_full_data)} dataset(s).") + # Extract extreme_percentiles from dataset + extreme_pct = qq_full_data[0].attrs["extreme_percentiles"] + plotter.qq_plot( + qq_full_data, + labels, + tag=name, + metric=metric, + extreme_percentiles=extreme_pct, + ) + else: + # Skip plotting if no Q-Q data available + _logger.warning( + f"Q-Q data not available for {metric} - {region} - {stream} - {ch}. " + f"Skipping plot generation." + ) + + def create_filename( *, prefix: Sequence[str] = (), diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 73a3627e2..cc03f1edd 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -1138,6 +1138,220 @@ def heat_map( plt.savefig(f"{self.out_plot_dir.joinpath(name)}.{self.image_format}") +class QuantilePlots: + def __init__(self, plotter_cfg: dict, output_basedir: str | Path): + """ + Initialize the QuantilePlots class. + + Parameters + ---------- + plotter_cfg: + Configuration dictionary containing basic information for plotting. + Expected keys are: + - image_format: Format of the saved images (e.g., 'png', 'pdf', etc.) + - dpi_val: DPI value for the saved images + - fig_size: Size of the figure (width, height) in inches + output_basedir: + Base directory under which the plots will be saved. + Expected scheme `/`. + """ + self.image_format = plotter_cfg.get("image_format") + self.dpi_val = plotter_cfg.get("dpi_val") + self.fig_size = plotter_cfg.get("fig_size") + self.out_plot_dir = Path(output_basedir) / "quantile_plots" + + if not os.path.exists(self.out_plot_dir): + _logger.info(f"Creating dir {self.out_plot_dir}") + os.makedirs(self.out_plot_dir, exist_ok=True) + + _logger.info(f"Saving quantile plots to: {self.out_plot_dir}") + + def _check_lengths(self, data: xr.DataArray | list, labels: str | list) -> tuple[list, list]: + """ + Check if the lengths of data and labels match. + + Parameters + ---------- + data: + DataArray or list of DataArrays to be plotted + labels: + Label or list of labels for each dataset + + Returns + ------- + data_list, label_list - lists of data and labels + """ + assert isinstance(data, xr.DataArray | list), ( + "QuantilePlots::_check_lengths - Data should be of type xr.DataArray or list" + ) + assert isinstance(labels, str | list), ( + "QuantilePlots::_check_lengths - Labels should be of type str or list" + ) + + data_list = [data] if isinstance(data, xr.DataArray) else data + label_list = [labels] if isinstance(labels, str) else labels + + assert len(data_list) == len(label_list), ( + "QuantilePlots::_check_lengths - Data and Labels do not match" + ) + + return data_list, label_list + + def qq_plot( + self, + qq_data: list[xr.Dataset], + labels: str | list, + tag: str = "", + metric: str = "qq_analysis", + extreme_percentiles: tuple[float, float] | None = None, + ) -> None: + """ + Create quantile-quantile (Q-Q) plots for extreme value analysis. + + This method generates comprehensive Q-Q plots comparing forecast quantiles + against ground truth quantiles, with emphasis on extreme values. + + Parameters + ---------- + qq_data: + Dataset or list of Datasets containing Q-Q analysis results. + Each dataset should contain: + - 'quantile_levels': Theoretical quantile levels (0 to 1) + - 'p_quantiles': Quantile values from prediction data + - 'gt_quantiles': Quantile values from ground truth data + - 'qq_deviation': Absolute difference between quantiles + - 'extreme_low_mse': MSE for lower extreme quantiles + - 'extreme_high_mse': MSE for upper extreme quantiles + labels: + Label or list of labels for each dataset + tag: + Tag to be added to the plot title and filename + metric: + Name of the metric (default: 'qq_analysis') + extreme_percentiles: + Lower and upper percentile thresholds for extreme regions. + + Returns + ------- + None + """ + data_list, label_list = self._check_lengths(qq_data, labels) + + # Use extreme_percentiles from data if not explicitly provided + if extreme_percentiles is None: + extreme_percentiles = tuple(data_list[0].attrs.get("extreme_percentiles", (5.0, 95.0))) + + # Create figure with subplots + fig = plt.figure(figsize=(16, 6), dpi=self.dpi_val) + gs = fig.add_gridspec(1, 2, width_ratios=[2, 1], wspace=0.3) + + ax_qq = fig.add_subplot(gs[0]) # Main Q-Q plot + ax_dev = fig.add_subplot(gs[1]) # Deviation plot + + colors = plt.cm.tab10(np.linspace(0, 1, len(data_list))) + + for _i, (ds, label, color) in enumerate(zip(data_list, label_list, colors, strict=False)): + # Extract quantile data + quantile_levels = ds["quantile_levels"].values + p_quantiles = ds["p_quantiles"].values + gt_quantiles = ds["gt_quantiles"].values + qq_deviation = ds["qq_deviation"].values + + # Main Q-Q plot + ax_qq.scatter( + gt_quantiles, + p_quantiles, + alpha=0.6, + s=20, + c=[color], + label=label, + edgecolors="none", + ) + + # Deviation plot + ax_dev.plot( + quantile_levels, + qq_deviation, + label=label, + color=color, + linewidth=2, + alpha=0.8, + ) + + # Format main Q-Q plot + ax_qq.set_xlabel("Ground Truth Quantiles", fontsize=12) + ax_qq.set_ylabel("Prediction Quantiles", fontsize=12) + ax_qq.set_title("Quantile-Quantile Plot for Extreme Value Analysis", fontsize=14) + + # Add perfect agreement line (y=x) + min_val = min([ds["gt_quantiles"].min().values for ds in data_list]) + max_val = max([ds["gt_quantiles"].max().values for ds in data_list]) + ax_qq.plot( + [min_val, max_val], + [min_val, max_val], + "k--", + linewidth=2, + label="Perfect Agreement", + alpha=0.7, + ) + + # Add shaded regions for extremes + if len(data_list) > 0: + ds_ref = data_list[0] + quantile_levels = ds_ref["quantile_levels"].values + + # Find extreme regions using configurable thresholds + lower_extreme_idx = quantile_levels < (extreme_percentiles[0] / 100) + upper_extreme_idx = quantile_levels > (extreme_percentiles[1] / 100) + + if np.any(lower_extreme_idx): + lower_q = ds_ref["gt_quantiles"].values[lower_extreme_idx] + ax_qq.axvspan( + min_val, + lower_q.max() if len(lower_q) > 0 else min_val, + alpha=0.1, + color="blue", + label="Lower Extreme Zone", + ) + + if np.any(upper_extreme_idx): + upper_q = ds_ref["gt_quantiles"].values[upper_extreme_idx] + ax_qq.axvspan( + upper_q.min() if len(upper_q) > 0 else max_val, + max_val, + alpha=0.1, + color="red", + label="Upper Extreme Zone", + ) + + ax_qq.legend(frameon=False, loc="upper left", fontsize=10) + ax_qq.grid(True, linestyle="--", alpha=0.3) + + # Format deviation plot + ax_dev.set_xlabel("Quantile Level", fontsize=12) + ax_dev.set_ylabel("Absolute Deviation", fontsize=12) + ax_dev.set_title("Quantile Deviation", fontsize=14) + ax_dev.legend(frameon=False, fontsize=10) + ax_dev.grid(True, linestyle="--", alpha=0.3) + + # Highlight extreme regions in deviation plot + lower_threshold = extreme_percentiles[0] / 100 + upper_threshold = extreme_percentiles[1] / 100 + ax_dev.axvspan(0.0, lower_threshold, alpha=0.1, color="blue") + ax_dev.axvspan(upper_threshold, 1.0, alpha=0.1, color="red") + + plt.tight_layout() + + # Save the plot + parts = ["qq_analysis", tag] + name = "_".join(filter(None, parts)) + save_path = self.out_plot_dir.joinpath(f"{name}.{self.image_format}") + plt.savefig(save_path, bbox_inches="tight") + plt.close() + + _logger.info(f"Q-Q plot saved to {save_path}") + + class ScoreCards: """ Initialize the ScoreCards class. diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index 3099b0bd4..d4846b894 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -190,6 +190,7 @@ def __init__( "grad_amplitude": self.calc_spatial_variability, "psnr": self.calc_psnr, "seeps": self.calc_seeps, + "qq_analysis": self.calc_quantiles, } self.prob_metrics_dict = { "ssr": self.calc_ssr, @@ -241,6 +242,7 @@ def get_score( """ if score_name in self.det_metrics_dict.keys(): f = self.det_metrics_dict[score_name] + _logger.debug(f"Using deterministic metric: {score_name}") elif score_name in self.prob_metrics_dict.keys(): assert self.ens_dim in data.prediction.dims, ( f"Probablistic score {score_name} chosen, but ensemble dimension {self.ens_dim} " @@ -1495,3 +1497,138 @@ def check_for_coords(coord_names_data, coord_names_expected): raise ValueError(f"Second-order differentation is not implemenetd in {method} yet.") return var_diff_amplitude + + def calc_quantiles( + self, + p: xr.DataArray, + gt: xr.DataArray, + n_quantiles: int = 100, + quantile_method: str = "linear", + focus_extremes: bool = True, + extreme_percentiles: tuple[float, float] = (5.0, 95.0), # 5th and 95th percentiles + iqr_percentiles: tuple[float, float] = (25.0, 75.0), + ) -> xr.DataArray: + """ + Calculate quantile-quantile (Q-Q) analysis metric for extreme value evaluation. + + This metric compares the distribution of forecast values with ground truth values + by computing quantiles and their deviations. + + The Q-Q analysis returns quantile values from both prediction and ground truth, + along with metrics to assess how well the forecast captures the observed distribution, + especially in the tails (extremes). + + Parameters + ---------- + p: xr.DataArray + Forecast data array + gt: xr.DataArray + Ground truth data array + n_quantiles: int + Number of quantiles to calculate for the Q-Q plot. Default is 100. + Higher values provide finer resolution of the distribution. + quantile_method: str + Method for quantile calculation. Options: 'linear', 'lower', 'higher', + 'midpoint', 'nearest'. Default is 'linear'. + focus_extremes: bool + If True, additional quantiles are computed in the extreme tails of the + distribution for better resolution of extreme values. Default is True. + extreme_percentiles: tuple[float, float] + Lower and upper percentile thresholds to define extremes. + Default is (5.0, 95.0), meaning values below 5th and above 95th percentile + are considered extremes. + iqr_percentiles: tuple[float, float] + Lower and upper percentile thresholds for interquartile range (IQR) calculation. + Default is (25.0, 75.0), meaning IQR is computed as difference between 75th and + 25th percentiles. + + Returns + ------- + xr.DataArray + Dataset containing Q-Q analysis results with the following variables: + - 'quantile': Theoretical quantile levels (0 to 1) + - 'p_quantiles': Quantile values from prediction data + - 'gt_quantiles': Quantile values from ground truth data + - 'qq_deviation': Absolute difference between prediction and ground truth quantiles + - 'qq_deviation_normalized': Normalized deviation (relative to ground truth IQR) + - 'extreme_low_mse': MSE for lower extreme quantiles + - 'extreme_high_mse': MSE for upper extreme quantiles + - 'overall_qq_score': Overall Q-Q score (lower is better, 0 is perfect) + """ + if self._agg_dims is None: + raise ValueError( + "Cannot calculate Q-Q analysis without aggregation dimensions (agg_dims=None)." + ) + + _logger.info(f"Starting Q-Q analysis with {n_quantiles} quantiles") + + # Generate quantile levels + if focus_extremes: + # Add more resolution in the tails for extreme value analysis + lower_tail = np.linspace(0.001, extreme_percentiles[0] / 100, 20) + middle = np.linspace( + extreme_percentiles[0] / 100, extreme_percentiles[1] / 100, n_quantiles - 40 + ) + upper_tail = np.linspace(extreme_percentiles[1] / 100, 0.999, 20) + quantile_levels = np.concatenate([lower_tail, middle, upper_tail]) + else: + quantile_levels = np.linspace(0.001, 0.999, n_quantiles) + + # Stack aggregation dimensions into a single dimension + p_flat = p.stack({"_agg_points": self._agg_dims}) + gt_flat = gt.stack({"_agg_points": self._agg_dims}) + + # Remove NaN values before quantile calculation + p_flat = p_flat.dropna(dim="_agg_points", how="all") + gt_flat = gt_flat.dropna(dim="_agg_points", how="all") + + # Calculate quantiles using xarray's quantile method + p_quantiles = p_flat.quantile(quantile_levels, dim="_agg_points", method=quantile_method) + gt_quantiles = gt_flat.quantile(quantile_levels, dim="_agg_points", method=quantile_method) + + # Calculate Q-Q deviations + qq_deviation = np.abs(p_quantiles - gt_quantiles) + + # Calculate normalized deviation (relative to interquartile range of ground truth) + gt_q_low = gt_flat.quantile(iqr_percentiles[0] / 100, dim="_agg_points") + gt_q_high = gt_flat.quantile(iqr_percentiles[1] / 100, dim="_agg_points") + iqr = gt_q_high - gt_q_low + # Avoid division by zero + iqr = iqr.where(iqr > 1e-10, 1.0) + + qq_deviation_normalized = qq_deviation / iqr + + # Calculate MSE for extreme quantiles + extreme_low_mask = quantile_levels < (extreme_percentiles[0] / 100) + extreme_high_mask = quantile_levels > (extreme_percentiles[1] / 100) + + extreme_low_mse = ( + ((p_quantiles - gt_quantiles) ** 2).isel(quantile=extreme_low_mask).mean(dim="quantile") + ) + extreme_high_mse = ( + ((p_quantiles - gt_quantiles) ** 2) + .isel(quantile=extreme_high_mask) + .mean(dim="quantile") + ) + + # Calculate overall Q-Q score (mean absolute deviation across all quantiles) + overall_qq_score = qq_deviation.mean(dim="quantile") + + # Store Q-Q data as xarray attributes for automatic JSON serialization + overall_qq_score.attrs.update( + { + "p_quantiles": p_quantiles.values.tolist(), + "gt_quantiles": gt_quantiles.values.tolist(), + "qq_deviation": qq_deviation.values.tolist(), + "qq_deviation_normalized": qq_deviation_normalized.values.tolist(), + "extreme_low_mse": extreme_low_mse.values.tolist(), + "extreme_high_mse": extreme_high_mse.values.tolist(), + "quantile_levels": quantile_levels.tolist(), + "extreme_percentiles": list(extreme_percentiles), + "iqr_percentiles": list(iqr_percentiles), + } + ) + + _logger.info(f"Q-Q analysis completed with {len(overall_qq_score.attrs)} attributes") + + return overall_qq_score diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index d911a370b..033e5b669 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -25,10 +25,17 @@ bar_plot_metric_region, heat_maps_metric_region, plot_metric_region, + quantile_plot_metric_region, ratio_plot_metric_region, score_card_metric_region, ) -from weathergen.evaluate.plotting.plotter import BarPlots, LinePlots, Plotter, ScoreCards +from weathergen.evaluate.plotting.plotter import ( + BarPlots, + LinePlots, + Plotter, + QuantilePlots, + ScoreCards, +) from weathergen.evaluate.scores.score import VerifiedData, get_score from weathergen.evaluate.utils.clim_utils import get_climatology from weathergen.evaluate.utils.regions import RegionBoundingBox @@ -140,6 +147,9 @@ def calc_scores_per_stream( ) lead_time_map = {} + # Store metric-specific attributes that get lost during concat + # Key: (fstep, metric) -> attrs dict + all_metric_attrs = {} for (fstep, tars), (_, preds) in zip(da_tars.items(), da_preds.items(), strict=False): if preds.ipoint.size == 0: @@ -166,8 +176,9 @@ def calc_scores_per_stream( # Build up computation graphs for all metrics _logger.debug(f"Build computation graphs for metrics for stream {stream}...") - # Add it only if it is not None + # Calculate scores and filter out None values valid_scores = [] + valid_metric_names = [] for metric in metrics: score = get_score( @@ -175,16 +186,31 @@ def calc_scores_per_stream( ) if score is not None: valid_scores.append(score) + valid_metric_names.append(metric) + else: + _logger.warning(f"Metric {metric} returned None, skipping") - valid_metric_names = [ - metric - for metric, score in zip(metrics, valid_scores, strict=False) - if score is not None - ] if not valid_scores: continue - combined_metrics = xr.concat(valid_scores, dim="metric") + # Concatenate all metrics using "minimal" to handle metrics with different coords + # Preserve attributes from individual metrics (e.g., Q-Q analysis data) + # Store attributes before concat as they may be lost + for metric, score in zip(valid_metric_names, valid_scores, strict=False): + if score.attrs: + # Store with key (fstep, metric) for later restoration + all_metric_attrs[(int(fstep), metric)] = score.attrs.copy() + _logger.debug( + f"Stored {len(score.attrs)} attrs for fstep={fstep}, metric={metric}" + ) + + combined_metrics = xr.concat( + valid_scores, + dim="metric", + coords="minimal", + combine_attrs="drop_conflicts", + ) + combined_metrics = combined_metrics.assign_coords(metric=valid_metric_names) combined_metrics = combined_metrics.compute() @@ -202,6 +228,40 @@ def calc_scores_per_stream( metric_stream.loc[criteria] = combined_metrics + # Restore metric-specific coordinates that were dropped by coords="minimal" + # (e.g., quantiles, extreme_percentiles for qq_analysis) + for coord_name in combined_metrics.coords: + # Skip coordinates that are already dimensions (no need to restore) + if coord_name in combined_metrics.dims or coord_name in metric_stream.dims: + continue + + # Only restore coordinates whose dimensions exist in metric_stream + # (e.g., skip coords with 'quantile' dim if metric_stream doesn't have it) + coord_dims = combined_metrics.coords[coord_name].dims + if not all(dim in metric_stream.dims for dim in coord_dims): + _logger.debug( + f"Skipping coordinate '{coord_name}' with incompatible " + f"dimensions {coord_dims} (metric_stream has {metric_stream.dims})" + ) + continue + + # Initialize coordinate in metric_stream if it doesn't exist yet + if coord_name not in metric_stream.coords: + coord_shape = tuple(len(metric_stream.coords[dim]) for dim in coord_dims) + metric_stream = metric_stream.assign_coords( + { + coord_name: xr.DataArray( + np.full(coord_shape, "", dtype=object), + dims=coord_dims, + coords={dim: metric_stream.coords[dim] for dim in coord_dims}, + ) + } + ) + + # Build indexers to select the right location in metric_stream + indexers = {dim: criteria[dim] for dim in coord_dims if dim in criteria} + metric_stream.coords[coord_name].loc[indexers] = combined_metrics.coords[coord_name] + lead_time_map[fstep] = ( np.unique(combined_metrics.lead_time.values.astype("timedelta64[h]")) if "lead_time" in combined_metrics.coords @@ -225,12 +285,22 @@ def calc_scores_per_stream( ) _logger.info(f"Scores for run {reader.run_id} - {stream} calculated successfully.") + _logger.debug(f"all_metric_attrs keys: {list(all_metric_attrs.keys())}") # Build local dictionary for this region for metric in metrics: + metric_data = metric_stream.sel({"metric": metric}) + # Restore metric-specific attributes from all forecast steps + # Attributes are the same across forecast steps for a given metric + for (_stored_fstep, stored_metric), attrs in all_metric_attrs.items(): + if stored_metric == metric and attrs: + _logger.debug(f"Restoring {len(attrs)} attributes for {metric}") + metric_data.attrs.update(attrs) + break + local_scores.setdefault(metric, {}).setdefault(region, {}).setdefault(stream, {})[ reader.run_id - ] = metric_stream.sel({"metric": metric}) + ] = metric_data return local_scores @@ -481,10 +551,7 @@ def metric_list_to_json( for metric, metric_stream in metrics_dict.items(): for region in regions: - metric_now = metric_stream[region][stream] - for run_id in metric_now.keys(): - metric_now = metric_now[run_id] - + for run_id, metric_data in metric_stream[region][stream].items(): # Match the expected filename pattern save_path = ( reader.metrics_dir @@ -493,7 +560,7 @@ def metric_list_to_json( _logger.info(f"Saving results to {save_path}") with open(save_path, "w") as f: - json.dump(metric_now.to_dict(), f, indent=4) + json.dump(metric_data.to_dict(), f, indent=4) _logger.info( f"Saved all results of inference run {reader.run_id} - mini_epoch {reader.mini_epoch:d} " @@ -534,9 +601,13 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): plotter = LinePlots(plot_cfg, summary_dir) sc_plotter = ScoreCards(plot_cfg, summary_dir) br_plotter = BarPlots(plot_cfg, summary_dir) + quantile_plotter = QuantilePlots(plot_cfg, summary_dir) for region in regions: for metric in metrics: - if eval_opt.get("summary_plots", True): + # Special handling for Q-Q analysis metric + if metric == "qq_analysis": + quantile_plot_metric_region(metric, region, runs, scores_dict, quantile_plotter) + elif eval_opt.get("summary_plots", True): plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary) if eval_opt.get("ratio_plots", False): ratio_plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary)