From 787498506f09cb11f9757fc6315b635ce896c772 Mon Sep 17 00:00:00 2001 From: "Dr. Jehangir Awan" Date: Tue, 27 Jan 2026 13:16:32 +0100 Subject: [PATCH 1/5] Add Q-Q analysis metric for extreme value evaluation --- .../evaluate/plotting/plot_utils.py | 158 +++++++++++++++++ .../weathergen/evaluate/plotting/plotter.py | 145 ++++++++++++++++ .../src/weathergen/evaluate/scores/score.py | 161 ++++++++++++++++++ .../src/weathergen/evaluate/utils/utils.py | 29 +++- 4 files changed, 491 insertions(+), 2 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 39030653b..b92285f8b 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -409,6 +409,164 @@ def list_streams(cls): return list(cls._marker_size_stream.keys()) +def qq_plot_metric_region( + metric: str, + region: str, + runs: dict, + scores_dict: dict, + plotter: object, +) -> None: + """ + Create quantile-quantile (Q-Q) plots for extreme value analysis for all streams + and channels for a given metric and region. + + Parameters + ---------- + metric: str + String specifying the metric to plot (should be 'qq_analysis') + region: str + String specifying the region to plot + runs: dict + Dictionary containing the config for all runs + scores_dict : dict + The dictionary containing all computed metrics. + plotter: + Plotter object to handle the plotting part. Must have a qq_plot method. + """ + streams_set = collect_streams(runs) + channels_set = collect_channels(scores_dict, metric, region, runs) + + for stream in streams_set: + for ch in channels_set: + selected_data, labels, run_ids = [], [], [] + qq_full_data = [] # Store full Q-Q datasets for detailed plotting + + for run_id, data in scores_dict[metric][region].get(stream, {}).items(): + # skip if channel is missing + if ch not in np.atleast_1d(data.channel.values): + continue + + # Select channel + data_for_channel = data.sel(channel=ch) if "channel" in data.dims else data + + # Check for NaN + if data_for_channel.isnull().all(): + continue + + # For qq_analysis, extract full Q-Q data from coordinate + if metric == "qq_analysis": + if "qq_full_data" in data_for_channel.coords: + import json + # Get the qq_full_data - it may be 0-d (scalar) or multi-d + qq_coord = data_for_channel.coords["qq_full_data"] + + # Extract the first non-empty JSON string + if qq_coord.ndim == 0: + qq_data_str = qq_coord.item() + else: + # Flatten and get first non-empty value + qq_values = qq_coord.values.flatten() + qq_data_str = next((v for v in qq_values if v and v != ""), None) + + if not qq_data_str: + _logger.warning(f"Empty qq_full_data for channel {ch}") + continue + + qq_data = json.loads(qq_data_str) + + # Convert lists back to numpy arrays for plotting + quantile_levels = np.array(qq_data["quantile_levels"]) + p_quantiles = np.array(qq_data["p_quantiles"]) + gt_quantiles = np.array(qq_data["gt_quantiles"]) + qq_deviation = np.array(qq_data["qq_deviation"]) + qq_deviation_normalized = np.array(qq_data.get("qq_deviation_normalized", qq_deviation)) + extreme_low_mse = np.array(qq_data.get("extreme_low_mse", 0.0)) + extreme_high_mse = np.array(qq_data.get("extreme_high_mse", 0.0)) + + # All quantile data is replicated across sample/ens dimensions + # Take the first sample's data (they're all identical) + quantile_levels_1d = quantile_levels.flatten() if quantile_levels.ndim > 1 else quantile_levels + + # For p_quantiles and gt_quantiles, average across samples if multidimensional + # Shape is typically (n_quantiles, n_samples, n_ens) + # We want to average across samples (axis 1) and ens (axis 2), keeping quantiles (axis 0) + if p_quantiles.ndim > 1: + # Average across all axes except the first (quantile axis) + p_quantiles_1d = np.mean(p_quantiles, axis=tuple(range(1, p_quantiles.ndim))).flatten() + else: + p_quantiles_1d = p_quantiles + + if gt_quantiles.ndim > 1: + gt_quantiles_1d = np.mean(gt_quantiles, axis=tuple(range(1, gt_quantiles.ndim))).flatten() + else: + gt_quantiles_1d = gt_quantiles + + if qq_deviation.ndim > 1: + qq_deviation_1d = np.mean(qq_deviation, axis=tuple(range(1, qq_deviation.ndim))).flatten() + else: + qq_deviation_1d = qq_deviation + + if qq_deviation_normalized.ndim > 1: + qq_deviation_normalized_1d = np.mean(qq_deviation_normalized, axis=tuple(range(1, qq_deviation_normalized.ndim))).flatten() + else: + qq_deviation_normalized_1d = qq_deviation_normalized + + # Handle scalar MSE values + extreme_low_mse_scalar = float(np.mean(extreme_low_mse)) + extreme_high_mse_scalar = float(np.mean(extreme_high_mse)) + + # Create Dataset with proper dimensions + qq_dataset = xr.Dataset({ + "quantile_levels": (["quantile"], quantile_levels_1d), + "p_quantiles": (["quantile"], p_quantiles_1d), + "gt_quantiles": (["quantile"], gt_quantiles_1d), + "qq_deviation": (["quantile"], qq_deviation_1d), + "qq_deviation_normalized": (["quantile"], qq_deviation_normalized_1d), + "extreme_low_mse": ([], extreme_low_mse_scalar), + "extreme_high_mse": ([], extreme_high_mse_scalar), + }) + qq_full_data.append(qq_dataset) + else: + _logger.warning(f"qq_full_data not found in coords for qq_analysis metric") + + selected_data.append(data_for_channel) + labels.append(runs[run_id].get("label", run_id)) + run_ids.append(run_id) + + if selected_data: + _logger.info( + f"Creating Q-Q plot for {metric} - {region} - {stream} - {ch}." + ) + + name = create_filename( + prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream, ch] + ) + + # Check if plotter has qq_plot method + if hasattr(plotter, "qq_plot") and qq_full_data: + _logger.info(f"Creating Q-Q plot with {len(qq_full_data)} dataset(s).") + plotter.qq_plot( + qq_full_data, + labels, + tag=name, + metric=metric, + ) + else: + # Fallback to standard plot if qq_plot not available + _logger.warning( + f"Plotter does not have qq_plot method or no full Q-Q data available. " + f"Falling back to standard plot for {metric}." + ) + selected_data, time_dim = _assign_time_coord(selected_data) + plotter.plot( + selected_data, + labels, + tag=name, + x_dim=time_dim, + y_dim=metric, + ) + + def create_filename( *, prefix: Sequence[str] = (), diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 73a3627e2..4664ee974 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -1137,6 +1137,151 @@ def heat_map( name = "_".join(filter(None, parts)) plt.savefig(f"{self.out_plot_dir.joinpath(name)}.{self.image_format}") + def qq_plot( + self, + qq_data: list[xr.Dataset], + labels: str | list, + tag: str = "", + metric: str = "qq_analysis", + ) -> None: + """ + Create quantile-quantile (Q-Q) plots for extreme value analysis. + + This method generates comprehensive Q-Q plots comparing forecast quantiles + against ground truth quantiles, with emphasis on extreme values. + + Parameters + ---------- + qq_data: + Dataset or list of Datasets containing Q-Q analysis results. + Each dataset should contain: + - 'quantile_levels': Theoretical quantile levels (0 to 1) + - 'p_quantiles': Quantile values from prediction data + - 'gt_quantiles': Quantile values from ground truth data + - 'qq_deviation': Absolute difference between quantiles + - 'extreme_low_mse': MSE for lower extreme quantiles + - 'extreme_high_mse': MSE for upper extreme quantiles + labels: + Label or list of labels for each dataset + tag: + Tag to be added to the plot title and filename + metric: + Name of the metric (default: 'qq_analysis') + + Returns + ------- + None + """ + data_list, label_list = self._check_lengths(qq_data, labels) + + # Create figure with subplots + fig = plt.figure(figsize=(16, 6), dpi=self.dpi_val) + gs = fig.add_gridspec(1, 2, width_ratios=[2, 1], wspace=0.3) + + ax_qq = fig.add_subplot(gs[0]) # Main Q-Q plot + ax_dev = fig.add_subplot(gs[1]) # Deviation plot + + colors = plt.cm.tab10(np.linspace(0, 1, len(data_list))) + + for i, (ds, label, color) in enumerate(zip(data_list, label_list, colors, strict=False)): + # Extract quantile data + quantile_levels = ds["quantile_levels"].values + p_quantiles = ds["p_quantiles"].values + gt_quantiles = ds["gt_quantiles"].values + qq_deviation = ds["qq_deviation"].values + + # Main Q-Q plot + ax_qq.scatter( + gt_quantiles, + p_quantiles, + alpha=0.6, + s=20, + c=[color], + label=label, + edgecolors="none", + ) + + # Deviation plot + ax_dev.plot( + quantile_levels, + qq_deviation, + label=label, + color=color, + linewidth=2, + alpha=0.8, + ) + + # Format main Q-Q plot + ax_qq.set_xlabel("Ground Truth Quantiles", fontsize=12) + ax_qq.set_ylabel("Prediction Quantiles", fontsize=12) + ax_qq.set_title("Quantile-Quantile Plot for Extreme Value Analysis", fontsize=14) + + # Add perfect agreement line (y=x) + min_val = min([ds["gt_quantiles"].min().values for ds in data_list]) + max_val = max([ds["gt_quantiles"].max().values for ds in data_list]) + ax_qq.plot( + [min_val, max_val], + [min_val, max_val], + "k--", + linewidth=2, + label="Perfect Agreement", + alpha=0.7, + ) + + # Add shaded regions for extremes + if len(data_list) > 0: + ds_ref = data_list[0] + quantile_levels = ds_ref["quantile_levels"].values + + # Find extreme regions (typically below 5% and above 95%) + lower_extreme_idx = quantile_levels < 0.05 + upper_extreme_idx = quantile_levels > 0.95 + + if np.any(lower_extreme_idx): + lower_q = ds_ref["gt_quantiles"].values[lower_extreme_idx] + ax_qq.axvspan( + min_val, + lower_q.max() if len(lower_q) > 0 else min_val, + alpha=0.1, + color="blue", + label="Lower Extreme Zone", + ) + + if np.any(upper_extreme_idx): + upper_q = ds_ref["gt_quantiles"].values[upper_extreme_idx] + ax_qq.axvspan( + upper_q.min() if len(upper_q) > 0 else max_val, + max_val, + alpha=0.1, + color="red", + label="Upper Extreme Zone", + ) + + ax_qq.legend(frameon=False, loc="upper left", fontsize=10) + ax_qq.grid(True, linestyle="--", alpha=0.3) + + # Format deviation plot + ax_dev.set_xlabel("Quantile Level", fontsize=12) + ax_dev.set_ylabel("Absolute Deviation", fontsize=12) + ax_dev.set_title("Quantile Deviation", fontsize=14) + ax_dev.legend(frameon=False, fontsize=10) + ax_dev.grid(True, linestyle="--", alpha=0.3) + + # Highlight extreme regions in deviation plot + ax_dev.axvspan(0.0, 0.05, alpha=0.1, color="blue") + ax_dev.axvspan(0.95, 1.0, alpha=0.1, color="red") + + plt.tight_layout() + + # Save the plot + parts = ["qq_analysis", tag] + name = "_".join(filter(None, parts)) + save_path = self.out_plot_dir.joinpath(f"{name}.{self.image_format}") + plt.savefig(save_path, bbox_inches="tight") + plt.close() + + _logger.info(f"Q-Q plot saved to {save_path}") + class ScoreCards: """ diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index 3099b0bd4..ae6fffd76 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -190,6 +190,7 @@ def __init__( "grad_amplitude": self.calc_spatial_variability, "psnr": self.calc_psnr, "seeps": self.calc_seeps, + "qq_analysis": self.calc_qq_analysis, } self.prob_metrics_dict = { "ssr": self.calc_ssr, @@ -1495,3 +1496,163 @@ def check_for_coords(coord_names_data, coord_names_expected): raise ValueError(f"Second-order differentation is not implemenetd in {method} yet.") return var_diff_amplitude + + def calc_qq_analysis( + self, + p: xr.DataArray, + gt: xr.DataArray, + n_quantiles: int = 100, + quantile_method: str = "linear", + focus_extremes: bool = True, + extreme_percentiles: tuple[float, float] = (5.0, 95.0), + ) -> xr.DataArray: + """ + Calculate quantile-quantile (Q-Q) analysis metric for extreme value evaluation. + + This metric compares the distribution of forecast values with ground truth values + by computing quantiles and their deviations. + + The Q-Q analysis returns quantile values from both prediction and ground truth, + along with metrics to assess how well the forecast captures the observed distribution, + especially in the tails (extremes). + + Parameters + ---------- + p: xr.DataArray + Forecast data array + gt: xr.DataArray + Ground truth data array + n_quantiles: int + Number of quantiles to calculate for the Q-Q plot. Default is 100. + Higher values provide finer resolution of the distribution. + quantile_method: str + Method for quantile calculation. Options: 'linear', 'lower', 'higher', + 'midpoint', 'nearest'. Default is 'linear'. + focus_extremes: bool + If True, additional quantiles are computed in the extreme tails of the + distribution for better resolution of extreme values. Default is True. + extreme_percentiles: tuple[float, float] + Lower and upper percentile thresholds to define extremes. + Default is (5.0, 95.0), meaning values below 5th and above 95th percentile + are considered extremes. + + Returns + ------- + xr.DataArray + Dataset containing Q-Q analysis results with the following variables: + - 'quantile': Theoretical quantile levels (0 to 1) + - 'p_quantiles': Quantile values from prediction data + - 'gt_quantiles': Quantile values from ground truth data + - 'qq_deviation': Absolute difference between prediction and ground truth quantiles + - 'qq_deviation_normalized': Normalized deviation (relative to ground truth IQR) + - 'extreme_low_mse': MSE for lower extreme quantiles + - 'extreme_high_mse': MSE for upper extreme quantiles + - 'overall_qq_score': Overall Q-Q score (lower is better, 0 is perfect) + """ + if self._agg_dims is None: + raise ValueError( + "Cannot calculate Q-Q analysis without aggregation dimensions " + "(agg_dims=None)." + ) + + # Generate quantile levels + if focus_extremes: + # Add more resolution in the tails for extreme value analysis + lower_tail = np.linspace(0.001, extreme_percentiles[0] / 100, 20) + middle = np.linspace( + extreme_percentiles[0] / 100, extreme_percentiles[1] / 100, n_quantiles - 40 + ) + upper_tail = np.linspace(extreme_percentiles[1] / 100, 0.999, 20) + quantile_levels = np.concatenate([lower_tail, middle, upper_tail]) + else: + quantile_levels = np.linspace(0.001, 0.999, n_quantiles) + + # Flatten the data across aggregation dimensions while preserving other dimensions + # Determine which dimensions to keep (not in agg_dims) + keep_dims = [dim for dim in p.dims if dim not in self._agg_dims] + + if keep_dims: + # Stack aggregation dimensions into a single dimension + p_flat = p.stack({"_agg_points": self._agg_dims}) + gt_flat = gt.stack({"_agg_points": self._agg_dims}) + else: + # If all dimensions are aggregated, add a dummy dimension + p_flat = p.stack({"_agg_points": self._agg_dims}).expand_dims("dummy") + gt_flat = gt.stack({"_agg_points": self._agg_dims}).expand_dims("dummy") + keep_dims = ["dummy"] + + # Remove NaN values before quantile calculation + p_flat = p_flat.dropna(dim="_agg_points", how="all") + gt_flat = gt_flat.dropna(dim="_agg_points", how="all") + + # Calculate quantiles using xarray's quantile method + p_quantiles = p_flat.quantile( + quantile_levels, dim="_agg_points", method=quantile_method + ) + gt_quantiles = gt_flat.quantile( + quantile_levels, dim="_agg_points", method=quantile_method + ) + + # Calculate Q-Q deviations + qq_deviation = np.abs(p_quantiles - gt_quantiles) + + # Calculate normalized deviation (relative to interquartile range of ground truth) + gt_q25 = gt_flat.quantile(0.25, dim="_agg_points") + gt_q75 = gt_flat.quantile(0.75, dim="_agg_points") + iqr = gt_q75 - gt_q25 + # Avoid division by zero + iqr = iqr.where(iqr > 1e-10, 1.0) + + qq_deviation_normalized = qq_deviation / iqr + + # Calculate MSE for extreme quantiles + extreme_low_mask = quantile_levels < (extreme_percentiles[0] / 100) + extreme_high_mask = quantile_levels > (extreme_percentiles[1] / 100) + + extreme_low_mse = ( + ((p_quantiles - gt_quantiles) ** 2) + .isel(quantile=extreme_low_mask) + .mean(dim="quantile") + ) + extreme_high_mse = ( + ((p_quantiles - gt_quantiles) ** 2) + .isel(quantile=extreme_high_mask) + .mean(dim="quantile") + ) + + # Calculate overall Q-Q score (mean absolute deviation across all quantiles) + overall_qq_score = qq_deviation.mean(dim="quantile") + + # Store Q-Q data as a non-scalar coordinate with matching dimensions + # This ensures each channel/sample/ens combination gets its own Q-Q data + import json + + # Create separate JSON strings for each position in the multi-dimensional array + qq_data_coord_array = np.empty(overall_qq_score.shape, dtype=object) + + # Iterate over all positions and create individual JSON strings + for idx in np.ndindex(overall_qq_score.shape): + qq_full_data = { + "quantile_levels": quantile_levels.tolist(), + "p_quantiles": p_quantiles.values[(...,) + idx].tolist() if p_quantiles.ndim > 1 else p_quantiles.values.tolist(), + "gt_quantiles": gt_quantiles.values[(...,) + idx].tolist() if gt_quantiles.ndim > 1 else gt_quantiles.values.tolist(), + "qq_deviation": qq_deviation.values[(...,) + idx].tolist() if qq_deviation.ndim > 1 else qq_deviation.values.tolist(), + "qq_deviation_normalized": qq_deviation_normalized.values[(...,) + idx].tolist() if qq_deviation_normalized.ndim > 1 else qq_deviation_normalized.values.tolist(), + "extreme_low_mse": float(extreme_low_mse.values[idx]) if extreme_low_mse.ndim > 0 else float(extreme_low_mse.values), + "extreme_high_mse": float(extreme_high_mse.values[idx]) if extreme_high_mse.ndim > 0 else float(extreme_high_mse.values), + } + qq_data_coord_array[idx] = json.dumps(qq_full_data) + + qq_data_coord = xr.DataArray( + qq_data_coord_array, + dims=overall_qq_score.dims, + coords={dim: overall_qq_score.coords[dim] for dim in overall_qq_score.dims}, + attrs={"description": "Full Q-Q analysis data for plotting"} + ) + overall_qq_score = overall_qq_score.assign_coords({"qq_full_data": qq_data_coord}) + + # Remove dummy dimension if it was added + if "dummy" in overall_qq_score.dims: + overall_qq_score = overall_qq_score.squeeze("dummy", drop=True) + + return overall_qq_score diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index d911a370b..c19efa6c5 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -25,6 +25,7 @@ bar_plot_metric_region, heat_maps_metric_region, plot_metric_region, + qq_plot_metric_region, ratio_plot_metric_region, score_card_metric_region, ) @@ -184,7 +185,14 @@ def calc_scores_per_stream( if not valid_scores: continue - combined_metrics = xr.concat(valid_scores, dim="metric") + # Concatenate all metrics; use "different" to allow metric-specific coordinates like qq_full_data + combined_metrics = xr.concat( + valid_scores, + dim="metric", + coords="different", + combine_attrs="no_conflicts", + ) + combined_metrics = combined_metrics.assign_coords(metric=valid_metric_names) combined_metrics = combined_metrics.compute() @@ -201,6 +209,20 @@ def calc_scores_per_stream( criteria["ens"] = combined_metrics.ens.values metric_stream.loc[criteria] = combined_metrics + + # Preserve metric-specific coordinates that loc assignment may drop + if "qq_full_data" in combined_metrics.coords: + if "qq_full_data" not in metric_stream.coords: + # Initialize coordinate array on first encounter + metric_stream = metric_stream.assign_coords({ + "qq_full_data": xr.DataArray( + np.full(metric_stream.shape, "", dtype=object), + dims=metric_stream.dims, + coords={dim: metric_stream.coords[dim] for dim in metric_stream.dims} + ) + }) + # Copy coordinate values for this slice + metric_stream.coords["qq_full_data"].loc[criteria] = combined_metrics.coords["qq_full_data"] lead_time_map[fstep] = ( np.unique(combined_metrics.lead_time.values.astype("timedelta64[h]")) @@ -536,7 +558,10 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): br_plotter = BarPlots(plot_cfg, summary_dir) for region in regions: for metric in metrics: - if eval_opt.get("summary_plots", True): + # Special handling for Q-Q analysis metric + if metric == "qq_analysis": + qq_plot_metric_region(metric, region, runs, scores_dict, plotter) + elif eval_opt.get("summary_plots", True): plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary) if eval_opt.get("ratio_plots", False): ratio_plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary) From 55fff9598e334c046d9d420a8b72ad96ea317fc9 Mon Sep 17 00:00:00 2001 From: "Dr. Jehangir Awan" Date: Wed, 28 Jan 2026 14:34:19 +0100 Subject: [PATCH 2/5] Add Q-Q analysis metric for extreme value evaluation --- .../evaluate/plotting/plot_utils.py | 97 +++++++++++-------- .../weathergen/evaluate/plotting/plotter.py | 24 ++--- .../src/weathergen/evaluate/scores/score.py | 59 +++++------ .../src/weathergen/evaluate/utils/utils.py | 28 +++--- 4 files changed, 119 insertions(+), 89 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index b92285f8b..f9a4f1e53 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -417,9 +417,9 @@ def qq_plot_metric_region( plotter: object, ) -> None: """ - Create quantile-quantile (Q-Q) plots for extreme value analysis for all streams + Create quantile-quantile (Q-Q) plots for extreme value analysis for all streams and channels for a given metric and region. - + Parameters ---------- metric: str @@ -445,21 +445,22 @@ def qq_plot_metric_region( # skip if channel is missing if ch not in np.atleast_1d(data.channel.values): continue - + # Select channel data_for_channel = data.sel(channel=ch) if "channel" in data.dims else data - + # Check for NaN if data_for_channel.isnull().all(): continue - + # For qq_analysis, extract full Q-Q data from coordinate if metric == "qq_analysis": if "qq_full_data" in data_for_channel.coords: import json + # Get the qq_full_data - it may be 0-d (scalar) or multi-d qq_coord = data_for_channel.coords["qq_full_data"] - + # Extract the first non-empty JSON string if qq_coord.ndim == 0: qq_data_str = qq_coord.item() @@ -467,76 +468,94 @@ def qq_plot_metric_region( # Flatten and get first non-empty value qq_values = qq_coord.values.flatten() qq_data_str = next((v for v in qq_values if v and v != ""), None) - + if not qq_data_str: _logger.warning(f"Empty qq_full_data for channel {ch}") continue - + qq_data = json.loads(qq_data_str) - + # Convert lists back to numpy arrays for plotting quantile_levels = np.array(qq_data["quantile_levels"]) p_quantiles = np.array(qq_data["p_quantiles"]) gt_quantiles = np.array(qq_data["gt_quantiles"]) qq_deviation = np.array(qq_data["qq_deviation"]) - qq_deviation_normalized = np.array(qq_data.get("qq_deviation_normalized", qq_deviation)) + qq_deviation_normalized = np.array( + qq_data.get("qq_deviation_normalized", qq_deviation) + ) extreme_low_mse = np.array(qq_data.get("extreme_low_mse", 0.0)) extreme_high_mse = np.array(qq_data.get("extreme_high_mse", 0.0)) - + # All quantile data is replicated across sample/ens dimensions # Take the first sample's data (they're all identical) - quantile_levels_1d = quantile_levels.flatten() if quantile_levels.ndim > 1 else quantile_levels - - # For p_quantiles and gt_quantiles, average across samples if multidimensional - # Shape is typically (n_quantiles, n_samples, n_ens) - # We want to average across samples (axis 1) and ens (axis 2), keeping quantiles (axis 0) + quantile_levels_1d = ( + quantile_levels.flatten() + if quantile_levels.ndim > 1 + else quantile_levels + ) + + # For p_quantiles and gt_quantiles, average across samples + # Shape: (n_quantiles, n_samples, n_ens) + # Average across samples (axis 1) and ens (axis 2) if p_quantiles.ndim > 1: # Average across all axes except the first (quantile axis) - p_quantiles_1d = np.mean(p_quantiles, axis=tuple(range(1, p_quantiles.ndim))).flatten() + p_quantiles_1d = np.mean( + p_quantiles, axis=tuple(range(1, p_quantiles.ndim)) + ).flatten() else: p_quantiles_1d = p_quantiles - + if gt_quantiles.ndim > 1: - gt_quantiles_1d = np.mean(gt_quantiles, axis=tuple(range(1, gt_quantiles.ndim))).flatten() + gt_quantiles_1d = np.mean( + gt_quantiles, axis=tuple(range(1, gt_quantiles.ndim)) + ).flatten() else: gt_quantiles_1d = gt_quantiles - + if qq_deviation.ndim > 1: - qq_deviation_1d = np.mean(qq_deviation, axis=tuple(range(1, qq_deviation.ndim))).flatten() + qq_deviation_1d = np.mean( + qq_deviation, axis=tuple(range(1, qq_deviation.ndim)) + ).flatten() else: qq_deviation_1d = qq_deviation - + if qq_deviation_normalized.ndim > 1: - qq_deviation_normalized_1d = np.mean(qq_deviation_normalized, axis=tuple(range(1, qq_deviation_normalized.ndim))).flatten() + qq_deviation_normalized_1d = np.mean( + qq_deviation_normalized, + axis=tuple(range(1, qq_deviation_normalized.ndim)), + ).flatten() else: qq_deviation_normalized_1d = qq_deviation_normalized - + # Handle scalar MSE values extreme_low_mse_scalar = float(np.mean(extreme_low_mse)) extreme_high_mse_scalar = float(np.mean(extreme_high_mse)) - + # Create Dataset with proper dimensions - qq_dataset = xr.Dataset({ - "quantile_levels": (["quantile"], quantile_levels_1d), - "p_quantiles": (["quantile"], p_quantiles_1d), - "gt_quantiles": (["quantile"], gt_quantiles_1d), - "qq_deviation": (["quantile"], qq_deviation_1d), - "qq_deviation_normalized": (["quantile"], qq_deviation_normalized_1d), - "extreme_low_mse": ([], extreme_low_mse_scalar), - "extreme_high_mse": ([], extreme_high_mse_scalar), - }) + qq_dataset = xr.Dataset( + { + "quantile_levels": (["quantile"], quantile_levels_1d), + "p_quantiles": (["quantile"], p_quantiles_1d), + "gt_quantiles": (["quantile"], gt_quantiles_1d), + "qq_deviation": (["quantile"], qq_deviation_1d), + "qq_deviation_normalized": ( + ["quantile"], + qq_deviation_normalized_1d, + ), + "extreme_low_mse": ([], extreme_low_mse_scalar), + "extreme_high_mse": ([], extreme_high_mse_scalar), + } + ) qq_full_data.append(qq_dataset) else: - _logger.warning(f"qq_full_data not found in coords for qq_analysis metric") - + _logger.warning("qq_full_data not found in coords for qq_analysis metric") + selected_data.append(data_for_channel) labels.append(runs[run_id].get("label", run_id)) run_ids.append(run_id) if selected_data: - _logger.info( - f"Creating Q-Q plot for {metric} - {region} - {stream} - {ch}." - ) + _logger.info(f"Creating Q-Q plot for {metric} - {region} - {stream} - {ch}.") name = create_filename( prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream, ch] diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 4664ee974..47c34fe36 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -1146,8 +1146,8 @@ def qq_plot( ) -> None: """ Create quantile-quantile (Q-Q) plots for extreme value analysis. - - This method generates comprehensive Q-Q plots comparing forecast quantiles + + This method generates comprehensive Q-Q plots comparing forecast quantiles against ground truth quantiles, with emphasis on extreme values. Parameters @@ -1177,13 +1177,13 @@ def qq_plot( # Create figure with subplots fig = plt.figure(figsize=(16, 6), dpi=self.dpi_val) gs = fig.add_gridspec(1, 2, width_ratios=[2, 1], wspace=0.3) - + ax_qq = fig.add_subplot(gs[0]) # Main Q-Q plot ax_dev = fig.add_subplot(gs[1]) # Deviation plot colors = plt.cm.tab10(np.linspace(0, 1, len(data_list))) - for i, (ds, label, color) in enumerate(zip(data_list, label_list, colors, strict=False)): + for _i, (ds, label, color) in enumerate(zip(data_list, label_list, colors, strict=False)): # Extract quantile data quantile_levels = ds["quantile_levels"].values p_quantiles = ds["p_quantiles"].values @@ -1215,7 +1215,7 @@ def qq_plot( ax_qq.set_xlabel("Ground Truth Quantiles", fontsize=12) ax_qq.set_ylabel("Prediction Quantiles", fontsize=12) ax_qq.set_title("Quantile-Quantile Plot for Extreme Value Analysis", fontsize=14) - + # Add perfect agreement line (y=x) min_val = min([ds["gt_quantiles"].min().values for ds in data_list]) max_val = max([ds["gt_quantiles"].max().values for ds in data_list]) @@ -1227,16 +1227,16 @@ def qq_plot( label="Perfect Agreement", alpha=0.7, ) - + # Add shaded regions for extremes if len(data_list) > 0: ds_ref = data_list[0] quantile_levels = ds_ref["quantile_levels"].values - + # Find extreme regions (typically below 5% and above 95%) lower_extreme_idx = quantile_levels < 0.05 upper_extreme_idx = quantile_levels > 0.95 - + if np.any(lower_extreme_idx): lower_q = ds_ref["gt_quantiles"].values[lower_extreme_idx] ax_qq.axvspan( @@ -1246,7 +1246,7 @@ def qq_plot( color="blue", label="Lower Extreme Zone", ) - + if np.any(upper_extreme_idx): upper_q = ds_ref["gt_quantiles"].values[upper_extreme_idx] ax_qq.axvspan( @@ -1266,20 +1266,20 @@ def qq_plot( ax_dev.set_title("Quantile Deviation", fontsize=14) ax_dev.legend(frameon=False, fontsize=10) ax_dev.grid(True, linestyle="--", alpha=0.3) - + # Highlight extreme regions in deviation plot ax_dev.axvspan(0.0, 0.05, alpha=0.1, color="blue") ax_dev.axvspan(0.95, 1.0, alpha=0.1, color="red") plt.tight_layout() - + # Save the plot parts = ["qq_analysis", tag] name = "_".join(filter(None, parts)) save_path = self.out_plot_dir.joinpath(f"{name}.{self.image_format}") plt.savefig(save_path, bbox_inches="tight") plt.close() - + _logger.info(f"Q-Q plot saved to {save_path}") diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index ae6fffd76..d13464cef 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -1508,10 +1508,10 @@ def calc_qq_analysis( ) -> xr.DataArray: """ Calculate quantile-quantile (Q-Q) analysis metric for extreme value evaluation. - + This metric compares the distribution of forecast values with ground truth values - by computing quantiles and their deviations. - + by computing quantiles and their deviations. + The Q-Q analysis returns quantile values from both prediction and ground truth, along with metrics to assess how well the forecast captures the observed distribution, especially in the tails (extremes). @@ -1526,7 +1526,7 @@ def calc_qq_analysis( Number of quantiles to calculate for the Q-Q plot. Default is 100. Higher values provide finer resolution of the distribution. quantile_method: str - Method for quantile calculation. Options: 'linear', 'lower', 'higher', + Method for quantile calculation. Options: 'linear', 'lower', 'higher', 'midpoint', 'nearest'. Default is 'linear'. focus_extremes: bool If True, additional quantiles are computed in the extreme tails of the @@ -1551,8 +1551,7 @@ def calc_qq_analysis( """ if self._agg_dims is None: raise ValueError( - "Cannot calculate Q-Q analysis without aggregation dimensions " - "(agg_dims=None)." + "Cannot calculate Q-Q analysis without aggregation dimensions (agg_dims=None)." ) # Generate quantile levels @@ -1570,7 +1569,7 @@ def calc_qq_analysis( # Flatten the data across aggregation dimensions while preserving other dimensions # Determine which dimensions to keep (not in agg_dims) keep_dims = [dim for dim in p.dims if dim not in self._agg_dims] - + if keep_dims: # Stack aggregation dimensions into a single dimension p_flat = p.stack({"_agg_points": self._agg_dims}) @@ -1586,12 +1585,8 @@ def calc_qq_analysis( gt_flat = gt_flat.dropna(dim="_agg_points", how="all") # Calculate quantiles using xarray's quantile method - p_quantiles = p_flat.quantile( - quantile_levels, dim="_agg_points", method=quantile_method - ) - gt_quantiles = gt_flat.quantile( - quantile_levels, dim="_agg_points", method=quantile_method - ) + p_quantiles = p_flat.quantile(quantile_levels, dim="_agg_points", method=quantile_method) + gt_quantiles = gt_flat.quantile(quantile_levels, dim="_agg_points", method=quantile_method) # Calculate Q-Q deviations qq_deviation = np.abs(p_quantiles - gt_quantiles) @@ -1602,7 +1597,7 @@ def calc_qq_analysis( iqr = gt_q75 - gt_q25 # Avoid division by zero iqr = iqr.where(iqr > 1e-10, 1.0) - + qq_deviation_normalized = qq_deviation / iqr # Calculate MSE for extreme quantiles @@ -1610,9 +1605,7 @@ def calc_qq_analysis( extreme_high_mask = quantile_levels > (extreme_percentiles[1] / 100) extreme_low_mse = ( - ((p_quantiles - gt_quantiles) ** 2) - .isel(quantile=extreme_low_mask) - .mean(dim="quantile") + ((p_quantiles - gt_quantiles) ** 2).isel(quantile=extreme_low_mask).mean(dim="quantile") ) extreme_high_mse = ( ((p_quantiles - gt_quantiles) ** 2) @@ -1626,28 +1619,40 @@ def calc_qq_analysis( # Store Q-Q data as a non-scalar coordinate with matching dimensions # This ensures each channel/sample/ens combination gets its own Q-Q data import json - + # Create separate JSON strings for each position in the multi-dimensional array qq_data_coord_array = np.empty(overall_qq_score.shape, dtype=object) - + # Iterate over all positions and create individual JSON strings for idx in np.ndindex(overall_qq_score.shape): qq_full_data = { "quantile_levels": quantile_levels.tolist(), - "p_quantiles": p_quantiles.values[(...,) + idx].tolist() if p_quantiles.ndim > 1 else p_quantiles.values.tolist(), - "gt_quantiles": gt_quantiles.values[(...,) + idx].tolist() if gt_quantiles.ndim > 1 else gt_quantiles.values.tolist(), - "qq_deviation": qq_deviation.values[(...,) + idx].tolist() if qq_deviation.ndim > 1 else qq_deviation.values.tolist(), - "qq_deviation_normalized": qq_deviation_normalized.values[(...,) + idx].tolist() if qq_deviation_normalized.ndim > 1 else qq_deviation_normalized.values.tolist(), - "extreme_low_mse": float(extreme_low_mse.values[idx]) if extreme_low_mse.ndim > 0 else float(extreme_low_mse.values), - "extreme_high_mse": float(extreme_high_mse.values[idx]) if extreme_high_mse.ndim > 0 else float(extreme_high_mse.values), + "p_quantiles": p_quantiles.values[(...,) + idx].tolist() + if p_quantiles.ndim > 1 + else p_quantiles.values.tolist(), + "gt_quantiles": gt_quantiles.values[(...,) + idx].tolist() + if gt_quantiles.ndim > 1 + else gt_quantiles.values.tolist(), + "qq_deviation": qq_deviation.values[(...,) + idx].tolist() + if qq_deviation.ndim > 1 + else qq_deviation.values.tolist(), + "qq_deviation_normalized": qq_deviation_normalized.values[(...,) + idx].tolist() + if qq_deviation_normalized.ndim > 1 + else qq_deviation_normalized.values.tolist(), + "extreme_low_mse": float(extreme_low_mse.values[idx]) + if extreme_low_mse.ndim > 0 + else float(extreme_low_mse.values), + "extreme_high_mse": float(extreme_high_mse.values[idx]) + if extreme_high_mse.ndim > 0 + else float(extreme_high_mse.values), } qq_data_coord_array[idx] = json.dumps(qq_full_data) - + qq_data_coord = xr.DataArray( qq_data_coord_array, dims=overall_qq_score.dims, coords={dim: overall_qq_score.coords[dim] for dim in overall_qq_score.dims}, - attrs={"description": "Full Q-Q analysis data for plotting"} + attrs={"description": "Full Q-Q analysis data for plotting"}, ) overall_qq_score = overall_qq_score.assign_coords({"qq_full_data": qq_data_coord}) diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index c19efa6c5..2d43bb191 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -185,14 +185,14 @@ def calc_scores_per_stream( if not valid_scores: continue - # Concatenate all metrics; use "different" to allow metric-specific coordinates like qq_full_data + # Concatenate all metrics; use "different" for metric-specific coords combined_metrics = xr.concat( valid_scores, dim="metric", coords="different", combine_attrs="no_conflicts", ) - + combined_metrics = combined_metrics.assign_coords(metric=valid_metric_names) combined_metrics = combined_metrics.compute() @@ -209,20 +209,26 @@ def calc_scores_per_stream( criteria["ens"] = combined_metrics.ens.values metric_stream.loc[criteria] = combined_metrics - + # Preserve metric-specific coordinates that loc assignment may drop if "qq_full_data" in combined_metrics.coords: if "qq_full_data" not in metric_stream.coords: # Initialize coordinate array on first encounter - metric_stream = metric_stream.assign_coords({ - "qq_full_data": xr.DataArray( - np.full(metric_stream.shape, "", dtype=object), - dims=metric_stream.dims, - coords={dim: metric_stream.coords[dim] for dim in metric_stream.dims} - ) - }) + metric_stream = metric_stream.assign_coords( + { + "qq_full_data": xr.DataArray( + np.full(metric_stream.shape, "", dtype=object), + dims=metric_stream.dims, + coords={ + dim: metric_stream.coords[dim] for dim in metric_stream.dims + }, + ) + } + ) # Copy coordinate values for this slice - metric_stream.coords["qq_full_data"].loc[criteria] = combined_metrics.coords["qq_full_data"] + metric_stream.coords["qq_full_data"].loc[criteria] = combined_metrics.coords[ + "qq_full_data" + ] lead_time_map[fstep] = ( np.unique(combined_metrics.lead_time.values.astype("timedelta64[h]")) From 43a5bece3870c5c6c44f43bfafbae22b1d24e1e0 Mon Sep 17 00:00:00 2001 From: "Dr. Jehangir Awan" Date: Thu, 29 Jan 2026 17:53:26 +0100 Subject: [PATCH 3/5] Address PR review: refactor qq_analysis for clarity and extensibility - Rename methods, add helper functions, fix naming conventions - Make percentile thresholds configurable - Create QuantilePlots class, implement generic coordinate handling - All reviewer comments addressed --- .../evaluate/plotting/plot_utils.py | 96 ++++++------------- .../weathergen/evaluate/plotting/plotter.py | 76 ++++++++++++++- .../src/weathergen/evaluate/scores/score.py | 58 +++++------ .../weathergen/evaluate/scores/score_utils.py | 38 ++++++++ .../src/weathergen/evaluate/utils/utils.py | 53 +++++----- 5 files changed, 192 insertions(+), 129 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index f9a4f1e53..29a8964ce 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -7,6 +7,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import json import logging from collections.abc import Iterable, Sequence @@ -16,6 +17,24 @@ _logger = logging.getLogger(__name__) +def _flatten_or_average(arr: np.ndarray) -> np.ndarray: + """Flatten array or average across non-quantile dimensions. + + Parameters + ---------- + arr : np.ndarray + Input array, possibly multi-dimensional. + + Returns + ------- + np.ndarray + Flattened 1D array, averaged across extra dimensions if needed. + """ + if arr.ndim > 1: + return np.mean(arr, axis=tuple(range(1, arr.ndim))).flatten() + return arr + + def collect_streams(runs: dict): """Get all unique streams across runs, sorted. @@ -409,7 +428,7 @@ def list_streams(cls): return list(cls._marker_size_stream.keys()) -def qq_plot_metric_region( +def quantile_plot_metric_region( metric: str, region: str, runs: dict, @@ -475,75 +494,16 @@ def qq_plot_metric_region( qq_data = json.loads(qq_data_str) - # Convert lists back to numpy arrays for plotting - quantile_levels = np.array(qq_data["quantile_levels"]) - p_quantiles = np.array(qq_data["p_quantiles"]) - gt_quantiles = np.array(qq_data["gt_quantiles"]) - qq_deviation = np.array(qq_data["qq_deviation"]) - qq_deviation_normalized = np.array( - qq_data.get("qq_deviation_normalized", qq_deviation) - ) - extreme_low_mse = np.array(qq_data.get("extreme_low_mse", 0.0)) - extreme_high_mse = np.array(qq_data.get("extreme_high_mse", 0.0)) - - # All quantile data is replicated across sample/ens dimensions - # Take the first sample's data (they're all identical) - quantile_levels_1d = ( - quantile_levels.flatten() - if quantile_levels.ndim > 1 - else quantile_levels - ) - - # For p_quantiles and gt_quantiles, average across samples - # Shape: (n_quantiles, n_samples, n_ens) - # Average across samples (axis 1) and ens (axis 2) - if p_quantiles.ndim > 1: - # Average across all axes except the first (quantile axis) - p_quantiles_1d = np.mean( - p_quantiles, axis=tuple(range(1, p_quantiles.ndim)) - ).flatten() - else: - p_quantiles_1d = p_quantiles - - if gt_quantiles.ndim > 1: - gt_quantiles_1d = np.mean( - gt_quantiles, axis=tuple(range(1, gt_quantiles.ndim)) - ).flatten() - else: - gt_quantiles_1d = gt_quantiles - - if qq_deviation.ndim > 1: - qq_deviation_1d = np.mean( - qq_deviation, axis=tuple(range(1, qq_deviation.ndim)) - ).flatten() - else: - qq_deviation_1d = qq_deviation - - if qq_deviation_normalized.ndim > 1: - qq_deviation_normalized_1d = np.mean( - qq_deviation_normalized, - axis=tuple(range(1, qq_deviation_normalized.ndim)), - ).flatten() - else: - qq_deviation_normalized_1d = qq_deviation_normalized - - # Handle scalar MSE values - extreme_low_mse_scalar = float(np.mean(extreme_low_mse)) - extreme_high_mse_scalar = float(np.mean(extreme_high_mse)) - - # Create Dataset with proper dimensions + # Convert and process all arrays qq_dataset = xr.Dataset( { - "quantile_levels": (["quantile"], quantile_levels_1d), - "p_quantiles": (["quantile"], p_quantiles_1d), - "gt_quantiles": (["quantile"], gt_quantiles_1d), - "qq_deviation": (["quantile"], qq_deviation_1d), - "qq_deviation_normalized": ( - ["quantile"], - qq_deviation_normalized_1d, - ), - "extreme_low_mse": ([], extreme_low_mse_scalar), - "extreme_high_mse": ([], extreme_high_mse_scalar), + "quantile_levels": (["quantile"], _flatten_or_average(np.array(qq_data["quantile_levels"]))), + "p_quantiles": (["quantile"], _flatten_or_average(np.array(qq_data["p_quantiles"]))), + "gt_quantiles": (["quantile"], _flatten_or_average(np.array(qq_data["gt_quantiles"]))), + "qq_deviation": (["quantile"], _flatten_or_average(np.array(qq_data["qq_deviation"]))), + "qq_deviation_normalized": (["quantile"], _flatten_or_average(np.array(qq_data.get("qq_deviation_normalized", qq_data["qq_deviation"])))), + "extreme_low_mse": ([], float(np.mean(qq_data.get("extreme_low_mse", 0.0)))), + "extreme_high_mse": ([], float(np.mean(qq_data.get("extreme_high_mse", 0.0)))), } ) qq_full_data.append(qq_dataset) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 47c34fe36..de51a5ef8 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -1137,12 +1137,74 @@ def heat_map( name = "_".join(filter(None, parts)) plt.savefig(f"{self.out_plot_dir.joinpath(name)}.{self.image_format}") + +class QuantilePlots: + + def __init__(self, plotter_cfg: dict, output_basedir: str | Path): + """ + Initialize the QuantilePlots class. + + Parameters + ---------- + plotter_cfg: + Configuration dictionary containing basic information for plotting. + Expected keys are: + - image_format: Format of the saved images (e.g., 'png', 'pdf', etc.) + - dpi_val: DPI value for the saved images + - fig_size: Size of the figure (width, height) in inches + output_basedir: + Base directory under which the plots will be saved. + Expected scheme `/`. + """ + self.image_format = plotter_cfg.get("image_format") + self.dpi_val = plotter_cfg.get("dpi_val") + self.fig_size = plotter_cfg.get("fig_size") + self.out_plot_dir = Path(output_basedir) / "quantile_plots" + + if not os.path.exists(self.out_plot_dir): + _logger.info(f"Creating dir {self.out_plot_dir}") + os.makedirs(self.out_plot_dir, exist_ok=True) + + _logger.info(f"Saving quantile plots to: {self.out_plot_dir}") + + def _check_lengths(self, data: xr.DataArray | list, labels: str | list) -> tuple[list, list]: + """ + Check if the lengths of data and labels match. + + Parameters + ---------- + data: + DataArray or list of DataArrays to be plotted + labels: + Label or list of labels for each dataset + + Returns + ------- + data_list, label_list - lists of data and labels + """ + assert isinstance(data, xr.DataArray | list), ( + "QuantilePlots::_check_lengths - Data should be of type xr.DataArray or list" + ) + assert isinstance(labels, str | list), ( + "QuantilePlots::_check_lengths - Labels should be of type str or list" + ) + + data_list = [data] if isinstance(data, xr.DataArray) else data + label_list = [labels] if isinstance(labels, str) else labels + + assert len(data_list) == len(label_list), ( + "QuantilePlots::_check_lengths - Data and Labels do not match" + ) + + return data_list, label_list + def qq_plot( self, qq_data: list[xr.Dataset], labels: str | list, tag: str = "", metric: str = "qq_analysis", + extreme_percentiles: tuple[float, float] = (5.0, 95.0), ) -> None: """ Create quantile-quantile (Q-Q) plots for extreme value analysis. @@ -1167,6 +1229,8 @@ def qq_plot( Tag to be added to the plot title and filename metric: Name of the metric (default: 'qq_analysis') + extreme_percentiles: + Lower and upper percentile thresholds for extreme regions. Returns ------- @@ -1233,9 +1297,9 @@ def qq_plot( ds_ref = data_list[0] quantile_levels = ds_ref["quantile_levels"].values - # Find extreme regions (typically below 5% and above 95%) - lower_extreme_idx = quantile_levels < 0.05 - upper_extreme_idx = quantile_levels > 0.95 + # Find extreme regions using configurable thresholds + lower_extreme_idx = quantile_levels < (extreme_percentiles[0] / 100) + upper_extreme_idx = quantile_levels > (extreme_percentiles[1] / 100) if np.any(lower_extreme_idx): lower_q = ds_ref["gt_quantiles"].values[lower_extreme_idx] @@ -1268,8 +1332,10 @@ def qq_plot( ax_dev.grid(True, linestyle="--", alpha=0.3) # Highlight extreme regions in deviation plot - ax_dev.axvspan(0.0, 0.05, alpha=0.1, color="blue") - ax_dev.axvspan(0.95, 1.0, alpha=0.1, color="red") + lower_threshold = extreme_percentiles[0] / 100 + upper_threshold = extreme_percentiles[1] / 100 + ax_dev.axvspan(0.0, lower_threshold, alpha=0.1, color="blue") + ax_dev.axvspan(upper_threshold, 1.0, alpha=0.1, color="red") plt.tight_layout() diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index d13464cef..46aac0cb9 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -7,6 +7,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. import inspect +import json import logging from dataclasses import dataclass @@ -16,7 +17,7 @@ import xarray as xr from scipy.spatial import cKDTree -from weathergen.evaluate.scores.score_utils import to_list +from weathergen.evaluate.scores.score_utils import arr_to_float, arr_to_list, to_list # from common.io import MockIO @@ -190,7 +191,7 @@ def __init__( "grad_amplitude": self.calc_spatial_variability, "psnr": self.calc_psnr, "seeps": self.calc_seeps, - "qq_analysis": self.calc_qq_analysis, + "qq_analysis": self.calc_quantiles, } self.prob_metrics_dict = { "ssr": self.calc_ssr, @@ -1497,7 +1498,7 @@ def check_for_coords(coord_names_data, coord_names_expected): return var_diff_amplitude - def calc_qq_analysis( + def calc_quantiles( self, p: xr.DataArray, gt: xr.DataArray, @@ -1505,6 +1506,7 @@ def calc_qq_analysis( quantile_method: str = "linear", focus_extremes: bool = True, extreme_percentiles: tuple[float, float] = (5.0, 95.0), + iqr_percentiles: tuple[float, float] = (25.0, 75.0), ) -> xr.DataArray: """ Calculate quantile-quantile (Q-Q) analysis metric for extreme value evaluation. @@ -1535,6 +1537,10 @@ def calc_qq_analysis( Lower and upper percentile thresholds to define extremes. Default is (5.0, 95.0), meaning values below 5th and above 95th percentile are considered extremes. + iqr_percentiles: tuple[float, float] + Lower and upper percentile thresholds for interquartile range (IQR) calculation. + Default is (25.0, 75.0), meaning IQR is computed as difference between 75th and + 25th percentiles. Returns ------- @@ -1575,10 +1581,10 @@ def calc_qq_analysis( p_flat = p.stack({"_agg_points": self._agg_dims}) gt_flat = gt.stack({"_agg_points": self._agg_dims}) else: - # If all dimensions are aggregated, add a dummy dimension - p_flat = p.stack({"_agg_points": self._agg_dims}).expand_dims("dummy") - gt_flat = gt.stack({"_agg_points": self._agg_dims}).expand_dims("dummy") - keep_dims = ["dummy"] + # If all dimensions are aggregated, add a scalar dimension + p_flat = p.stack({"_agg_points": self._agg_dims}).expand_dims("scalar") + gt_flat = gt.stack({"_agg_points": self._agg_dims}).expand_dims("scalar") + keep_dims = ["scalar"] # Remove NaN values before quantile calculation p_flat = p_flat.dropna(dim="_agg_points", how="all") @@ -1592,9 +1598,9 @@ def calc_qq_analysis( qq_deviation = np.abs(p_quantiles - gt_quantiles) # Calculate normalized deviation (relative to interquartile range of ground truth) - gt_q25 = gt_flat.quantile(0.25, dim="_agg_points") - gt_q75 = gt_flat.quantile(0.75, dim="_agg_points") - iqr = gt_q75 - gt_q25 + gt_q_low = gt_flat.quantile(iqr_percentiles[0] / 100, dim="_agg_points") + gt_q_high = gt_flat.quantile(iqr_percentiles[1] / 100, dim="_agg_points") + iqr = gt_q_high - gt_q_low # Avoid division by zero iqr = iqr.where(iqr > 1e-10, 1.0) @@ -1618,8 +1624,6 @@ def calc_qq_analysis( # Store Q-Q data as a non-scalar coordinate with matching dimensions # This ensures each channel/sample/ens combination gets its own Q-Q data - import json - # Create separate JSON strings for each position in the multi-dimensional array qq_data_coord_array = np.empty(overall_qq_score.shape, dtype=object) @@ -1627,24 +1631,12 @@ def calc_qq_analysis( for idx in np.ndindex(overall_qq_score.shape): qq_full_data = { "quantile_levels": quantile_levels.tolist(), - "p_quantiles": p_quantiles.values[(...,) + idx].tolist() - if p_quantiles.ndim > 1 - else p_quantiles.values.tolist(), - "gt_quantiles": gt_quantiles.values[(...,) + idx].tolist() - if gt_quantiles.ndim > 1 - else gt_quantiles.values.tolist(), - "qq_deviation": qq_deviation.values[(...,) + idx].tolist() - if qq_deviation.ndim > 1 - else qq_deviation.values.tolist(), - "qq_deviation_normalized": qq_deviation_normalized.values[(...,) + idx].tolist() - if qq_deviation_normalized.ndim > 1 - else qq_deviation_normalized.values.tolist(), - "extreme_low_mse": float(extreme_low_mse.values[idx]) - if extreme_low_mse.ndim > 0 - else float(extreme_low_mse.values), - "extreme_high_mse": float(extreme_high_mse.values[idx]) - if extreme_high_mse.ndim > 0 - else float(extreme_high_mse.values), + "p_quantiles": arr_to_list(p_quantiles, idx), + "gt_quantiles": arr_to_list(gt_quantiles, idx), + "qq_deviation": arr_to_list(qq_deviation, idx), + "qq_deviation_normalized": arr_to_list(qq_deviation_normalized, idx), + "extreme_low_mse": arr_to_float(extreme_low_mse, idx), + "extreme_high_mse": arr_to_float(extreme_high_mse, idx), } qq_data_coord_array[idx] = json.dumps(qq_full_data) @@ -1656,8 +1648,8 @@ def calc_qq_analysis( ) overall_qq_score = overall_qq_score.assign_coords({"qq_full_data": qq_data_coord}) - # Remove dummy dimension if it was added - if "dummy" in overall_qq_score.dims: - overall_qq_score = overall_qq_score.squeeze("dummy", drop=True) + # Remove scalar dimension if it was added + if "scalar" in overall_qq_score.dims: + overall_qq_score = overall_qq_score.squeeze("scalar", drop=True) return overall_qq_score diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score_utils.py b/packages/evaluate/src/weathergen/evaluate/scores/score_utils.py index 4ad793d07..9d9f50451 100644 --- a/packages/evaluate/src/weathergen/evaluate/scores/score_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score_utils.py @@ -36,3 +36,41 @@ def to_list(obj: Any) -> list: elif not isinstance(obj, list): obj = [obj] return obj + + +def arr_to_list(arr, idx): + """ + Convert array values at specific index to a Python list for JSON serialization. + + Parameters + ---------- + arr : xr.DataArray or np.ndarray + Array containing values to convert. + idx : tuple + Index tuple for accessing multi-dimensional array. + + Returns + ------- + list + Python list of values at the specified index. + """ + return arr.values[(...,) + idx].tolist() if arr.ndim > 1 else arr.values.tolist() + + +def arr_to_float(arr, idx): + """ + Convert array value at specific index to a Python float for JSON serialization. + + Parameters + ---------- + arr : xr.DataArray or np.ndarray + Array containing the value to convert. + idx : tuple + Index for accessing the array. + + Returns + ------- + float + Python float value at the specified index. + """ + return float(arr.values[idx]) if arr.ndim > 0 else float(arr.values) diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index 2d43bb191..103f872c2 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -25,11 +25,11 @@ bar_plot_metric_region, heat_maps_metric_region, plot_metric_region, - qq_plot_metric_region, + quantile_plot_metric_region, ratio_plot_metric_region, score_card_metric_region, ) -from weathergen.evaluate.plotting.plotter import BarPlots, LinePlots, Plotter, ScoreCards +from weathergen.evaluate.plotting.plotter import BarPlots, LinePlots, Plotter, QuantilePlots, ScoreCards from weathergen.evaluate.scores.score import VerifiedData, get_score from weathergen.evaluate.utils.clim_utils import get_climatology from weathergen.evaluate.utils.regions import RegionBoundingBox @@ -185,7 +185,7 @@ def calc_scores_per_stream( if not valid_scores: continue - # Concatenate all metrics; use "different" for metric-specific coords + # Concatenate all metrics using "different" to allow metric-specific coords combined_metrics = xr.concat( valid_scores, dim="metric", @@ -210,25 +210,31 @@ def calc_scores_per_stream( metric_stream.loc[criteria] = combined_metrics - # Preserve metric-specific coordinates that loc assignment may drop - if "qq_full_data" in combined_metrics.coords: - if "qq_full_data" not in metric_stream.coords: - # Initialize coordinate array on first encounter - metric_stream = metric_stream.assign_coords( - { - "qq_full_data": xr.DataArray( - np.full(metric_stream.shape, "", dtype=object), - dims=metric_stream.dims, - coords={ - dim: metric_stream.coords[dim] for dim in metric_stream.dims - }, - ) - } - ) - # Copy coordinate values for this slice - metric_stream.coords["qq_full_data"].loc[criteria] = combined_metrics.coords[ - "qq_full_data" - ] + # Preserve non-dimensional coordinates that .loc assignment drops + for coord in combined_metrics.coords: + # Skip dimensional coordinates and index variables + if coord in combined_metrics.dims or coord in metric_stream.dims: + continue + + # Check if coordinate dimensions are compatible with metric_stream + coord_dims = combined_metrics.coords[coord].dims + if not all(dim in metric_stream.dims for dim in coord_dims): + continue + + # Initialize new coordinate with proper dimensions if needed + if coord not in metric_stream.coords: + coord_shape = tuple(len(metric_stream.coords[dim]) for dim in coord_dims) + metric_stream = metric_stream.assign_coords({ + coord: xr.DataArray( + np.full(coord_shape, "", dtype=object), + dims=coord_dims, + coords={dim: metric_stream.coords[dim] for dim in coord_dims}, + ) + }) + + # Assign coordinate values using filtered criteria + coord_criteria = {dim: criteria[dim] for dim in coord_dims if dim in criteria} + metric_stream.coords[coord].loc[coord_criteria] = combined_metrics.coords[coord] lead_time_map[fstep] = ( np.unique(combined_metrics.lead_time.values.astype("timedelta64[h]")) @@ -562,11 +568,12 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): plotter = LinePlots(plot_cfg, summary_dir) sc_plotter = ScoreCards(plot_cfg, summary_dir) br_plotter = BarPlots(plot_cfg, summary_dir) + quantile_plotter = QuantilePlots(plot_cfg, summary_dir) for region in regions: for metric in metrics: # Special handling for Q-Q analysis metric if metric == "qq_analysis": - qq_plot_metric_region(metric, region, runs, scores_dict, plotter) + quantile_plot_metric_region(metric, region, runs, scores_dict, quantile_plotter) elif eval_opt.get("summary_plots", True): plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary) if eval_opt.get("ratio_plots", False): From 86161a5689fa71a72c43612971c0e62a8f1082fe Mon Sep 17 00:00:00 2001 From: "Dr. Jehangir Awan" Date: Thu, 29 Jan 2026 19:33:59 +0100 Subject: [PATCH 4/5] Fix: remove duplicate inline import in plot_utils.py --- .../evaluate/plotting/plot_utils.py | 52 +++++++++++------ .../weathergen/evaluate/plotting/plotter.py | 3 +- .../src/weathergen/evaluate/utils/utils.py | 56 ++++++++++++------- 3 files changed, 72 insertions(+), 39 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 29a8964ce..4a5b2240c 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -13,18 +13,19 @@ import numpy as np import xarray as xr +from numpy.typing import NDArray _logger = logging.getLogger(__name__) -def _flatten_or_average(arr: np.ndarray) -> np.ndarray: +def _flatten_or_average(arr: NDArray) -> NDArray: """Flatten array or average across non-quantile dimensions. - + Parameters ---------- arr : np.ndarray Input array, possibly multi-dimensional. - + Returns ------- np.ndarray @@ -475,8 +476,6 @@ def quantile_plot_metric_region( # For qq_analysis, extract full Q-Q data from coordinate if metric == "qq_analysis": if "qq_full_data" in data_for_channel.coords: - import json - # Get the qq_full_data - it may be 0-d (scalar) or multi-d qq_coord = data_for_channel.coords["qq_full_data"] @@ -494,18 +493,39 @@ def quantile_plot_metric_region( qq_data = json.loads(qq_data_str) - # Convert and process all arrays - qq_dataset = xr.Dataset( - { - "quantile_levels": (["quantile"], _flatten_or_average(np.array(qq_data["quantile_levels"]))), - "p_quantiles": (["quantile"], _flatten_or_average(np.array(qq_data["p_quantiles"]))), - "gt_quantiles": (["quantile"], _flatten_or_average(np.array(qq_data["gt_quantiles"]))), - "qq_deviation": (["quantile"], _flatten_or_average(np.array(qq_data["qq_deviation"]))), - "qq_deviation_normalized": (["quantile"], _flatten_or_average(np.array(qq_data.get("qq_deviation_normalized", qq_data["qq_deviation"])))), - "extreme_low_mse": ([], float(np.mean(qq_data.get("extreme_low_mse", 0.0)))), - "extreme_high_mse": ([], float(np.mean(qq_data.get("extreme_high_mse", 0.0)))), - } + # Build dataset with quantile-based and scalar variables + quantile_vars = [ + "quantile_levels", + "p_quantiles", + "gt_quantiles", + "qq_deviation", + ] + dataset_dict = { + var: (["quantile"], _flatten_or_average(np.array(qq_data[var]))) + for var in quantile_vars + } + + # Add normalized deviation (with fallback) + dataset_dict["qq_deviation_normalized"] = ( + ["quantile"], + _flatten_or_average( + np.array( + qq_data.get("qq_deviation_normalized", qq_data["qq_deviation"]) + ) + ), ) + + # Add scalar MSE values + dataset_dict["extreme_low_mse"] = ( + [], + float(np.mean(qq_data.get("extreme_low_mse", 0.0))), + ) + dataset_dict["extreme_high_mse"] = ( + [], + float(np.mean(qq_data.get("extreme_high_mse", 0.0))), + ) + + qq_dataset = xr.Dataset(dataset_dict) qq_full_data.append(qq_dataset) else: _logger.warning("qq_full_data not found in coords for qq_analysis metric") diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index de51a5ef8..4f2244182 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -1139,7 +1139,6 @@ def heat_map( class QuantilePlots: - def __init__(self, plotter_cfg: dict, output_basedir: str | Path): """ Initialize the QuantilePlots class. @@ -1160,7 +1159,7 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path): self.dpi_val = plotter_cfg.get("dpi_val") self.fig_size = plotter_cfg.get("fig_size") self.out_plot_dir = Path(output_basedir) / "quantile_plots" - + if not os.path.exists(self.out_plot_dir): _logger.info(f"Creating dir {self.out_plot_dir}") os.makedirs(self.out_plot_dir, exist_ok=True) diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index 103f872c2..5bacdb87e 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -29,7 +29,13 @@ ratio_plot_metric_region, score_card_metric_region, ) -from weathergen.evaluate.plotting.plotter import BarPlots, LinePlots, Plotter, QuantilePlots, ScoreCards +from weathergen.evaluate.plotting.plotter import ( + BarPlots, + LinePlots, + Plotter, + QuantilePlots, + ScoreCards, +) from weathergen.evaluate.scores.score import VerifiedData, get_score from weathergen.evaluate.utils.clim_utils import get_climatology from weathergen.evaluate.utils.regions import RegionBoundingBox @@ -185,11 +191,11 @@ def calc_scores_per_stream( if not valid_scores: continue - # Concatenate all metrics using "different" to allow metric-specific coords + # Concatenate all metrics using "minimal" to handle metrics with different coords combined_metrics = xr.concat( valid_scores, dim="metric", - coords="different", + coords="minimal", combine_attrs="no_conflicts", ) @@ -210,31 +216,39 @@ def calc_scores_per_stream( metric_stream.loc[criteria] = combined_metrics - # Preserve non-dimensional coordinates that .loc assignment drops - for coord in combined_metrics.coords: - # Skip dimensional coordinates and index variables - if coord in combined_metrics.dims or coord in metric_stream.dims: + # Restore metric-specific coordinates that were dropped by coords="minimal" + # (e.g., quantiles, extreme_percentiles for qq_analysis) + for coord_name in combined_metrics.coords: + # Skip coordinates that are already dimensions (no need to restore) + if coord_name in combined_metrics.dims or coord_name in metric_stream.dims: continue - # Check if coordinate dimensions are compatible with metric_stream - coord_dims = combined_metrics.coords[coord].dims + # Only restore coordinates whose dimensions exist in metric_stream + # (e.g., skip coords with 'quantile' dim if metric_stream doesn't have it) + coord_dims = combined_metrics.coords[coord_name].dims if not all(dim in metric_stream.dims for dim in coord_dims): + _logger.debug( + f"Skipping coordinate '{coord_name}' with incompatible " + f"dimensions {coord_dims} (metric_stream has {metric_stream.dims})" + ) continue - # Initialize new coordinate with proper dimensions if needed - if coord not in metric_stream.coords: + # Initialize coordinate in metric_stream if it doesn't exist yet + if coord_name not in metric_stream.coords: coord_shape = tuple(len(metric_stream.coords[dim]) for dim in coord_dims) - metric_stream = metric_stream.assign_coords({ - coord: xr.DataArray( - np.full(coord_shape, "", dtype=object), - dims=coord_dims, - coords={dim: metric_stream.coords[dim] for dim in coord_dims}, - ) - }) + metric_stream = metric_stream.assign_coords( + { + coord_name: xr.DataArray( + np.full(coord_shape, "", dtype=object), + dims=coord_dims, + coords={dim: metric_stream.coords[dim] for dim in coord_dims}, + ) + } + ) - # Assign coordinate values using filtered criteria - coord_criteria = {dim: criteria[dim] for dim in coord_dims if dim in criteria} - metric_stream.coords[coord].loc[coord_criteria] = combined_metrics.coords[coord] + # Build indexers to select the right location in metric_stream + indexers = {dim: criteria[dim] for dim in coord_dims if dim in criteria} + metric_stream.coords[coord_name].loc[indexers] = combined_metrics.coords[coord_name] lead_time_map[fstep] = ( np.unique(combined_metrics.lead_time.values.astype("timedelta64[h]")) From dcdc75b02487b30fd4844055fe8986da098045b6 Mon Sep 17 00:00:00 2001 From: "Dr. Jehangir Awan" Date: Wed, 4 Feb 2026 23:12:36 +0100 Subject: [PATCH 5/5] Remove JSON serialization bottleneck in Q-Q analysis, use xarray attributes --- .../evaluate/plotting/plot_utils.py | 103 ++++++------------ .../weathergen/evaluate/plotting/plotter.py | 6 +- .../src/weathergen/evaluate/scores/score.py | 61 ++++------- .../weathergen/evaluate/scores/score_utils.py | 38 ------- .../src/weathergen/evaluate/utils/utils.py | 45 +++++--- 5 files changed, 92 insertions(+), 161 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 4a5b2240c..5f02b3e05 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -7,7 +7,6 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import json import logging from collections.abc import Iterable, Sequence @@ -473,62 +472,35 @@ def quantile_plot_metric_region( if data_for_channel.isnull().all(): continue - # For qq_analysis, extract full Q-Q data from coordinate - if metric == "qq_analysis": - if "qq_full_data" in data_for_channel.coords: - # Get the qq_full_data - it may be 0-d (scalar) or multi-d - qq_coord = data_for_channel.coords["qq_full_data"] - - # Extract the first non-empty JSON string - if qq_coord.ndim == 0: - qq_data_str = qq_coord.item() - else: - # Flatten and get first non-empty value - qq_values = qq_coord.values.flatten() - qq_data_str = next((v for v in qq_values if v and v != ""), None) - - if not qq_data_str: - _logger.warning(f"Empty qq_full_data for channel {ch}") - continue - - qq_data = json.loads(qq_data_str) - - # Build dataset with quantile-based and scalar variables - quantile_vars = [ - "quantile_levels", - "p_quantiles", - "gt_quantiles", - "qq_deviation", - ] - dataset_dict = { - var: (["quantile"], _flatten_or_average(np.array(qq_data[var]))) - for var in quantile_vars - } - - # Add normalized deviation (with fallback) - dataset_dict["qq_deviation_normalized"] = ( - ["quantile"], - _flatten_or_average( - np.array( - qq_data.get("qq_deviation_normalized", qq_data["qq_deviation"]) - ) + # For qq_analysis, extract Q-Q data from attributes + if metric == "qq_analysis" and "p_quantiles" in data_for_channel.attrs: + attrs = data_for_channel.attrs + # Convert to numpy arrays once + p_quantiles_arr = np.array(attrs["p_quantiles"]) + gt_quantiles_arr = np.array(attrs["gt_quantiles"]) + qq_deviation_arr = np.array(attrs["qq_deviation"]) + qq_deviation_norm_arr = np.array(attrs["qq_deviation_normalized"]) + quantile_levels = np.array(attrs["quantile_levels"]) + extreme_low_mse = float(np.mean(np.array(attrs["extreme_low_mse"]))) + extreme_high_mse = float(np.mean(np.array(attrs["extreme_high_mse"]))) + + qq_dataset = xr.Dataset( + { + "quantile_levels": (["quantile"], quantile_levels), + "p_quantiles": (["quantile"], _flatten_or_average(p_quantiles_arr)), + "gt_quantiles": (["quantile"], _flatten_or_average(gt_quantiles_arr)), + "qq_deviation": (["quantile"], _flatten_or_average(qq_deviation_arr)), + "qq_deviation_normalized": ( + ["quantile"], + _flatten_or_average(qq_deviation_norm_arr), ), - ) - - # Add scalar MSE values - dataset_dict["extreme_low_mse"] = ( - [], - float(np.mean(qq_data.get("extreme_low_mse", 0.0))), - ) - dataset_dict["extreme_high_mse"] = ( - [], - float(np.mean(qq_data.get("extreme_high_mse", 0.0))), - ) - - qq_dataset = xr.Dataset(dataset_dict) - qq_full_data.append(qq_dataset) - else: - _logger.warning("qq_full_data not found in coords for qq_analysis metric") + "extreme_low_mse": ([], extreme_low_mse), + "extreme_high_mse": ([], extreme_high_mse), + } + ) + # Store extreme percentiles for plotting + qq_dataset.attrs["extreme_percentiles"] = tuple(attrs["extreme_percentiles"]) + qq_full_data.append(qq_dataset) selected_data.append(data_for_channel) labels.append(runs[run_id].get("label", run_id)) @@ -541,28 +513,23 @@ def quantile_plot_metric_region( prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream, ch] ) - # Check if plotter has qq_plot method + # Check if plotter has qq_plot method and Q-Q data is available if hasattr(plotter, "qq_plot") and qq_full_data: _logger.info(f"Creating Q-Q plot with {len(qq_full_data)} dataset(s).") + # Extract extreme_percentiles from dataset + extreme_pct = qq_full_data[0].attrs["extreme_percentiles"] plotter.qq_plot( qq_full_data, labels, tag=name, metric=metric, + extreme_percentiles=extreme_pct, ) else: - # Fallback to standard plot if qq_plot not available + # Skip plotting if no Q-Q data available _logger.warning( - f"Plotter does not have qq_plot method or no full Q-Q data available. " - f"Falling back to standard plot for {metric}." - ) - selected_data, time_dim = _assign_time_coord(selected_data) - plotter.plot( - selected_data, - labels, - tag=name, - x_dim=time_dim, - y_dim=metric, + f"Q-Q data not available for {metric} - {region} - {stream} - {ch}. " + f"Skipping plot generation." ) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 4f2244182..cc03f1edd 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -1203,7 +1203,7 @@ def qq_plot( labels: str | list, tag: str = "", metric: str = "qq_analysis", - extreme_percentiles: tuple[float, float] = (5.0, 95.0), + extreme_percentiles: tuple[float, float] | None = None, ) -> None: """ Create quantile-quantile (Q-Q) plots for extreme value analysis. @@ -1237,6 +1237,10 @@ def qq_plot( """ data_list, label_list = self._check_lengths(qq_data, labels) + # Use extreme_percentiles from data if not explicitly provided + if extreme_percentiles is None: + extreme_percentiles = tuple(data_list[0].attrs.get("extreme_percentiles", (5.0, 95.0))) + # Create figure with subplots fig = plt.figure(figsize=(16, 6), dpi=self.dpi_val) gs = fig.add_gridspec(1, 2, width_ratios=[2, 1], wspace=0.3) diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index 46aac0cb9..d4846b894 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -7,7 +7,6 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. import inspect -import json import logging from dataclasses import dataclass @@ -17,7 +16,7 @@ import xarray as xr from scipy.spatial import cKDTree -from weathergen.evaluate.scores.score_utils import arr_to_float, arr_to_list, to_list +from weathergen.evaluate.scores.score_utils import to_list # from common.io import MockIO @@ -243,6 +242,7 @@ def get_score( """ if score_name in self.det_metrics_dict.keys(): f = self.det_metrics_dict[score_name] + _logger.debug(f"Using deterministic metric: {score_name}") elif score_name in self.prob_metrics_dict.keys(): assert self.ens_dim in data.prediction.dims, ( f"Probablistic score {score_name} chosen, but ensemble dimension {self.ens_dim} " @@ -1505,7 +1505,7 @@ def calc_quantiles( n_quantiles: int = 100, quantile_method: str = "linear", focus_extremes: bool = True, - extreme_percentiles: tuple[float, float] = (5.0, 95.0), + extreme_percentiles: tuple[float, float] = (5.0, 95.0), # 5th and 95th percentiles iqr_percentiles: tuple[float, float] = (25.0, 75.0), ) -> xr.DataArray: """ @@ -1560,6 +1560,8 @@ def calc_quantiles( "Cannot calculate Q-Q analysis without aggregation dimensions (agg_dims=None)." ) + _logger.info(f"Starting Q-Q analysis with {n_quantiles} quantiles") + # Generate quantile levels if focus_extremes: # Add more resolution in the tails for extreme value analysis @@ -1572,19 +1574,9 @@ def calc_quantiles( else: quantile_levels = np.linspace(0.001, 0.999, n_quantiles) - # Flatten the data across aggregation dimensions while preserving other dimensions - # Determine which dimensions to keep (not in agg_dims) - keep_dims = [dim for dim in p.dims if dim not in self._agg_dims] - - if keep_dims: - # Stack aggregation dimensions into a single dimension - p_flat = p.stack({"_agg_points": self._agg_dims}) - gt_flat = gt.stack({"_agg_points": self._agg_dims}) - else: - # If all dimensions are aggregated, add a scalar dimension - p_flat = p.stack({"_agg_points": self._agg_dims}).expand_dims("scalar") - gt_flat = gt.stack({"_agg_points": self._agg_dims}).expand_dims("scalar") - keep_dims = ["scalar"] + # Stack aggregation dimensions into a single dimension + p_flat = p.stack({"_agg_points": self._agg_dims}) + gt_flat = gt.stack({"_agg_points": self._agg_dims}) # Remove NaN values before quantile calculation p_flat = p_flat.dropna(dim="_agg_points", how="all") @@ -1622,34 +1614,21 @@ def calc_quantiles( # Calculate overall Q-Q score (mean absolute deviation across all quantiles) overall_qq_score = qq_deviation.mean(dim="quantile") - # Store Q-Q data as a non-scalar coordinate with matching dimensions - # This ensures each channel/sample/ens combination gets its own Q-Q data - # Create separate JSON strings for each position in the multi-dimensional array - qq_data_coord_array = np.empty(overall_qq_score.shape, dtype=object) - - # Iterate over all positions and create individual JSON strings - for idx in np.ndindex(overall_qq_score.shape): - qq_full_data = { + # Store Q-Q data as xarray attributes for automatic JSON serialization + overall_qq_score.attrs.update( + { + "p_quantiles": p_quantiles.values.tolist(), + "gt_quantiles": gt_quantiles.values.tolist(), + "qq_deviation": qq_deviation.values.tolist(), + "qq_deviation_normalized": qq_deviation_normalized.values.tolist(), + "extreme_low_mse": extreme_low_mse.values.tolist(), + "extreme_high_mse": extreme_high_mse.values.tolist(), "quantile_levels": quantile_levels.tolist(), - "p_quantiles": arr_to_list(p_quantiles, idx), - "gt_quantiles": arr_to_list(gt_quantiles, idx), - "qq_deviation": arr_to_list(qq_deviation, idx), - "qq_deviation_normalized": arr_to_list(qq_deviation_normalized, idx), - "extreme_low_mse": arr_to_float(extreme_low_mse, idx), - "extreme_high_mse": arr_to_float(extreme_high_mse, idx), + "extreme_percentiles": list(extreme_percentiles), + "iqr_percentiles": list(iqr_percentiles), } - qq_data_coord_array[idx] = json.dumps(qq_full_data) - - qq_data_coord = xr.DataArray( - qq_data_coord_array, - dims=overall_qq_score.dims, - coords={dim: overall_qq_score.coords[dim] for dim in overall_qq_score.dims}, - attrs={"description": "Full Q-Q analysis data for plotting"}, ) - overall_qq_score = overall_qq_score.assign_coords({"qq_full_data": qq_data_coord}) - # Remove scalar dimension if it was added - if "scalar" in overall_qq_score.dims: - overall_qq_score = overall_qq_score.squeeze("scalar", drop=True) + _logger.info(f"Q-Q analysis completed with {len(overall_qq_score.attrs)} attributes") return overall_qq_score diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score_utils.py b/packages/evaluate/src/weathergen/evaluate/scores/score_utils.py index 9d9f50451..4ad793d07 100644 --- a/packages/evaluate/src/weathergen/evaluate/scores/score_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score_utils.py @@ -36,41 +36,3 @@ def to_list(obj: Any) -> list: elif not isinstance(obj, list): obj = [obj] return obj - - -def arr_to_list(arr, idx): - """ - Convert array values at specific index to a Python list for JSON serialization. - - Parameters - ---------- - arr : xr.DataArray or np.ndarray - Array containing values to convert. - idx : tuple - Index tuple for accessing multi-dimensional array. - - Returns - ------- - list - Python list of values at the specified index. - """ - return arr.values[(...,) + idx].tolist() if arr.ndim > 1 else arr.values.tolist() - - -def arr_to_float(arr, idx): - """ - Convert array value at specific index to a Python float for JSON serialization. - - Parameters - ---------- - arr : xr.DataArray or np.ndarray - Array containing the value to convert. - idx : tuple - Index for accessing the array. - - Returns - ------- - float - Python float value at the specified index. - """ - return float(arr.values[idx]) if arr.ndim > 0 else float(arr.values) diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index 5bacdb87e..033e5b669 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -147,6 +147,9 @@ def calc_scores_per_stream( ) lead_time_map = {} + # Store metric-specific attributes that get lost during concat + # Key: (fstep, metric) -> attrs dict + all_metric_attrs = {} for (fstep, tars), (_, preds) in zip(da_tars.items(), da_preds.items(), strict=False): if preds.ipoint.size == 0: @@ -173,8 +176,9 @@ def calc_scores_per_stream( # Build up computation graphs for all metrics _logger.debug(f"Build computation graphs for metrics for stream {stream}...") - # Add it only if it is not None + # Calculate scores and filter out None values valid_scores = [] + valid_metric_names = [] for metric in metrics: score = get_score( @@ -182,21 +186,29 @@ def calc_scores_per_stream( ) if score is not None: valid_scores.append(score) + valid_metric_names.append(metric) + else: + _logger.warning(f"Metric {metric} returned None, skipping") - valid_metric_names = [ - metric - for metric, score in zip(metrics, valid_scores, strict=False) - if score is not None - ] if not valid_scores: continue # Concatenate all metrics using "minimal" to handle metrics with different coords + # Preserve attributes from individual metrics (e.g., Q-Q analysis data) + # Store attributes before concat as they may be lost + for metric, score in zip(valid_metric_names, valid_scores, strict=False): + if score.attrs: + # Store with key (fstep, metric) for later restoration + all_metric_attrs[(int(fstep), metric)] = score.attrs.copy() + _logger.debug( + f"Stored {len(score.attrs)} attrs for fstep={fstep}, metric={metric}" + ) + combined_metrics = xr.concat( valid_scores, dim="metric", coords="minimal", - combine_attrs="no_conflicts", + combine_attrs="drop_conflicts", ) combined_metrics = combined_metrics.assign_coords(metric=valid_metric_names) @@ -273,12 +285,22 @@ def calc_scores_per_stream( ) _logger.info(f"Scores for run {reader.run_id} - {stream} calculated successfully.") + _logger.debug(f"all_metric_attrs keys: {list(all_metric_attrs.keys())}") # Build local dictionary for this region for metric in metrics: + metric_data = metric_stream.sel({"metric": metric}) + # Restore metric-specific attributes from all forecast steps + # Attributes are the same across forecast steps for a given metric + for (_stored_fstep, stored_metric), attrs in all_metric_attrs.items(): + if stored_metric == metric and attrs: + _logger.debug(f"Restoring {len(attrs)} attributes for {metric}") + metric_data.attrs.update(attrs) + break + local_scores.setdefault(metric, {}).setdefault(region, {}).setdefault(stream, {})[ reader.run_id - ] = metric_stream.sel({"metric": metric}) + ] = metric_data return local_scores @@ -529,10 +551,7 @@ def metric_list_to_json( for metric, metric_stream in metrics_dict.items(): for region in regions: - metric_now = metric_stream[region][stream] - for run_id in metric_now.keys(): - metric_now = metric_now[run_id] - + for run_id, metric_data in metric_stream[region][stream].items(): # Match the expected filename pattern save_path = ( reader.metrics_dir @@ -541,7 +560,7 @@ def metric_list_to_json( _logger.info(f"Saving results to {save_path}") with open(save_path, "w") as f: - json.dump(metric_now.to_dict(), f, indent=4) + json.dump(metric_data.to_dict(), f, indent=4) _logger.info( f"Saved all results of inference run {reader.run_id} - mini_epoch {reader.mini_epoch:d} "