From 9f1e35e39728c7032860e8a9c628d54319305969 Mon Sep 17 00:00:00 2001 From: Sorcha Date: Thu, 5 Feb 2026 16:09:20 +0100 Subject: [PATCH 1/3] power spectra --- .../example_extras/power_spectra/__init__.py | 0 .../power_spectra/psd_config.yml | 18 + .../example_extras/power_spectra/psd_main.py | 257 ++++++++ .../example_extras/power_spectra/psd_plots.py | 605 ++++++++++++++++++ 4 files changed, 880 insertions(+) create mode 100644 packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/__init__.py create mode 100644 packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_config.yml create mode 100644 packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py create mode 100644 packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_plots.py diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/__init__.py b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_config.yml b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_config.yml new file mode 100644 index 000000000..8b24bf3df --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_config.yml @@ -0,0 +1,18 @@ +variables: ['u'] # , 'dew_point_temperature'] +regions: ['ShortGlobe'] #, 'N-Mid-Lats', 'S-Mid-Lats', 'Tropics'] +pressure_levels: [250] +forecast_steps : [6, 12] # choose the forecast stpes (must be common to all the datasets) +output_dir: "~/WeatherGen/test_output" +prefix: "comparison_gn3gotvh" +# define prefix for images +comparisons: + wg_bhwuev7l: + netcdf_paths: ['/p/home/jusers/owens1/juwels/WeatherGen/output_nc_bhwuev7l/pred*'] + wg_gn3gotvh_prediction: + netcdf_paths: ['/p/home/jusers/owens1/juwels/WeatherGen/gn3gotvh/pred*'] + # if directory all paths chosen, or use wildcard, or specify single path + wg_target: + netcdf_paths: ['/p/home/jusers/owens1/juwels/WeatherGen/gn3gotvh/targ*'] + +## it will produce powerspectra and average across all the samples +# available then plot the fsteps as single plots \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py new file mode 100644 index 000000000..8d62ce4fd --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py @@ -0,0 +1,257 @@ +#!/usr/bin/env -S uv run +# /// script +# dependencies = [ +# "omegaconf", +# "weathergen-common" +# ] +# [tool.uv.sources] +# weathergen-common = { path = "../../../../../../common" } +# /// + +""" +Plots the power spectrum of the analysis increments +Adapted from Martin Willet's code for power spectra +for use with the WeatherGenerator model: + +uv run packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py +--run-id gn3gotvh --export-dir /p/home/jusers/owens1/juwels/WeatherGen/gn3gotvh + +OR + +uv run packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py +--config packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_config.yml + +""" +import argparse +import glob +import logging +import os +import sys +from pathlib import Path + +import psd_plots as psd_plots +from omegaconf import DictConfig, OmegaConf + +# Local application / package +from weathergen.common.config import _REPO_ROOT +from weathergen.common.logger import init_loggers + +_logger = logging.getLogger(__name__) + +def extract_filepaths(netcdf_paths: list): + """ + Extracts filepaths from a list of netcdf paths. + If a directory is given, all files in the directory are returned. + Parameters + ---------- + netcdf_paths: + List of netcdf paths + Returns + ------- + list: + List of filepaths + """ + if len(netcdf_paths) > 1: + # list of files + return netcdf_paths + else: + netcdf_path = netcdf_paths[0] + if os.path.isfile(netcdf_path): + return netcdf_paths + elif os.path.isdir(netcdf_path): + glob_path = netcdf_path + "/*" + else: + glob_path = netcdf_path + return glob.glob(glob_path) + + +def psd_from_config(cfg: dict): + """ + Main function that controls power spectra density plotting. + Parameters + ---------- + cfg: + Configuration input stored as dictionary + """ + diags = cfg.variables + regions = cfg.regions + plevels = cfg.pressure_levels + comparison_dict = {} + for comp in cfg.comparisons: + # extract file paths + comparison_dict[comp] = extract_filepaths(cfg.comparisons[comp]["netcdf_paths"]) + outdir = cfg.output_dir + os.makedirs(outdir, exist_ok=True) + fname = cfg.prefix + fc_times = cfg.forecast_steps + + psd_plots.plot_psds( + comparison_dict, + regions, + diags, + fname=fname, + outdir=outdir, + usencname=True, + plevels=plevels, + fc_times=fc_times, + ) + + +def parse_args(args: list) -> None: + """ + Parse command line arguments. + Parameters + ---------- + args : List of command line arguments. + """ + parser = argparse.ArgumentParser(description="Plot power spectral densities from NetCDF files.") + + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to the configuration YAML file.", + ) + + parser.add_argument( + "--output-dir", + type=str, + default=_REPO_ROOT / "plots" / "power_spectra", + help="Directory to save the output plots.", + ) + + parser.add_argument( + "--run-id", + type=str, + help="Run ID to construct configuration if --config is not provided.", + ) + + parser.add_argument( + "--variables", + type=str, + nargs="+", + help="List of variables to plot (e.g., 'u', 't2m'). If None, uses all", + choices=["q", "t", "u", "v", "z", "t2m", "msl", "u10", "v10", "d2m", "skt", "sp"], + default=["z", "u10", "v10"], + ) + + parser.add_argument( + "--regions", + type=str, + nargs="+", + help="List of regions to plot (e.g., 'ShortGlobe', 'N-Mid-Lats'). If None, uses all", + choices=["ShortGlobe", "N-Mid-Lats", "S-Mid-Lats", "Tropics"], + default=["ShortGlobe"], + ) + + parser.add_argument( + "--pressure-levels", + type=int, + nargs="+", + help="List of pressure levels to plot (e.g., 250, 500). \ + If not provided uses all", + default=[100, 850], + ) + + parser.add_argument( + "--forecast-steps", + type=int, + nargs="+", + help="List of forecast steps to plot (e.g., 6, 12). \ + If not provided averages over all forecast steps", + default=None, + ) + + parser.add_argument( + "--prefix", + type=str, + default="", + help="Prefix for output files (default: empty).", + ) + + parser.add_argument( + "--export-dir", + type=str, + help="Directory where exported NetCDF files were saved.", + default=None, + ) + + args, unknown_args = parser.parse_known_args(args) + if unknown_args: + _logger.warning(f"Unknown arguments: {unknown_args}") + return args + + +def construct_config_from_run_id(run_id: str, args: argparse.Namespace) -> DictConfig: + """ + Construct configuration from run ID and command line arguments. + Parameters + ---------- + run_id : Run ID to construct configuration for. + args : Command line arguments. + Returns + ------- + DictConfig: Constructed configuration. + """ + run_id_config = { + "variables": args.variables, + "regions": args.regions, + "pressure_levels": args.pressure_levels, + "forecast_steps": args.forecast_steps, + "prefix": args.prefix, + "output_dir": Path(args.output_dir), + "comparisons": { + run_id: {"netcdf_paths": [f"{args.export_dir}/pred*.nc"]}, + "target": {"netcdf_paths": [f"{args.export_dir}/targ*.nc"]}, + }, + } + run_id_config = DictConfig(run_id_config) + return run_id_config + + +def psd_from_args(args: list) -> None: + # Get run_id zarr data as lists of xarray DataArrays + """ + Export data from Zarr store to NetCDF files based on command line arguments. + Parameters + ---------- + args : List of command line arguments. + """ + init_loggers() + + args = parse_args(sys.argv[1:]) + + # Load configuration + if args.config: + config_file = Path(args.config) + config = OmegaConf.load(config_file) + # check config loaded correctly + assert isinstance(config, DictConfig), "Config file not loaded correctly" + # use PosixPath for output_dir + config.output_dir = Path(config.output_dir) + + # Use run id to construct config if not provided + elif args.run_id: + if args.export_dir is None: + # TODO: automatically run export into results directory and use that path here + raise ValueError("When using --run-id, --export-dir must also be provided.") + config = construct_config_from_run_id(args.run_id, args) + + else: + raise ValueError("Either --config or --run-id must be provided.") + + _logger.info(f"starting power spectral density plotting with config: {config}") + + psd_from_config(config) + + +def psd() -> None: + """ + Main function to plot power spectral densities. + """ + # By default, arguments from the command line are read. + psd_from_args(sys.argv[1:]) + + +if __name__ == "__main__": + psd() diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_plots.py b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_plots.py new file mode 100644 index 000000000..4cdbcbc68 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_plots.py @@ -0,0 +1,605 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "scitools-iris", +# ] +# /// + +""" +Adapted from Martin Willet's code for power spectra for use with the WeatherGenerator model +""" + +import logging +import warnings + +import iris +import matplotlib.pyplot as plt +import numpy as np + +_logger = logging.getLogger(__name__) +warnings.simplefilter(action="ignore", category=FutureWarning) + +# A couple of physical constants +g = 9.81 # Acceleration due to gravity. USed to convert GP to GPH. +re = 6.37e6 # earth radius. +re2 = re * re # earth radius squared + +# Define dictionary of regions of interest +regions = { + "FullGlobe": dict(label="FullGlobe", lonW=0.0, lonE=360.0, latS=-90.0, latN=90.0), + "ShortGlobe": dict(label="ShortGlobe", lonW=0.0, lonE=360.0, latS=-60.0, latN=60.0), + "N-Mid-Lats": dict(label="N-Mid-Lats", lonW=0.0, lonE=360.0, latS=30.0, latN=60.0), + "S-Mid-Lats": dict(label="S-Mid-Lats", lonW=0.0, lonE=360.0, latS=-60.0, latN=-30.0), + "Tropics": dict(label="Tropics", lonW=0.0, lonE=360.0, latS=-30.0, latN=30.0), + "Deep_Tropics": dict(label="Deep_Tropics", lonW=0.0, lonE=360.0, latS=-10.0, latN=10.0), + "NN-Sub-Tropics": dict(label="NN-Sub-Tropics", lonW=0.0, lonE=360.0, latS=30.0, latN=40.0), + "SS-Sub-Tropics": dict(label="SS-Sub-Tropics", lonW=0.0, lonE=360.0, latS=-40.0, latN=-30.0), + "NN-Mid-Lats": dict(label="NN-Mid-Lats", lonW=0.0, lonE=360.0, latS=45.0, latN=75.0), + "SS-Mid-Lats": dict(label="SS-Mid-Lats", lonW=0.0, lonE=360.0, latS=-75.0, latN=-45.0), + "N-Polar": dict(label="N-Polar", lonW=0.0, lonE=360.0, latS=75.0, latN=90.0), + "S-Polar": dict(label="S-Polar", lonW=0.0, lonE=360.0, latS=-90.0, latN=-75.0), +} + +# Define dictionary of potential diagnostics to plot +diags = { + "q": { + "ncvar": "q", + "ncname": "specific_humidity_at_pressure_levels", + "std": "specific_humidity", + "units": "kg kg-1", + "levtype": "pressure", + "scale": 0.01, + "slope": -5, + "yscale": 1.0, + }, + "t": { + "ncvar": "t", + "ncname": "temperature_at_pressure_levels", + "std": "air_temperature", + "units": "K", + "levtype": "pressure", + "scale": 1.0, + "slope": -3, + "yscale": 0.1, + }, + "u": { + "ncvar": "u", + "ncname": "u_wind_at_pressure_levels", + "std": "x_wind", + "units": "m s-1", + "levtype": "pressure", + "scale": 1.0, + "slope": -3, + "yscale": 1.0, + }, + "v": { + "ncvar": "v", + "ncname": "v_wind_at_pressure_levels", + "std": "y_wind", + "units": "m s-1", + "levtype": "pressure", + "scale": 1.0, + "slope": -3, + "yscale": 1.0, + }, + "z": { + "ncvar": "z", + "ncname": "geopotential_at_pressure_levels", + "std": "geopotential", + "units": "m2 s-2", + "levtype": "pressure", + "scale": 0.1 / g, + "slope": -5, + "yscale": 1.0, + }, + "u10": { + "ncvar": "10u", + "ncname": "u_wind_at_10m", + "std": "x_wind", + "units": "m s-1", + "levtype": "surface", + "scale": 1.0, + "slope": -3, + "yscale": 1.0, + }, + "v10": { + "ncvar": "v10", + "ncname": "v_wind_at_10m", + "std": "y_wind", + "units": "m s-1", + "levtype": "surface", + "scale": 1.0, + "slope": -3, + "yscale": 1.0, + }, + "d2m": { + "ncvar": "d2m", + "ncname": "dew_point_temperature_at_screen_level", + "std": "dew_point_temperature", + "units": "K", + "levtype": "surface", + "scale": 1.0, + "slope": -3, + "yscale": 1.0, + }, + "t2m": { + "ncvar": "t2m", + "ncname": "temperature_at_screen_level", + "std": "air_temperature", + "units": "K", + "levtype": "surface", + "scale": 1.0, + "slope": -3, + "yscale": 1.0, + }, + "msl": { + "ncvar": "msl", + "ncname": "mean_sea_level_pressure", + "std": "air_pressure_at_mean_sea_level", + "units": "Pa", + "levtype": "surface", + "scale": 0.01, + "slope": -5, + "yscale": 1.0, + }, + "skt": { + "ncvar": "skt", + "ncname": "skin_temperature", + "std": "sea_surface_skin_temperature", + "units": "K", + "levtype": "surface", + "scale": 1.0, + "slope": -3, + "yscale": 1.0, + }, + "sp": { + "ncvar": "sp", + "ncname": "surface_pressure", + "std": "surface_air_pressure", + "units": "Pa", + "levtype": "surface", + "scale": 0.01, + "slope": -5, + "yscale": 1.0, + }, +} + + +def add_levels(weathergen_diags, plevels): + for var_dict in weathergen_diags: + if weathergen_diags[var_dict]["levtype"] == "pressure": + weathergen_diags[var_dict]["levels"] = plevels + elif weathergen_diags[var_dict]["levtype"] == "surface": + weathergen_diags[var_dict]["levels"] = [0] + + +def psd(ht): + """ + Returns a power spectrum density for the positive non-zero frequencies + Assumes ht has an even number of points + """ + n = len(ht) + # Hf = np.fft.fft(ht, norm='forward') + hf = np.fft.rfft(ht, norm="forward") + psd = np.abs(hf[1 : round(n / 2 + 1)]) ** 2 + # Compensate for positive frequencies only + psd *= 2.0 + return psd + + +def old_cubepsd(cube, dimension="longitude"): + """ + Returns a power spectrum density for a cube + Assumes that cube.data has an even number of points in dimension dim + """ + + npoints = len(cube.coord(dimension).points) + field_psd = np.zeros([round(npoints / 2)]) + + _logger.info("About to calc PSDs") + + nloc = 0 + + for field_slice in cube.slices([dimension]): + nloc += 1 + field_psd += psd(field_slice.data) + + field_psd /= nloc + + return field_psd + + +def cubepsd(cubes, dimension="longitude"): + """ + Returns a power spectrum density for a cube + Assumes that cube.data has an even number of points in dimension dim + """ + + if isinstance(cubes, iris.cube.CubeList): + # being passed a cube list + npoints = len(cubes[0].coord(dimension).points) + else: + # Assume it is just a cube + npoints = len(cubes.coord(dimension).points) + + field_psd = np.zeros([round(npoints / 2)]) + + _logger.info("About to calc PSDs") + + nloc = 0 + if isinstance(cubes, iris.cube.CubeList): + for cube in cubes: + for field_slice in cube.slices([dimension]): + nloc += 1 + field_psd += psd(field_slice.data) + else: + for field_slice in cubes.slices([dimension]): + nloc += 1 + field_psd += psd(field_slice.data) + + field_psd /= nloc + + return field_psd + + +def addwvns(axes): + """ + Adds lines of equal wavenumber to plots + """ + yscale = axes.yaxis.get_scale() + ylims = axes.get_ylim() + + if yscale == "log": + ytxt = 10.0 ** (0.85 * (np.log10(ylims[1] / ylims[0])) + np.log10(ylims[0])) + else: + ytxt = 0.85 * (ylims[1] - ylims[0]) + ylims[0] + + wvns = [1, 2, 4, 8, 16, 24, 48, 96, 144, 216, 320, 640, 1280, 2560] + for wvn in wvns: + axes.plot( + np.array([wvn / 360.0, wvn / 360.0]), + np.array(ylims), + color="black", + lw=1.0, + scalex=False, + scaley=False, + ) + axes.text(wvn / 360.0, ytxt, f"n{wvn:3.0f}", rotation="vertical") + + return None + + +def addlengths(axes, region): + """ + Adds lines of equal spatial scale in km to plots + """ + + re = 6.37e6 # earth radius. Used to plot phyical lengths on plots. + yscale = axes.yaxis.get_scale() + ylims = axes.get_ylim() + + if yscale == "log": + ytxt = 10.0 ** (0.05 * (np.log10(ylims[1] / ylims[0])) + np.log10(ylims[0])) + else: + ytxt = 0.05 * (ylims[1] - ylims[0]) + ylims[0] + + lengths = np.array([1.0e4, 3.0e3, 1.0e3, 3.0e2, 1.0e2, 3e1, 1e1]) + + flengths = ( + 2.0 + * np.pi + * re + * np.cos((region["latN"] + region["latS"]) / 360.0 * np.pi) + / (1000.0 * lengths * 360.0) + ) + + for ilength in range(len(lengths)): + axes.plot( + np.array([flengths[ilength], flengths[ilength]]), + np.array(ylims), + color="black", + linestyle="dashed", + lw=1.0, + scalex=False, + scaley=False, + ) + axes.text(flengths[ilength], ytxt, f"{lengths[ilength]:5.0f}km", rotation="vertical") + + return None + + +def addidealslope(axes, slope, defxs=None, defy0=10.0): + """ + Adds an idealised slope to a log-log spectra plot + """ + if defxs is None: + defxs = [0.01, 0.1] + slopexs = np.array(defxs) + slopeys = defy0 * np.array([1.0, (slopexs[1] / slopexs[0]) ** slope]) + xtxt = np.sqrt(np.prod(slopexs)) + ytxt = np.sqrt(np.prod(slopeys)) + + axes.plot(slopexs, slopeys, color="black", lw=2.0, scalex=False, scaley=False) + axes.text(xtxt, ytxt, "$k^{" + str(slope) + "}$", fontsize="xx-large", weight="bold") + + return None + + +def calcposfreq(cube, dimension="longitude"): + """ + Given a cube and dimension returns the positive frequencies + Assumes gridpoints are evenly spaced + """ + npoints = len(cube.coord(dimension).points) + + # Create frequencies + freq = np.fft.fftfreq(npoints, d=360.0 / npoints) + + # Positive half + posfreq = np.absolute(freq[1 : round(npoints / 2 + 1)]) + + return posfreq + + +def region_constraint(region): + """ + Given a region definition, returns a longitude and latitude constraint + """ + # Setup iris constraint to extract data for this region: + lat_constraint = iris.Constraint( + latitude=lambda lat: lat >= region["latS"] and lat <= region["latN"] + ) + # Case where region straddles the Greenwich Meridian: + if region["lonW"] > region["lonE"]: + lon_constraint = iris.Constraint( + longitude=lambda lon: lon >= region["lonW"] or lon <= region["lonE"] + ) + # Normal case + else: + lon_constraint = iris.Constraint( + longitude=lambda lon: lon >= region["lonW"] and lon <= region["lonE"] + ) + # end if + + return lat_constraint, lon_constraint + + +def tidy_plot(axes, plttitle, ylabel, ylims, region): + """ + Add plots stuff common to all plots + """ + axes.set_title(plttitle) + axes.set_xlabel("Frequency (1/deg long)") + axes.set_ylabel(ylabel) + axes.grid(True, which="major", linewidth=1.0) + axes.grid(True, which="minor", linewidth=0.5) + axes.set_xlim(1.0e-3, 1.0e1) + axes.set_ylim(ylims[0], ylims[1]) + addwvns(axes) + addlengths(axes, region) + + return None + + +def setuppage(): + plt.rc("figure", figsize=(8.27, 11.69)) + plt.subplots_adjust(hspace=0.3) + # plt.rcParams['font.size']=11 + plt.rcParams["font.size"] = 13 + return None + + +def plot_psds( + comparison_dict, + regkeys, + diagkeys, + usencname=False, + fc_times=None, + fname=None, + outdir=None, + plevels=None, +): + """ + Calculates and plots power spectra + comparison_dict containing + testnames - a list of the names of test + testfiles - a list of the filenames - wildcards can be used. One for each test. + It is assumed that the first of the tests is to be used as the reference. + regkeys - a list of keys specifying which regions to produce plots for. + diagkeys - a list of keys specifying which diagnostics are required. + usencname - if True the diagnostic contraint will use the netcdf name. + If False (default) stash code will be used. + fctimes - an optional 2d-array containing the forecast-times for each plot. + fname - optional prefix for plot filename. + outdir - optional output directory. + """ + + failed_string = "" + + if fname is None: + fname = "" + + if fc_times is None: + n_fc_times = 1 + else: + n_fc_times = len(fc_times[:]) + + # Setup some standard settings. + loglog_ylims = np.array([1.0e-5, 1.0e2]) + # loglog_ylims = np.array([1.e-5, 1.e3]) + # semilogx_ylims = [0.0, 2.0] + semilogx_ylims = [0.0, 3.0] + colors = ["b", "r", "m", "c", "g", "orange"] + + # prep diag_keys + if plevels is None: + plevels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] + add_levels(diags, plevels) + + if regions is None: + regkeys = ["ShortGlobe", "N-Mid-Lats", "S-Mid-Lats", "Tropics"] + + if diagkeys is None: + diagkeys = ["q", "t", "u", "v", "z", "t2m", "msl", "u10", "v10", "d2m", "skt", "sp"] + + for diagkey in diagkeys: + # For each diagnostic... + diag = diags[diagkey] + if usencname: + if diag["ncname"] is not None: + field_constraint = diag["ncname"] + else: + field_constraint = None + + _logger.debug("field_constraint is:", field_constraint) + scale = diag["scale"] + levels = diag["levels"] + n_levels = len(levels) + + for regkey in regkeys: + # For each region... + region = regions[regkey] + + # Calc lat and lon contraint from region + lat_constraint, lon_constraint = region_constraint(region) + + # for level in np.nditer(levels): + for i_level in range(n_levels): + level = levels[i_level] + if diag["levtype"] == "pressure": + lev_constraint = iris.Constraint(pressure=level) + elif diag["levtype"] == "mlevels": + lev_constraint = iris.Constraint(model_level_number=level) + + for i_fc_time in range(n_fc_times): + plttitle = diag["ncname"] + ": " + figname = fname + diag["ncname"] + "_" + + if diag["levtype"] == "pressure": + plttitle += str(level) + "hPa: " + figname += str(level) + "_" + elif diag["levtype"] == "mlevels": + plttitle += "ML" + str(level) + ": " + figname += str(level) + "_" + + plttitle += region["label"] + figname += region["label"] + + if fc_times is not None: + fc_time = fc_times[i_fc_time] + desired_time = iris.Constraint(forecast_period=fc_time) + plttitle += "T" + str(fc_time) + figname += "T" + str(fc_time) + else: + fc_time = None + + figname += "_spectra.png" + _logger.info("Creating " + plttitle) + + # Initialise plot page + setuppage() + + n_test = len(comparison_dict.keys()) + # for testkey in testkeys: + for i_test, comp_key in enumerate(comparison_dict): + testname = comp_key + testfiles = comparison_dict[comp_key] + ############## Read data ########################## + _logger.info("About to read field") + + try: + psds = [] + for testfile in testfiles: + field = iris.load(testfile, field_constraint) + + _logger.debug(field) + + coord_names = [coord.name() for coord in field[0].coords()] + _logger.debug(coord_names) + + tot_constraint = lat_constraint & lon_constraint + + if "air_pressure" in coord_names: + for cube in field: + cube.coord("air_pressure").rename("pressure") + + if diag["levtype"] == "pressure" or diag["levtype"] == "mlevels": + tot_constraint = tot_constraint & lev_constraint + + if "forecast_period" in coord_names and fc_time is not None: + _logger.debug(field[0].coord("forecast_period")) + for cube in field: + cube.coord("forecast_period").convert_units("hours") + _logger.debug( + "forecast_period:", field[0].coord("forecast_period") + ) + + if "forecast_period" in coord_names and fc_time is not None: + # create the constraint + tot_constraint = tot_constraint & desired_time + _logger.debug("Tot_constraint:", tot_constraint) + field = field.extract(tot_constraint) + _logger.info("Completed read field") + + ############## PSD calcs ########################## + # Create frequencies + posfreq = calcposfreq(field[0]) + + # Calculate PSD + field_psd = scale * scale * cubepsd(field, dimension="longitude") + psds.append(field_psd) + + _logger.info("Averaging PSDs over all samples for one forecast_time") + field_psd = np.mean(psds, axis=0) + + if i_test == 0: + # Take a copy of the data + ref_psd = np.copy(field_psd) + + _logger.info("Completed calc PSDs") + + ############## Plotting ########################## + plt.subplot(2, 1, 1) + plt.loglog(posfreq, field_psd, color=colors[i_test], label=testname) + + if i_test == n_test - 1: + # last test. Add plt stuff + _logger.info(plttitle) + tidy_plot( + plt.gca(), + plttitle + ": zonal spectra", + "Power ((" + diag["units"] + ")^2 deg)", + diag["yscale"] * loglog_ylims, + region, + ) + plt.legend(loc="lower left") + + # Add idealised slopes + addidealslope( + plt.gca(), float(diag["slope"]), defy0=10.0 * diag["yscale"] + ) + + plt.subplot(2, 1, 2) + plt.semilogx( + posfreq, field_psd / ref_psd, color=colors[i_test], label=testname + ) + + if i_test == n_test - 1: + tidy_plot( + plt.gca(), + plttitle + ": ratio of zonal spectra", + "Power ratio", + semilogx_ylims, + region, + ) + plt.legend() + except Exception as e: + _logger.error(e) + _logger.error( + f"Plotting power spectra failed for {testfile, tot_constraint}" + ) + failed_string += f"{testfile, tot_constraint}" + plt.savefig(outdir / figname) + plt.close() + _logger.info(f"Runs failed: {failed_string}") From 651713eaa624258b301eecb697b582989cd8e2ac Mon Sep 17 00:00:00 2001 From: Sorcha Date: Thu, 5 Feb 2026 16:30:38 +0100 Subject: [PATCH 2/3] minor fixes --- .../evaluate/example_extras/power_spectra/psd_config.yml | 8 ++++---- .../evaluate/example_extras/power_spectra/psd_main.py | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_config.yml b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_config.yml index 8b24bf3df..d7bcc5922 100644 --- a/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_config.yml +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_config.yml @@ -1,8 +1,8 @@ -variables: ['u'] # , 'dew_point_temperature'] +variables: ['u'] # , 'v10', 'z'] etc. regions: ['ShortGlobe'] #, 'N-Mid-Lats', 'S-Mid-Lats', 'Tropics'] -pressure_levels: [250] -forecast_steps : [6, 12] # choose the forecast stpes (must be common to all the datasets) -output_dir: "~/WeatherGen/test_output" +pressure_levels: [250] +forecast_steps : [6, 12] # choose the forecast steps (must be common to all the datasets) +output_dir: "./plots/power_spectra" #relative to dir in which script is run prefix: "comparison_gn3gotvh" # define prefix for images comparisons: diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py index 8d62ce4fd..80d7bbfb9 100644 --- a/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py @@ -21,6 +21,13 @@ uv run packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py --config packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_config.yml +Prerequisties: + +Please export the inference into a regular gridded netcdf first using the export package: +e.g. +uv run export --run-id --stream ERA5 \ +--output-dir ../output_nc --format netcdf --regrid-degree 1 \ +--regrid-type regular_ll """ import argparse import glob From bb3047bce17caef5951236a8a36b9a22e2e6e2ba Mon Sep 17 00:00:00 2001 From: Sorcha Date: Fri, 6 Feb 2026 12:15:38 +0100 Subject: [PATCH 3/3] target as baseline --- .../evaluate/example_extras/power_spectra/psd_main.py | 2 +- .../evaluate/example_extras/power_spectra/psd_plots.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py index 80d7bbfb9..d777b3582 100644 --- a/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_main.py @@ -208,8 +208,8 @@ def construct_config_from_run_id(run_id: str, args: argparse.Namespace) -> DictC "prefix": args.prefix, "output_dir": Path(args.output_dir), "comparisons": { - run_id: {"netcdf_paths": [f"{args.export_dir}/pred*.nc"]}, "target": {"netcdf_paths": [f"{args.export_dir}/targ*.nc"]}, + run_id: {"netcdf_paths": [f"{args.export_dir}/pred*.nc"]} }, } run_id_config = DictConfig(run_id_config) diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_plots.py b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_plots.py index 4cdbcbc68..a9f20908e 100644 --- a/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_plots.py +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/power_spectra/psd_plots.py @@ -597,9 +597,9 @@ def plot_psds( except Exception as e: _logger.error(e) _logger.error( - f"Plotting power spectra failed for {testfile, tot_constraint}" + f"Plotting power spectra failed for {testfile}" ) - failed_string += f"{testfile, tot_constraint}" + failed_string += f"{testfile}" plt.savefig(outdir / figname) plt.close() _logger.info(f"Runs failed: {failed_string}")