diff --git a/config/default_forecast_config.yml b/config/default_forecast_config.yml index 0080ed252..056f2a506 100644 --- a/config/default_forecast_config.yml +++ b/config/default_forecast_config.yml @@ -182,4 +182,4 @@ wgtags: # Examples: "rollout_ablation_grid" exp: None # *** Experiment-specific tags *** - grid: None \ No newline at end of file + grid: None diff --git a/packages/evaluate/pyproject.toml b/packages/evaluate/pyproject.toml index 3e3570bf1..8d8c5b918 100644 --- a/packages/evaluate/pyproject.toml +++ b/packages/evaluate/pyproject.toml @@ -29,6 +29,7 @@ dev = [ ] [project.scripts] +ssl_analysis= "weathergen.evaluate.ssl.ssl_eval:ssl_analysis" evaluation = "weathergen.evaluate.run_evaluation:evaluate" export = "weathergen.evaluate.export.export_inference:export" diff --git a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py index 1c819f78b..5bb3d4c01 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py @@ -8,7 +8,6 @@ # nor does it submit to any jurisdiction. # Standard library -import logging import re from pathlib import Path @@ -20,16 +19,15 @@ # Local application / package from weathergen.evaluate.io.io_reader import Reader -_logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) - class CsvReader(Reader): """ Reader class to read evaluation data from CSV files and convert to xarray DataArray. """ - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): + def __init__( + self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose=True + ): """ Initialize the CsvReader. @@ -43,7 +41,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non list of private paths for the supported HPC """ - super().__init__(eval_cfg, run_id, private_paths) + super().__init__(eval_cfg, run_id, private_paths, verbose) self.metrics_dir = Path(self.eval_cfg.get("metrics_dir")) self.metrics_base_dir = self.metrics_dir diff --git a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py index 658339cac..f5024e6fa 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py @@ -15,9 +15,6 @@ # Third-party import xarray as xr -_logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) - @dataclass class ReaderOutput: @@ -62,7 +59,9 @@ class DataAvailability: class Reader: - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | None = None): + def __init__( + self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | None = None, verbose=True + ): """ Generic data reader class. @@ -89,6 +88,10 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | "results_base_dir", None ) # base directory where results will be stored + self._logger = logging.getLogger(__name__) + logger_level = logging.INFO if verbose else logging.CRITICAL + 1 + self._logger.setLevel(logger_level) + def get_stream(self, stream: str): """ returns the dictionary associated to a particular stream @@ -217,7 +220,7 @@ def check_availability( requested[name] = reader_data[name] # If file with metrics exists, must exactly match if available_data is not None and reader_data[name] != available[name]: - _logger.info( + self._logger.info( f"Requested all {name}s for {mode}, but previous config was a " "strict subset. Recomputation required." ) @@ -230,7 +233,7 @@ def check_availability( if name == "ensemble" and "mean" in missing: missing.remove("mean") if missing: - _logger.info( + self._logger.info( f"Requested {name}(s) {missing} is unavailable. " f"Removing missing {name}(s) for {mode}." ) @@ -240,7 +243,7 @@ def check_availability( # Must be a subset of available_data (if provided) if available_data is not None and not requested[name] <= available[name]: missing = requested[name] - available[name] - _logger.info( + self._logger.info( f"{name.capitalize()}(s) {missing} missing in previous evaluation." "Recomputation required." ) @@ -248,7 +251,7 @@ def check_availability( if check_score and not corrected: scope = "metric file" if available_data is not None else "Zarr file" - _logger.info( + self._logger.info( f"All checks passed – All channels, samples, fsteps requested for {mode} are " f"present in {scope}..." ) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 062031273..6291b0b61 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -9,7 +9,6 @@ # Standard library import json -import logging from collections import defaultdict from pathlib import Path @@ -30,13 +29,12 @@ from weathergen.evaluate.scores.score_utils import to_list from weathergen.evaluate.utils.derived_channels import DeriveChannels -_logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) - class WeatherGenReader(Reader): - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): - super().__init__(eval_cfg, run_id, private_paths) + def __init__( + self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose=True + ): + super().__init__(eval_cfg, run_id, private_paths, verbose) # TODO: remove backwards compatibility to "epoch" in Feb. 2026 self.mini_epoch = eval_cfg.get("mini_epoch", 0) @@ -46,9 +44,11 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non if not self.results_base_dir: self.results_base_dir = get_path_run(self.inference_cfg) - _logger.info(f"Results directory obtained from private config: {self.results_base_dir}") + self._logger.info( + f"Results directory obtained from private config: {self.results_base_dir}" + ) else: - _logger.info(f"Results directory parsed: {self.results_base_dir}") + self._logger.info(f"Results directory parsed: {self.results_base_dir}") self.runplot_base_dir = Path( self.eval_cfg.get("runplot_base_dir", self.results_base_dir) @@ -80,18 +80,18 @@ def get_inference_config(self): configuration file from the inference run """ if self.private_paths: - _logger.info( + self._logger.info( f"Loading config for run {self.run_id} from private paths: {self.private_paths}" ) config = load_merge_configs(self.private_paths, self.run_id, self.mini_epoch) else: - _logger.info( + self._logger.info( f"Loading config for run {self.run_id} from model directory: {self.model_base_dir}" ) config = load_run_config(self.run_id, self.mini_epoch, self.model_base_dir) if type(config) not in [dict, oc.DictConfig]: - _logger.warning("Model config not found. inference config will be empty.") + self._logger.warning("Model config not found. inference config will be empty.") config = {} return config @@ -125,7 +125,7 @@ def get_climatology_filename(self, stream: str) -> str | None: if clim_base_dir and clim_fn: clim_data_path = Path(clim_base_dir).join(clim_fn) else: - _logger.warning( + self._logger.warning( f"No climatology path specified for stream {stream}. Setting climatology to " "NaN. Add 'climatology_path' to evaluation config to use metrics like ACC." ) @@ -145,9 +145,9 @@ def get_channels(self, stream: str) -> list[str]: ------- A list of channel names. """ - _logger.debug(f"Getting channels for stream {stream}...") + self._logger.debug(f"Getting channels for stream {stream}...") all_channels = self.get_inference_stream_attr(stream, "val_target_channels") - _logger.debug(f"Channels found in config: {all_channels}") + self._logger.debug(f"Channels found in config: {all_channels}") return all_channels def load_scores( @@ -208,7 +208,7 @@ def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArr Path(self.metrics_dir) / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" ) - _logger.debug(f"Looking for: {score_path}") + self._logger.debug(f"Looking for: {score_path}") if score_path.exists(): with open(score_path) as f: data_dict = json.load(f) @@ -253,8 +253,9 @@ def __init__( private_paths: dict | None = None, regions: list[str] | None = None, metrics: list[str] | None = None, + verbose=True, ): - super().__init__(eval_cfg, run_id, private_paths) + super().__init__(eval_cfg, run_id, private_paths, verbose) # goes looking for the coordinates available for all streams, regions, metrics streams = list(self.eval_cfg.streams.keys()) coord_names = ["sample", "forecast_step", "ens"] @@ -280,7 +281,7 @@ def __init__( message = [f"Some {name}(s) were not common among streams, regions and metrics:"] for val in skipped: message.append(f" {val} only in {provenance[name][val]}") - _logger.warning("\n".join(message)) + self.self._logger.warning("\n".join(message)) def get_samples(self) -> set[int]: return self.common_coords["sample"] @@ -297,16 +298,20 @@ def get_data(self, *args, **kwargs): raise ValueError(f"Missing JSON data for run {self.run_id}.") def get_recomputable_metrics(self, metrics): - _logger.info( - f"The following metrics have not yet been computed:{metrics}. Use type: zarr for that." - ) + if metrics: + self._logger.info( + f"The following metrics have not yet been computed:{metrics}. " + "Use type: zarr for that. Skipping them." + ) return {} class WeatherGenZarrReader(WeatherGenReader): - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): + def __init__( + self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose=True + ): """Data reader class for WeatherGenerator model outputs stored in Zarr format.""" - super().__init__(eval_cfg, run_id, private_paths) + super().__init__(eval_cfg, run_id, private_paths, verbose) zarr_ext = self.inference_cfg.get("zarr_store", "zarr") # for backwards compatibility assume zarr store is local i.e. .zarr format @@ -320,7 +325,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non ): self.fname_zarr = fname_zarr else: - _logger.error(f"Zarr file {self.fname_zarr} does not exist.") + self._logger.error(f"Zarr file {self.fname_zarr} does not exist.") raise FileNotFoundError(f"Zarr file {self.fname_zarr} does not exist") def get_data( @@ -367,7 +372,7 @@ def get_data( with zarrio_reader(self.fname_zarr) as zio: stream_cfg = self.get_stream(stream) all_channels = self.get_channels(stream) - _logger.info(f"RUN {self.run_id}: Processing stream {stream}...") + self._logger.info(f"RUN {self.run_id}: Processing stream {stream}...") fsteps = self.get_forecast_steps() if fsteps is None else fsteps @@ -401,14 +406,14 @@ def get_data( fsteps_final = [] for fstep in fsteps: - _logger.info(f"RUN {self.run_id} - {stream}: Processing fstep {fstep}...") + self._logger.info(f"RUN {self.run_id} - {stream}: Processing fstep {fstep}...") da_tars_fs, da_preds_fs, pps = [], [], [] for sample in tqdm(samples, desc=f"Processing {self.run_id} - {stream} - {fstep}"): out = zio.get_data(sample, stream, fstep) if out.target is None or out.prediction is None: - _logger.info( + self._logger.info( f"Skipping {stream} sample {sample} forecast step: {fstep}. " "No data found." ) @@ -420,31 +425,31 @@ def get_data( pps.append(npoints) if npoints == 0: - _logger.info( + self._logger.info( f"Skipping {stream} sample {sample} forecast step: {fstep}. " "Dataset is empty." ) continue if ensemble == ["mean"]: - _logger.debug("Averaging over ensemble members.") + self._logger.debug("Averaging over ensemble members.") pred = pred.mean("ens", keepdims=True) else: - _logger.debug(f"Selecting ensemble members {ensemble}.") + self._logger.debug(f"Selecting ensemble members {ensemble}.") pred = pred.sel(ens=ensemble) da_tars_fs.append(target.squeeze()) da_preds_fs.append(pred.squeeze()) if not da_tars_fs: - _logger.info( + self._logger.info( f"[{self.run_id} - {stream}] No valid data found for fstep {fstep}." ) continue fsteps_final.append(fstep) - _logger.debug( + self._logger.debug( f"Concatenating targets and predictions for stream {stream}, " f"forecast_step {fstep}..." ) @@ -464,7 +469,7 @@ def get_data( da_preds_fs = xr.concat(da_preds_fs, dim="ipoint") if len(samples) == 1: - _logger.debug("Repeating sample coordinate for single-sample case.") + self._logger.debug("Repeating sample coordinate for single-sample case.") for da in (da_tars_fs, da_preds_fs): da.assign_coords( sample=( @@ -474,7 +479,7 @@ def get_data( ) if set(channels) != set(all_channels): - _logger.debug( + self._logger.debug( f"Restricting targets and predictions to channels {channels} " f"for stream {stream}..." ) @@ -601,7 +606,7 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: ------- A list of ensemble members. """ - _logger.debug(f"Getting ensembles for stream {stream}...") + self._logger.debug(f"Getting ensembles for stream {stream}...") # TODO: improve this to get ensemble from io class with zarrio_reader(self.fname_zarr) as zio: @@ -620,7 +625,7 @@ def is_regular(self, stream: str) -> bool: ------- True if the stream is regularly spaced. False otherwise. """ - _logger.debug(f"Checking regular spacing for stream {stream}...") + self._logger.debug(f"Checking regular spacing for stream {stream}...") with zarrio_reader(self.fname_zarr) as zio: dummy = zio.get_data(0, stream, zio.forecast_steps[0]) @@ -642,10 +647,10 @@ def is_regular(self, stream: str) -> bool: and np.allclose(sorted(da["lon"].values), sorted(da1["lon"].values)) ) ): - _logger.debug("Latitude and/or longitude coordinates are not regularly spaced.") + self._logger.debug("Latitude and/or longitude coordinates are not regularly spaced.") return False - _logger.debug("Latitude and longitude coordinates are regularly spaced.") + self._logger.debug("Latitude and longitude coordinates are regularly spaced.") return True @@ -701,7 +706,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non super().__init__(eval_cfg, run_id, private_paths) self.readers = [] - _logger.info(f"MERGE READERS: {self.run_ids} ...") + self._logger.info(f"MERGE READERS: {self.run_ids} ...") for run_id in self.run_ids: reader = WeatherGenZarrReader(self.eval_cfg, run_id, self.private_paths) @@ -753,7 +758,7 @@ def get_data( for reader in self.readers: da_tars, da_preds, da_fsteps = [], [], [] - _logger.info(f"MERGE READERS: Processing run_id {reader.run_id}...") + self._logger.info(f"MERGE READERS: Processing run_id {reader.run_id}...") out = reader.get_data( stream, @@ -764,7 +769,7 @@ def get_data( ) for fstep in out.target.keys(): - _logger.debug(f"MERGE READERS: Processing fstep {fstep}...") + self._logger.debug(f"MERGE READERS: Processing fstep {fstep}...") da_tars.append(out.target[fstep]) da_preds.append(out.prediction[fstep]) @@ -924,7 +929,7 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: ------- A range of ensemble members equal to the number of merged readers. """ - _logger.debug(f"Getting ensembles for stream {stream}...") + self._logger.debug(f"Getting ensembles for stream {stream}...") all_ensembles = [] for reader in self.readers: all_ensembles.append(reader.get_ensemble(stream)) @@ -949,5 +954,5 @@ def is_regular(self, stream: str) -> bool: ------- True if the stream is regularly spaced. False otherwise. """ - _logger.debug(f"Checking regular spacing for stream {stream}...") + self._logger.debug(f"Checking regular spacing for stream {stream}...") return all(reader.is_regular(stream) for reader in self.readers) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 39030653b..9a13702f6 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -7,14 +7,11 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import logging from collections.abc import Iterable, Sequence import numpy as np import xarray as xr -_logger = logging.getLogger(__name__) - def collect_streams(runs: dict): """Get all unique streams across runs, sorted. @@ -105,14 +102,29 @@ def plot_metric_region( run_ids.append(run_id) if selected_data: - _logger.info(f"Creating plot for {metric} - {region} - {stream} - {ch}.") + plotter._logger.info(f"Creating plot for {metric} - {region} - {stream} - {ch}.") name = create_filename( - prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream, ch] + plotter=plotter, + prefix=[metric, region], + middle=sorted(set(run_ids)), + suffix=[stream, ch], ) selected_data, time_dim = _assign_time_coord(selected_data) + if time_dim != "lead_time": + plotter._logger.warning( + "lead_time coordinate not found for all plotted data; " + "using forecast_step as x-axis." + ) + + if time_dim != "lead_time": + plotter._logger.warning( + "lead_time coordinate not found for all plotted data; " + "using forecast_step as x-axis." + ) + plotter.plot( selected_data, labels, @@ -149,10 +161,6 @@ def _assign_time_coord(selected_data: list[xr.DataArray]) -> tuple[xr.DataArray, ) if "lead_time" not in data.coords and "lead_time" not in data.dims: - _logger.warning( - "lead_time coordinate not found for all plotted data; " - "using forecast_step as x-axis." - ) return selected_data, time_dim # Swap forecast_step with lead_time if all available run_ids have lead_time coord @@ -209,10 +217,13 @@ def ratio_plot_metric_region( run_ids.append(run_id) if len(selected_data) > 0: - _logger.info(f"Creating Ratio plot for {metric} - {stream}") + plotter._logger.info(f"Creating Ratio plot for {metric} - {stream}") name = create_filename( - prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream] + plotter=plotter, + prefix=[metric, region], + middle=sorted(set(run_ids)), + suffix=[stream], ) plotter.ratio_plot( selected_data, @@ -269,12 +280,21 @@ def heat_maps_metric_region( run_ids.append(run_id) if len(selected_data) > 0: - _logger.info(f"Creating Heat maps for {metric} - {stream}") + plotter._logger.info(f"Creating Heat maps for {metric} - {stream}") name = create_filename( - prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream] + plotter=plotter, + prefix=[metric, region], + middle=sorted(set(run_ids)), + suffix=[stream], ) selected_data, time_dim = _assign_time_coord(selected_data) + if time_dim != "lead_time": + plotter._logger.warning( + "lead_time coordinate not found for all plotted data; " + "using forecast_step as x-axis." + ) + plotter.heat_map( selected_data, labels, @@ -319,7 +339,7 @@ def score_card_metric_region( run_ids.append(run_id) if selected_data: - _logger.info(f"Creating score cards for {metric} - {region} - {stream}.") + sc_plotter._logger.info(f"Creating score cards for {metric} - {region} - {stream}.") name = "_".join([metric, region, stream]) sc_plotter.plot(selected_data, run_ids, metric, channels_set, name) @@ -360,7 +380,7 @@ def bar_plot_metric_region( run_ids.append(run_id) if selected_data: - _logger.info(f"Creating bar plots for {metric} - {region} - {stream}.") + br_plotter._logger.info(f"Creating bar plots for {metric} - {region} - {stream}.") name = "_".join([metric, region, stream]) br_plotter.plot(selected_data, run_ids, metric, channels_set, name) @@ -411,6 +431,7 @@ def list_streams(cls): def create_filename( *, + plotter, prefix: Sequence[str] = (), middle: Iterable[str] = (), suffix: Sequence[str] = (), @@ -423,6 +444,8 @@ def create_filename( Parameters ---------- + plotter: + Plotter object to handle the plotting part prefix : Sequence[str] Parts that must appear before the truncated section. middle : Iterable[str] @@ -459,7 +482,7 @@ def create_filename( used += d if len(truncated_middle) < len(mid): - _logger.warning( + plotter._logger.warning( f"Filename truncated: only {len(truncated_middle)} of {len(mid)} middle parts used " f"to keep length <= {max_len}." ) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 2d311dc44..1a58ec16b 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -33,18 +33,15 @@ logging.getLogger("matplotlib.category").setLevel(logging.ERROR) -_logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) - -_logger.debug(f"Taking cartopy paths from {work_dir}") - class Plotter: """ Contains all basic plotting functions. """ - def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | None = None): + def __init__( + self, plotter_cfg: dict, output_basedir: str | Path, stream: str | None = None, verbose=True + ): """ Initialize the Plotter class. @@ -64,8 +61,8 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | Stream identifier for which the plots will be created. It can also be set later via update_data_selection. """ - - _logger.info(f"Taking cartopy paths from {work_dir}") + self._logger = setup_logger(__name__, verbose) + self._logger.info(f"Taking cartopy paths from {work_dir}") self.image_format = plotter_cfg.get("image_format") self.dpi_val = plotter_cfg.get("dpi_val") @@ -79,8 +76,10 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | self.out_plot_basedir = Path(output_basedir) / "plots" + self._logger.debug(f"Taking cartopy paths from {work_dir}") + if not os.path.exists(self.out_plot_basedir): - _logger.info(f"Creating dir {self.out_plot_basedir}") + self._logger.info(f"Creating dir {self.out_plot_basedir}") os.makedirs(self.out_plot_basedir, exist_ok=True) self.sample = None @@ -103,17 +102,19 @@ def update_data_selection(self, select: dict): self.select = select if "sample" not in select: - _logger.warning("No sample in the selection. Might lead to unexpected results.") + self._logger.warning("No sample in the selection. Might lead to unexpected results.") else: self.sample = select["sample"] if "stream" not in select: - _logger.warning("No stream in the selection. Might lead to unexpected results.") + self._logger.warning("No stream in the selection. Might lead to unexpected results.") else: self.stream = select["stream"] if "forecast_step" not in select: - _logger.warning("No forecast_step in the selection. Might lead to unexpected results.") + self._logger.warning( + "No forecast_step in the selection. Might lead to unexpected results." + ) else: self.fstep = select["forecast_step"] @@ -191,7 +192,7 @@ def create_histograms_per_sample( hist_output_dir = self.out_plot_basedir / self.stream / "histograms" if not os.path.exists(hist_output_dir): - _logger.info(f"Creating dir {hist_output_dir}") + self._logger.info(f"Creating dir {hist_output_dir}") os.makedirs(hist_output_dir) for var in variables: @@ -210,19 +211,19 @@ def create_histograms_per_sample( if self.plot_subtimesteps: ntimes_unique = len(np.unique(targ.valid_time)) - _logger.info( + self._logger.info( f"Creating histograms for {ntimes_unique} valid times of variable {var}." ) groups = zip(targ.groupby("valid_time"), prd.groupby("valid_time"), strict=False) else: - _logger.info(f"Plotting histogram for all valid times of {var}") + self._logger.info(f"Plotting histogram for all valid times of {var}") groups = [((None, targ), (None, prd))] # wrap once with dummy valid_time for (valid_time, targ_t), (_, prd_t) in groups: if valid_time is not None: - _logger.debug(f"Plotting histogram for {var} at valid_time {valid_time}") + self._logger.debug(f"Plotting histogram for {var} at valid_time {valid_time}") name = self.plot_histogram(targ_t, prd_t, hist_output_dir, var, tag=tag) plot_names.append(name) @@ -297,7 +298,7 @@ def plot_histogram( name = "_".join(filter(None, parts)) fname = hist_output_dir / f"{name}.{self.image_format}" - _logger.debug(f"Saving histogram to {fname}") + self._logger.debug(f"Saving histogram to {fname}") plt.savefig(fname) plt.close() @@ -352,7 +353,7 @@ def create_maps_per_sample( map_output_dir = self.get_map_output_dir(tag) if not os.path.exists(map_output_dir): - _logger.info(f"Creating dir {map_output_dir}") + self._logger.info(f"Creating dir {map_output_dir}") os.makedirs(map_output_dir) for region in self.regions: @@ -369,22 +370,22 @@ def create_maps_per_sample( if self.plot_subtimesteps: ntimes_unique = len(np.unique(da.valid_time)) - _logger.info( + self._logger.info( f"Creating maps for {ntimes_unique} valid times of variable {var} - {tag}" ) if ntimes_unique == 0: - _logger.warning( + self._logger.warning( f"No valid times found for variable {var} - {tag}. Skipping." ) continue groups = da.groupby("valid_time") else: - _logger.info(f"Creating maps for all valid times of {var} - {tag}") + self._logger.info(f"Creating maps for all valid times of {var} - {tag}") groups = [(None, da)] # single dummy group for valid_time, da_t in groups: if valid_time is not None: - _logger.debug(f"Plotting map for {var} at valid_time {valid_time}") + self._logger.debug(f"Plotting map for {var} at valid_time {valid_time}") da_t = da_t.dropna(dim="ipoint") assert da_t.size > 0, "Data array must not be empty or contain only NAs" @@ -540,7 +541,7 @@ def scatter_plot( name = "_".join(filter(None, parts)) fname = f"{map_output_dir.joinpath(name)}.{self.image_format}" - _logger.debug(f"Saving map to {fname}") + self._logger.debug(f"Saving map to {fname}") plt.savefig(fname) plt.close() @@ -578,7 +579,7 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]: for region in self.regions: for _, sa in enumerate(samples): for _, var in enumerate(variables): - _logger.info(f"Creating animation for {var} sample: {sa} - {tag}") + self._logger.info(f"Creating animation for {var} sample: {sa} - {tag}") image_paths = [] for _, fstep in enumerate(fsteps): # TODO: refactor to avoid code duplication with scatter_plot @@ -612,7 +613,7 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]: ) else: - _logger.warning(f"No images found for animation {var} sample {sa}") + self._logger.warning(f"No images found for animation {var} sample {sa}") return image_paths @@ -621,7 +622,7 @@ def get_map_output_dir(self, tag): class LinePlots: - def __init__(self, plotter_cfg: dict, output_basedir: str | Path): + def __init__(self, plotter_cfg: dict, output_basedir: str | Path, verbose=True): """ Initialize the LinePlots class. @@ -643,6 +644,7 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path): Base directory under which the plots will be saved. Expected scheme `/`. """ + self._logger = setup_logger(__name__, verbose) self.image_format = plotter_cfg.get("image_format") self.dpi_val = plotter_cfg.get("dpi_val") @@ -653,10 +655,10 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path): self.baseline = plotter_cfg.get("baseline") self.out_plot_dir = Path(output_basedir) / "line_plots" if not os.path.exists(self.out_plot_dir): - _logger.info(f"Creating dir {self.out_plot_dir}") + self._logger.info(f"Creating dir {self.out_plot_dir}") os.makedirs(self.out_plot_dir, exist_ok=True) - _logger.info(f"Saving summary plots to: {self.out_plot_dir}") + self._logger.info(f"Saving summary plots to: {self.out_plot_dir}") def _check_lengths(self, data: xr.DataArray | list, labels: str | list) -> tuple[list, list]: """ @@ -695,12 +697,12 @@ def print_all_points_from_graph(self, fig: plt.Figure) -> None: ydata = line.get_ydata() xdata = line.get_xdata() label = line.get_label() - _logger.info(f"Summary for {label} plot:") + self._logger.info(f"Summary for {label} plot:") for xi, yi in zip(xdata, ydata, strict=False): xi = xi if isinstance(xi, str) else f"{float(xi):.3f}" yi = yi if isinstance(yi, str) else f"{float(yi):.3f}" - _logger.info(f" x: {xi}, y: {yi}") - _logger.info("--------------------------") + self._logger.info(f" x: {xi}, y: {yi}") + self._logger.info("--------------------------") return def _plot_ensemble(self, data: xr.DataArray, x_dim: str, label: str) -> None: @@ -770,7 +772,7 @@ def _plot_ensemble(self, data: xr.DataArray, x_dim: str, label: str) -> None: alpha=0.2, ) else: - _logger.warning( + self._logger.warning( f"LinePlot:: Unknown option for plot_ensemble: {self.plot_ensemble}. " "Skipping ensemble plotting." ) @@ -802,7 +804,7 @@ def _preprocess_data( non_x_dims = [dim for dim in data.dims if dim not in x_dims] if any(data.sizes.get(dim, 1) > 1 for dim in non_x_dims) and verbose: - logging.info(f"Averaging over dimensions: {non_x_dims}") + self._logger.info(f"Averaging over dimensions: {non_x_dims}") out = data.mean(dim=non_x_dims, skipna=True) @@ -856,7 +858,7 @@ def plot( non_zero_dims = [dim for dim in data.dims if dim != x_dim and data[dim].shape[0] > 1] if self.plot_ensemble and "ens" in non_zero_dims: - _logger.info(f"LinePlot:: Plotting ensemble with option {self.plot_ensemble}.") + self._logger.info(f"LinePlot:: Plotting ensemble with option {self.plot_ensemble}.") self._plot_ensemble(data, x_dim, label_list[i]) else: averaged = self._preprocess_data(data, x_dim) @@ -920,7 +922,7 @@ def _plot_base( plt.yscale("log") if print_summary: - _logger.info(f"Summary values for {name}") + self._logger.info(f"Summary values for {name}") self.print_all_points_from_graph(fig) if line: @@ -989,7 +991,7 @@ def ratio_plot( baseline_name = self.baseline baseline_idx = run_ids.index(self.baseline) if self.baseline in run_ids else None if baseline_idx is not None: - _logger.info(f"Using baseline run ID '{self.baseline}' for ratio plot.") + self._logger.info(f"Using baseline run ID '{self.baseline}' for ratio plot.") baseline = data_list[baseline_idx] else: @@ -1091,7 +1093,7 @@ def heat_map( ref = self._preprocess_data(ref, "channel", verbose=False) if ref.isnull().all(): - _logger.warning( + self._logger.warning( f"Heatmap:: Reference data for metric {metric} and label {label} contains " "only NaNs. Skipping heatmap." ) @@ -1148,14 +1150,17 @@ class ScoreCards: Base directory under which the score cards will be saved. """ - def __init__(self, plotter_cfg: dict, output_basedir: str | Path) -> None: + def __init__(self, plotter_cfg: dict, output_basedir: str | Path, verbose=True) -> None: self.image_format = plotter_cfg.get("image_format") self.dpi_val = plotter_cfg.get("dpi_val") self.improvement = plotter_cfg.get("improvement_scale", 0.2) self.out_plot_dir = Path(output_basedir) / "score_cards" self.baseline = plotter_cfg.get("baseline") + + self._logger = setup_logger(__name__, verbose) + if not os.path.exists(self.out_plot_dir): - _logger.info(f"Creating dir {self.out_plot_dir}") + self._logger.info(f"Creating dir {self.out_plot_dir}") os.makedirs(self.out_plot_dir, exist_ok=True) def plot( @@ -1275,7 +1280,7 @@ def plot( ] plt.legend(handles=legend, loc="upper left", bbox_to_anchor=(1.02, 1.0)) - _logger.info(f"Saving scorecards to: {self.out_plot_dir}") + self._logger.info(f"Saving scorecards to: {self.out_plot_dir}") parts = ["score_card", tag] + runs name = "_".join(filter(None, parts)) @@ -1341,7 +1346,9 @@ def compare_models( baseline_var = baseline.sel({"channel": var}) data_var = data[run_index].sel({"channel": var}) - baseline_score, model_score = calculate_average_over_dim(x_dim, baseline_var, data_var) + baseline_score, model_score = calculate_average_over_dim( + x_dim, baseline_var, data_var, self._logger + ) diff = baseline_score - model_score skill = self.get_skill_score(model_score, baseline_score, metric) @@ -1482,15 +1489,19 @@ class BarPlots: Base directory under which the score cards will be saved. """ - def __init__(self, plotter_cfg: dict, output_basedir: str | Path) -> None: + def __init__( + self, plotter_cfg: dict, output_basedir: str | Path, verbose: bool = False + ) -> None: self.image_format = plotter_cfg.get("image_format") self.dpi_val = plotter_cfg.get("dpi_val") self.cmap = plotter_cfg.get("cmap", "bwr") self.out_plot_dir = Path(output_basedir) / "bar_plots" self.baseline = plotter_cfg.get("baseline") - _logger.info(f"Saving bar plots to: {self.out_plot_dir}") + + self._logger = setup_logger(__name__, verbose) + self._logger.info(f"Saving bar plots to: {self.out_plot_dir}") if not os.path.exists(self.out_plot_dir): - _logger.info(f"Creating dir {self.out_plot_dir}") + self._logger.info(f"Creating dir {self.out_plot_dir}") os.makedirs(self.out_plot_dir, exist_ok=True) def plot( @@ -1532,7 +1543,7 @@ def plot( runs = [runs[baseline_idx]] + runs[:baseline_idx] + runs[baseline_idx + 1 :] data = [data[baseline_idx]] + data[:baseline_idx] + data[baseline_idx + 1 :] elif len(runs) < 2: - _logger.warning( + self._logger.warning( "BarPlots:: Less than two runs provided. Generating bar plot against ones." ) ones_array = xr.full_like(data[0], 1.0) @@ -1578,7 +1589,7 @@ def plot( transform=ax[run_index - 1].transAxes, ) - _logger.info(f"Saving bar plots to: {self.out_plot_dir}") + self._logger.info(f"Saving bar plots to: {self.out_plot_dir}") parts = ["bar_plot", tag] + runs name = "_".join(filter(None, parts)) plt.savefig( @@ -1627,7 +1638,9 @@ def calc_ratio_per_run_id( data_var = data[run_index].sel({"channel": var}) channels_per_comparison.append(var) - baseline_score, model_score = calculate_average_over_dim(x_dim, baseline_var, data_var) + baseline_score, model_score = calculate_average_over_dim( + x_dim, baseline_var, data_var, self._logger + ) ratio_score.append(model_score / baseline_score) @@ -1664,7 +1677,7 @@ def colors(self, ratio_score: np.array, metric: str) -> list[tuple]: def calculate_average_over_dim( - x_dim: str, baseline_var: xr.DataArray, data_var: xr.DataArray + x_dim: str, baseline_var: xr.DataArray, data_var: xr.DataArray, logger ) -> tuple[xr.DataArray, xr.DataArray]: """ Calculate average over xarray dimensions that are larger than 1. Those might be the @@ -1678,6 +1691,8 @@ def calculate_average_over_dim( xarray DataArray with the scores of the baseline model for a specific channel/variable data_var: xr.DataArray xarray DataArray with the scores of the comparison model for a specific channel/variable + logger: logging.Logger + Logger instance for logging information Returns ------- @@ -1691,7 +1706,7 @@ def calculate_average_over_dim( ] if non_zero_dims: - _logger.info(f"Found multiple entries for dimensions: {non_zero_dims}. Averaging...") + logger.info(f"Found multiple entries for dimensions: {non_zero_dims}. Averaging...") baseline_score = baseline_var.mean( dim=[dim for dim in baseline_var.dims if dim != x_dim], skipna=True @@ -1742,3 +1757,26 @@ def channel_sort_key(name: str) -> tuple[int, str, int]: return (0, prefix, int(number)) else: return (1, name, float("inf")) + + +def setup_logger(name: str, verbose: bool) -> logging.Logger: + """ + Set up a logger with the specified name and verbosity level. + + Parameters + ---------- + name : str + Name of the logger. + verbose : bool + If True, set logging level to INFO; otherwise, set to WARNING. + + Returns + ------- + logging.Logger + Configured logger instance. + """ + + logger = logging.getLogger(name) + logging_level = logging.INFO if verbose else logging.CRITICAL + 1 + logger.setLevel(logging_level) + return logger diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 98326444f..413815f3e 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -83,11 +83,26 @@ def setup_main_logger(log_file: str | None, log_queue: mp.Queue) -> QueueListene return listener -def setup_worker_logger(log_queue: mp.Queue) -> logging.Logger: - """""" +def setup_worker_logger(log_queue: mp.Queue, verbose: bool) -> logging.Logger: + """ + Set up worker process logger with QueueHandler. + Parameters + ---------- + log_queue: + Multiprocessing queue for logging. + verbose: + Verbosity flag. + Returns + ------- + Configured logger. + """ + if verbose: + logging_level = logging.INFO + else: + logging_level = logging.CRITICAL qh = QueueHandler(log_queue) logger = logging.getLogger() - logger.setLevel(logging.INFO) + logger.setLevel(logging_level) logger.handlers.clear() logger.addHandler(qh) return logger @@ -164,15 +179,16 @@ def get_reader( private_paths: dict[str, str], region: str | None = None, metric: str | None = None, + verbose: bool = True, ): if reader_type == "zarr": - reader = WeatherGenZarrReader(run, run_id, private_paths) + reader = WeatherGenZarrReader(run, run_id, private_paths, verbose) elif reader_type == "csv": - reader = CsvReader(run, run_id, private_paths) + reader = CsvReader(run, run_id, private_paths, verbose) elif reader_type == "json": - reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric) + reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric, verbose) elif reader_type == "merge": - reader = WeatherGenMergeReader(run, run_id, private_paths) + reader = WeatherGenMergeReader(run, run_id, private_paths, verbose) else: raise ValueError(f"Unknown reader type: {reader_type}") return reader @@ -280,6 +296,7 @@ def evaluate_from_config( plot_score_maps = cfg.evaluation.get("plot_score_maps", False) global_plotting_opts = cfg.get("global_plotting_options", {}) use_parallel = cfg.evaluation.get("num_processes", 0) + verbose = cfg.evaluation.get("verbose", True) default_streams = cfg.get("default_streams", {}) if use_parallel == "auto": @@ -311,7 +328,7 @@ def evaluate_from_config( regions = cfg.evaluation.regions metrics = cfg.evaluation.metrics - reader = get_reader(type_, run, run_id, private_paths, regions, metrics) + reader = get_reader(type_, run, run_id, private_paths, regions, metrics, verbose) for stream in reader.streams: tasks.append( @@ -330,13 +347,13 @@ def evaluate_from_config( scores_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) if num_processes == 0: if log_queue is not None: - setup_worker_logger(log_queue) + setup_worker_logger(log_queue, verbose) results = [_process_stream(**task) for task in tasks] else: with mp.Pool( processes=num_processes, initializer=setup_worker_logger, - initargs=(log_queue,), + initargs=(log_queue, verbose), ) as pool: results = pool.map( _process_stream_wrapper, @@ -383,7 +400,7 @@ def evaluate_from_config( # summary plots if scores_dict: _logger.info("Started creating summary plots...") - plot_summary(cfg, scores_dict, summary_dir) + plot_summary(cfg, scores_dict, summary_dir, verbose=verbose) if __name__ == "__main__": diff --git a/packages/evaluate/src/weathergen/evaluate/ssl/README.md b/packages/evaluate/src/weathergen/evaluate/ssl/README.md new file mode 100644 index 000000000..122e861fe --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/ssl/README.md @@ -0,0 +1,167 @@ +# Multi-Stream Self Supervised Learning Inference & Evaluation + +This script runs **inference and evaluation for a multi-stream weather generation model** and produces quantitative metrics and diagnostic plots for each stream. + +The workflow is designed to: + +1. Run inference from an existing trained model run +2. Evaluate forecasts across multiple data streams +3. Print training, validation, and evaluation summaries + +--- + +## Supported Streams + +The script currently evaluates the following streams: + +```python +streams = ["ERA5", "SurfaceCombined", "NPPATMS"] +``` + +Each stream has its own channels, metrics, and plotting configuration. + +--- + +## Overview of the Pipeline + +### 1. Inference (`infer_multi_stream`) + +* Runs inference using a trained model (`run_id`) +* Creates a new run ID with `_inf` suffix +* Generates forecasts over a fixed date range +* Outputs results in Zarr format + +**Inference period** + +* Start: `2021-10-10` +* End: `2022-10-11` +* Samples per forecast: `10` + +--- + +### 2. Evaluation Configuration (`get_evaluation_config`) + +Defines: + +* Metrics: `rmse`, `froct` +* Regions: `global` +* Plotting options (maps, histograms, summary plots) +* Stream-specific channels and evaluation settings + +Each stream is evaluated independently but summarized jointly. + +--- + +### 3. Evaluation (`evaluate_multi_stream_results`) + +* Loads inference outputs +* Computes metrics per stream, region, and forecast step +* Saves plots and summaries under: + +``` +./results/_inf/plots/summary/ +``` + +--- + +### 4. Reporting Utilities + +#### Print Losses (`print_losses`) + +Prints training or validation losses for each stream: + +* Uses `LossPhysical..mse.avg` +* Supports `train` and `val` stages + +#### Print Evaluation Results (`print_evaluation_results`) + +Prints mean evaluation scores averaged over: + +* samples +* forecast steps +* ensemble members + +Results are grouped by: + +* stream +* metric +* region + +--- + +## File Structure Expectations + +The script assumes the following directory layout: + +``` +. +├── models/ +│ └── +│ ├── _latest.chkpt +│ ├── model_.json +├── results/ +│ └── _inf/ +│ ├── metrics.json +│ └── plots/ + +``` + +--- + +## Running the Script + +### Command Line + +```bash +uv run --offline ssl_analysis --run-id +``` + +Optional verbose mode: + +```bash +uv run --offline ssl_analysis --run_id --run-id --verbose +``` + +--- + +## Execution Flow (Main) + +When executed, the script performs the following steps: + +1. **Inference** + + * Creates a new inference run ID + * Generates forecasts from the trained run + +2. **Evaluation** + + * Computes metrics and generates plots + +3. **Reporting** + + * Prints training losses from the original run + * Prints validation losses from the inference run + * Prints evaluation metrics for all streams + +--- + +## Metrics & Outputs + +### Metrics (for now) + +* RMSE +* FROCT + +### Outputs + +* Summary plots +* Per-stream maps and histograms +* Console summaries of losses and evaluation scores + +--- + +## Notes & Assumptions + +* The original training run must already exist under `./results/` +* The `_inf` suffix is currently used for inference run naming +* Forecast steps and samples default to `"all"` during evaluation diff --git a/packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.py b/packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.py new file mode 100644 index 000000000..18d848815 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.py @@ -0,0 +1,271 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +Script to run Inference and evaluation for the Weather Generator with +multiple streams and observations. +This script must run on a GPU machine. +It performs a standardised routine of inference and evaluation +over multiple data sources including gridded and obs data. + +Command: +uv run --offline ssl_analysis --run-id j2dkivn8 +""" + +import argparse +import json +import logging +import sys +from pathlib import Path + +import omegaconf + +from weathergen.evaluate.io.wegen_reader import ( + WeatherGenJSONReader, +) +from weathergen.evaluate.run_evaluation import evaluate_from_config +from weathergen.run_train import inference_from_args +from weathergen.utils.metrics import get_train_metrics_path + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +if not logger.handlers: + h = logging.StreamHandler(sys.stdout) + h.setLevel(logging.INFO) + logger.addHandler(h) + +logger.propagate = False + +# TODO: define WEATHERGEN_HOME properly to avoid partial paths + +streams = ["ERA5", "SurfaceCombined", "NPPATMS"] + +############## INFERENCE ################# + + +def infer_multi_stream(run_id): + """Run inference for multi-stream model.""" + logger.info("run multi-stream inference") + new_run_id = run_id + "_inf" # TODO: better naming + inference_from_args( + [ + "-start", + "2021-10-10", + "-end", + "2022-10-11", + "--samples", + "10", + "--options", + # "training_config.forecast.offset=0", + # "training_config.forecast.num_steps=0" + "zarr_store=zip", + ] + + ["--from-run-id", run_id, "--run-id", new_run_id, "--streams-output"] + + streams + # + [ + # "--config", + # "./config/evaluate/latent_space_eval_config.yaml", + # ] + ) + return new_run_id + + +############## EVALUATION ################# + + +def get_evaluation_config(run_id, verbose=False): + """Create evaluation configuration for multiple streams.""" + cfg = omegaconf.OmegaConf.create( + { + "global_plotting_options": { + "image_format": "png", + "dpi_val": 300, + }, + "evaluation": { + "regions": ["global"], + "metrics": ["rmse", "froct"], + "summary_plots": True, + "summary_dir": f"./results/{run_id}/plots/summary/", + "print_summary": False, + "verbose": verbose, + "ratio_plots": True, + "bar_plots": True, + }, + "run_ids": { + run_id: { + "streams": { + "ERA5": { + "channels": ["q_850", "z_500", "2t", "10u", "10v", "msl"], + "evaluation": {"forecast_steps": "all", "sample": "all"}, + "plotting": { + "sample": [0, 1], + "forecast_step": [0], + "plot_maps": True, + "plot_histograms": True, + "plot_animations": False, + }, + }, + "SurfaceCombined": { + "channels": ["obsvalue_t2m_0"], + "evaluation": {"forecast_steps": "all", "sample": "all"}, + "plotting": { + "sample": [0, 1], + "forecast_step": [0], + "plot_maps": True, + "plot_histograms": True, + "plot_animations": False, + }, + }, + "NPPATMS": { + "channels": ["obsvalue_rawbt_1"], + "evaluation": {"forecast_steps": "all", "sample": "all"}, + "plotting": { + "sample": [0, 1], + "forecast_step": [0], + "plot_maps": True, + "plot_histograms": True, + "plot_animations": False, + }, + }, + }, + "label": "Multi-Stream Test", + "mini_epoch": 0, + "rank": 0, + } + }, + } + ) + return cfg + + +def evaluate_multi_stream_results(run_id, verbose=False): + """Run evaluation for multiple streams.""" + + logger.info("run multi-stream evaluation") + cfg = get_evaluation_config(run_id, verbose=verbose) + try: + evaluate_from_config(cfg, None, None) + except FileNotFoundError as e: + logger.error(f"Error during evaluation: {e}") + + +############## PRINT FUNCTIONS ################# + + +def print_losses(run_id, stage="val"): + """Print validation losses for specified streams.""" + logger.info(f"{stage.capitalize()} Losses for run_id: {run_id}") + metrics = load_metrics(run_id) + + losses = {} + + for stream_name in streams: + loss = next( + ( + metric.get(f"LossPhysical.{stream_name}.mse.avg") + for metric in reversed(metrics) + if metric.get("stage") == stage + ), + None, + ) + + losses[stream_name] = loss + stage_label = "Train" if stage == "train" else "Validation" + # TODO: understand why logger is not working + logger.info( + f"{stage_label} losses – " + ", ".join(f"{k}: {v:.4f}" for k, v in losses.items()) + "\n" + ) + + +def print_evaluation_results(run_id, verbose=False): + """Print evaluation results for specified streams.""" + + eval_cfg = get_evaluation_config(run_id, verbose=verbose) + scores = load_scores(eval_cfg, run_id) + + metrics = list(eval_cfg.evaluation.get("metrics")) + regions = list(eval_cfg.evaluation.get("regions")) + for stream_name in streams: + stream_scores = scores[stream_name] + + for metric in metrics: + logger.info("------------------------------------------") + for region in regions: + da = stream_scores[metric][region][stream_name][run_id] + logger.info(f"\nEvaluation scores for {region} {stream_name} {metric}:") + + mean_da = da.mean(dim=["sample", "forecast_step", "ens"]) + logger.info(mean_da.to_dataframe(name=f"{metric} {region} {stream_name}")) + + +############## HELPERS ################# + + +def load_metrics(run_id): + """Helper function to load metrics""" + + file_path = get_train_metrics_path(base_path=Path("./results"), run_id=run_id) + + if not file_path.is_file(): + raise FileNotFoundError(f"Metrics file not found for run_id: {run_id}") + with open(file_path) as f: + json_str = f.readlines() + return json.loads("[" + "".join([s.replace("\n", ",") for s in json_str])[:-1] + "]") + + +def load_scores(eval_cfg, run_id): + """Helper function to load metrics""" + run_cfg = eval_cfg.run_ids[run_id] + + metrics = list(eval_cfg.evaluation.get("metrics")) + regions = list(eval_cfg.evaluation.get("regions")) + verbose = eval_cfg.get("verbose", False) + + reader = WeatherGenJSONReader(run_cfg, run_id, None, regions, metrics, verbose=verbose) + scores = {} + + for stream_name in streams: + stream_loaded_scores, _ = reader.load_scores( + stream_name, + regions, + metrics, + ) + + scores[stream_name] = stream_loaded_scores + + return scores + + +############## MAIN ################# + + +def ssl_analysis(): + parser = argparse.ArgumentParser(description="Run multi-stream SSL evaluation") + parser.add_argument( + "--run-id", type=str, required=True, help="Run identifier for the model to evaluate" + ) + parser.add_argument( + "--verbose", action="store_true", help="Enable verbose output", default=False + ) + args = parser.parse_args() + + run_id = args.run_id + verbose = args.verbose + + infer_run_id = infer_multi_stream(run_id) + + # Evaluate results + evaluate_multi_stream_results(infer_run_id, verbose=verbose) + logger.info("\n\nFinal Results Summary: \n") + logger.info("TRAINING & INFERENCE LOSSES: \n") + print_losses(run_id, stage="train") + print_losses(infer_run_id, stage="val") + logger.info("EVALUATION: \n") + print_evaluation_results(infer_run_id, verbose=verbose) diff --git a/packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.yaml b/packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.yaml new file mode 100644 index 000000000..a4ab555bc --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.yaml @@ -0,0 +1,125 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# forecast_freeze_model: False +# forecast_att_dense_rate: 1.0 +# fe_num_blocks: 2 +# fe_num_heads: 16 +# fe_dropout_rate: 0.1 +# fe_with_qk_lnorm: True +# fe_layer_norm_after_blocks: [] +# impute_latent_noise_std: 0.0 + +# healpix_level: 4 + +################ + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # paths in the private config + # data_path_*, + # model_path, + # run_path, + # path_shared_ + # multiprocessing_method + + desc: "" + + run_id: ??? + run_history: [] + +train_log_freq: + terminal: 10 + metrics: 20 + checkpoint: 100 + +# config for training +training_config: + + # training_mode: ["masking"] + + # num_mini_epochs: 1 + # samples_per_mini_epoch: 512 #128 #256 + # shuffle: True + + # start_date: 2012-01-01T00:00 + # end_date: 2020-12-31T00:00 + + # time_window_step: 06:00:00 + # time_window_len: 06:00:00 + + # window_offset_prediction : 1 + + # learning_rate_scheduling : + # lr_start: 1e-6 + # lr_max: 0.00005 + # lr_final_decay: 1e-6 + # lr_final: 0.0 + # num_steps_warmup: 20 + # num_steps_cooldown: 10 + # policy_warmup: "cosine" + # policy_decay: "constant" + # policy_cooldown: "linear" + # parallel_scaling_policy: "sqrt" + + # optimizer: + # grad_clip: 1.0 + # weight_decay: 0.1 + # log_grad_norms: False + # adamw : + # # parameters are scaled by number of DDP workers + # beta1 : 0.975 + # beta2 : 0.9875 + # eps : 2e-08 + + # losses : { + # "physical": { + # type: LossPhysical, + # loss_fcts: { "mse": { }, }, + # }, + # } + + model_input: { + "forecasting" : { + masking_strategy: "forecast", + } + } + + forecast : + time_step: 06:00:00 + num_steps: 2 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +# validation_config: + +# samples_per_mini_epoch: 32 +# shuffle: False + +# start_date: 2021-10-10T00:00 +# end_date: 2022-10-11T00:00 + +# validate_with_ema: +# enabled : True +# ema_ramp_up_ratio: 0.09 +# ema_halflife_in_thousands: 1e-3 + +# write_num_samples: 0 + +# test_config: + +# write_num_samples: 2 diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index d911a370b..10ae1c694 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -156,7 +156,6 @@ def calc_scores_per_stream( _logger.debug( f"Applying bounding box mask for region '{region}' to targets and predictions." ) - tars, preds, tars_next, preds_next = [ bbox.apply_mask(x) if x is not None else None for x in (tars, preds, tars_next, preds_next) @@ -316,7 +315,9 @@ def _plot_score_maps_per_stream( plotter.scatter_plot(data, map_dir, channel, region, tag=tag, title=title) -def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: +def plot_data( + reader: Reader, stream: str, global_plotting_opts: dict, verbose: bool = True +) -> None: """ Plot the data for a given run and stream. @@ -328,6 +329,8 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: Stream name to plot data for. global_plotting_opts: dict Dictionary containing all plotting options that apply globally to all run_ids + verbose: bool + Option to print verbose log messages """ run_id = reader.run_id @@ -356,7 +359,7 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: "regions": global_plotting_opts.get("regions", ["global"]), "plot_subtimesteps": reader.get_inference_stream_attr(stream, "tokenize_spacetime", False), } - plotter = Plotter(plotter_cfg, reader.runplot_dir) + plotter = Plotter(plotter_cfg, reader.runplot_dir, verbose=verbose) available_data = reader.check_availability(stream, mode="plotting") @@ -501,7 +504,7 @@ def metric_list_to_json( ) -def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): +def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path, verbose: bool = True) -> None: """ Plot summary of the evaluation results. This function is a placeholder for future implementation. @@ -512,6 +515,10 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): Configuration dictionary containing all information for the evaluation. scores_dict : Dictionary containing scores for each metric and stream. + summary_dir : + Directory where the summary plots will be saved. + verbose: bool + Option to print verbose log messages """ _logger.info("Plotting summary of evaluation results...") runs = cfg.run_ids @@ -531,9 +538,9 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): "baseline": eval_opt.get("baseline", None), } - plotter = LinePlots(plot_cfg, summary_dir) - sc_plotter = ScoreCards(plot_cfg, summary_dir) - br_plotter = BarPlots(plot_cfg, summary_dir) + plotter = LinePlots(plot_cfg, summary_dir, verbose=verbose) + sc_plotter = ScoreCards(plot_cfg, summary_dir, verbose=verbose) + br_plotter = BarPlots(plot_cfg, summary_dir, verbose=verbose) for region in regions: for metric in metrics: if eval_opt.get("summary_plots", True):