From c5c1a05fbc69ef3f369558db3982b3e8fc7a73a3 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Fri, 16 Jan 2026 16:56:48 +0100 Subject: [PATCH 01/44] latent_space evaluation scripts + propagate verbose --- config/default_config.yml | 2 +- config/default_forecast_config.yml | 2 +- .../src/weathergen/evaluate/io/csv_reader.py | 8 +- .../src/weathergen/evaluate/io/io_reader.py | 18 +- .../weathergen/evaluate/io/wegen_reader.py | 68 ++--- .../latent_space/latent_space_eval.py | 237 ++++++++++++++++++ .../latent_space/latent_space_eval.yaml | 125 +++++++++ .../evaluate/plotting/plot_utils.py | 43 ++-- .../weathergen/evaluate/plotting/plotter.py | 119 +++++---- .../src/weathergen/evaluate/run_evaluation.py | 27 +- .../src/weathergen/evaluate/utils/utils.py | 20 +- 11 files changed, 548 insertions(+), 121 deletions(-) create mode 100644 packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py create mode 100644 packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.yaml diff --git a/config/default_config.yml b/config/default_config.yml index 0de5c8846..ea5f4a98d 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -87,7 +87,7 @@ norm_type: "LayerNorm" ##################################### -streams_directory: "./config/streams/era5_1deg/" +streams_directory: "./config/streams/cerra_era5/" streams: ??? # type of zarr_store diff --git a/config/default_forecast_config.yml b/config/default_forecast_config.yml index 0080ed252..e74b9df23 100644 --- a/config/default_forecast_config.yml +++ b/config/default_forecast_config.yml @@ -1,4 +1,4 @@ -streams_directory: "./config/streams/era5_1deg/" +streams_directory: "./config/streams/cerra_era5/" embed_orientation: "channels" embed_unembed_mode: "block" diff --git a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py index 1916c3671..730e79655 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py @@ -20,8 +20,8 @@ # Local application / package from weathergen.evaluate.io.io_reader import Reader -_logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) +# _logger = logging.getLogger(__name__) +# _logger.setLevel(logging.INFO) class CsvReader(Reader): @@ -29,7 +29,7 @@ class CsvReader(Reader): Reader class to read evaluation data from CSV files and convert to xarray DataArray. """ - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose = True): """ Initialize the CsvReader. @@ -43,7 +43,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non list of private paths for the supported HPC """ - super().__init__(eval_cfg, run_id, private_paths) + super().__init__(eval_cfg, run_id, private_paths, verbose) self.metrics_dir = Path(self.eval_cfg.get("metrics_dir")) self.metrics_base_dir = self.metrics_dir diff --git a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py index 92e6748fd..527d2e4d7 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py @@ -15,9 +15,6 @@ # Third-party import xarray as xr -_logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) - @dataclass class ReaderOutput: @@ -62,7 +59,7 @@ class DataAvailability: class Reader: - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | None = None): + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | None = None, verbose = True): """ Generic data reader class. @@ -89,6 +86,11 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | "results_base_dir", None ) # base directory where results will be stored + self._logger = logging.getLogger(__name__) + + logger_level = logging.INFO if verbose else logging.WARNING + self._logger.setLevel(logger_level) + def get_stream(self, stream: str): """ returns the dictionary associated to a particular stream @@ -217,7 +219,7 @@ def check_availability( requested[name] = reader_data[name] # If file with metrics exists, must exactly match if available_data is not None and reader_data[name] != available[name]: - _logger.info( + self._logger.info( f"Requested all {name}s for {mode}, but previous config was a " "strict subset. Recomputing." ) @@ -230,7 +232,7 @@ def check_availability( if name == "ensemble" and "mean" in missing: missing.remove("mean") if missing: - _logger.info( + self._logger.info( f"Requested {name}(s) {missing} is unavailable. " f"Removing missing {name}(s) for {mode}." ) @@ -240,14 +242,14 @@ def check_availability( # Must be a subset of available_data (if provided) if available_data is not None and not requested[name] <= available[name]: missing = requested[name] - available[name] - _logger.info( + self._logger.info( f"{name.capitalize()}(s) {missing} missing in previous evaluation. Recomputing." ) check_score = False if check_score and not corrected: scope = "metric file" if available_data is not None else "Zarr file" - _logger.info( + self._logger.info( f"All checks passed – All channels, samples, fsteps requested for {mode} are " f"present in {scope}..." ) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 30cecb80e..01cef980e 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -30,13 +30,13 @@ from weathergen.evaluate.scores.score_utils import to_list from weathergen.evaluate.utils.derived_channels import DeriveChannels -_logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) +# _logger = logging.getLogger(__name__) +# _logger.setLevel(logging.INFO) class WeatherGenReader(Reader): - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): - super().__init__(eval_cfg, run_id, private_paths) + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose = True): + super().__init__(eval_cfg, run_id, private_paths, verbose) # TODO: remove backwards compatibility to "epoch" in Feb. 2026 self.mini_epoch = eval_cfg.get("mini_epoch", eval_cfg.get("epoch")) @@ -46,9 +46,9 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non if not self.results_base_dir: self.results_base_dir = Path(get_shared_wg_path("results")) - _logger.info(f"Results directory obtained from private config: {self.results_base_dir}") + self._logger.info(f"Results directory obtained from private config: {self.results_base_dir}") else: - _logger.info(f"Results directory parsed: {self.results_base_dir}") + self._logger.info(f"Results directory parsed: {self.results_base_dir}") self.runplot_base_dir = Path( self.eval_cfg.get("runplot_base_dir", self.results_base_dir) @@ -80,18 +80,18 @@ def get_inference_config(self): configuration file from the inference run """ if self.private_paths: - _logger.info( + self._logger.info( f"Loading config for run {self.run_id} from private paths: {self.private_paths}" ) config = load_merge_configs(self.private_paths, self.run_id, self.mini_epoch) else: - _logger.info( + self._logger.info( f"Loading config for run {self.run_id} from model directory: {self.model_base_dir}" ) config = load_run_config(self.run_id, self.mini_epoch, self.model_base_dir) if type(config) not in [dict, oc.DictConfig]: - _logger.warning("Model config not found. inference config will be empty.") + self._logger.warning("Model config not found. inference config will be empty.") config = {} return config @@ -125,7 +125,7 @@ def get_climatology_filename(self, stream: str) -> str | None: if clim_base_dir and clim_fn: clim_data_path = Path(clim_base_dir).join(clim_fn) else: - _logger.warning( + self._logger.warning( f"No climatology path specified for stream {stream}. Setting climatology to " "NaN. Add 'climatology_path' to evaluation config to use metrics like ACC." ) @@ -145,9 +145,9 @@ def get_channels(self, stream: str) -> list[str]: ------- A list of channel names. """ - _logger.debug(f"Getting channels for stream {stream}...") + self._logger.debug(f"Getting channels for stream {stream}...") all_channels = self.get_inference_stream_attr(stream, "val_target_channels") - _logger.debug(f"Channels found in config: {all_channels}") + self._logger.debug(f"Channels found in config: {all_channels}") return all_channels def load_scores( @@ -207,7 +207,7 @@ def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArr Path(self.metrics_dir) / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" ) - _logger.debug(f"Looking for: {score_path}") + self._logger.debug(f"Looking for: {score_path}") if score_path.exists(): with open(score_path) as f: data_dict = json.load(f) @@ -280,7 +280,7 @@ def __init__( message = [f"Some {name}(s) were not common among streams, regions and metrics:"] for val in skipped: message.append(f" {val} only in {provenance[name][val]}") - _logger.warning("\n".join(message)) + self.self._logger.warning("\n".join(message)) def get_samples(self) -> set[int]: return self.common_coords["sample"] @@ -320,7 +320,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non self.fname_zarr = fname_zarr_old if not self.fname_zarr.exists(): - _logger.error(f"Zarr file {self.fname_zarr} does not exist.") + self._logger.error(f"Zarr file {self.fname_zarr} does not exist.") raise FileNotFoundError(f"Zarr file {self.fname_zarr} does not exist") def get_data( @@ -367,7 +367,7 @@ def get_data( with zarrio_reader(self.fname_zarr) as zio: stream_cfg = self.get_stream(stream) all_channels = self.get_channels(stream) - _logger.info(f"RUN {self.run_id}: Processing stream {stream}...") + self._logger.info(f"RUN {self.run_id}: Processing stream {stream}...") fsteps = self.get_forecast_steps() if fsteps is None else fsteps @@ -401,14 +401,14 @@ def get_data( fsteps_final = [] for fstep in fsteps: - _logger.info(f"RUN {self.run_id} - {stream}: Processing fstep {fstep}...") + self._logger.info(f"RUN {self.run_id} - {stream}: Processing fstep {fstep}...") da_tars_fs, da_preds_fs, pps = [], [], [] for sample in tqdm(samples, desc=f"Processing {self.run_id} - {stream} - {fstep}"): out = zio.get_data(sample, stream, fstep) if out.target is None or out.prediction is None: - _logger.info( + self._logger.info( f"Skipping {stream} sample {sample} forecast step: {fstep}. " "No data found." ) @@ -420,31 +420,31 @@ def get_data( pps.append(npoints) if npoints == 0: - _logger.info( + self._logger.info( f"Skipping {stream} sample {sample} forecast step: {fstep}. " "Dataset is empty." ) continue if ensemble == ["mean"]: - _logger.debug("Averaging over ensemble members.") + self._logger.debug("Averaging over ensemble members.") pred = pred.mean("ens", keepdims=True) else: - _logger.debug(f"Selecting ensemble members {ensemble}.") + self._logger.debug(f"Selecting ensemble members {ensemble}.") pred = pred.sel(ens=ensemble) da_tars_fs.append(target.squeeze()) da_preds_fs.append(pred.squeeze()) if not da_tars_fs: - _logger.info( + self._logger.info( f"[{self.run_id} - {stream}] No valid data found for fstep {fstep}." ) continue fsteps_final.append(fstep) - _logger.debug( + self._logger.debug( f"Concatenating targets and predictions for stream {stream}, " f"forecast_step {fstep}..." ) @@ -469,7 +469,7 @@ def get_data( da_preds_fs = self.scale_z_channels(da_preds_fs, stream) if len(samples) == 1: - _logger.debug("Repeating sample coordinate for single-sample case.") + self._logger.debug("Repeating sample coordinate for single-sample case.") for da in (da_tars_fs, da_preds_fs): da.assign_coords( sample=( @@ -479,7 +479,7 @@ def get_data( ) if set(channels) != set(all_channels): - _logger.debug( + self._logger.debug( f"Restricting targets and predictions to channels {channels} " f"for stream {stream}..." ) @@ -585,7 +585,7 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: ------- A list of ensemble members. """ - _logger.debug(f"Getting ensembles for stream {stream}...") + self._logger.debug(f"Getting ensembles for stream {stream}...") # TODO: improve this to get ensemble from io class with zarrio_reader(self.fname_zarr) as zio: @@ -604,7 +604,7 @@ def is_regular(self, stream: str) -> bool: ------- True if the stream is regularly spaced. False otherwise. """ - _logger.debug(f"Checking regular spacing for stream {stream}...") + self._logger.debug(f"Checking regular spacing for stream {stream}...") with zarrio_reader(self.fname_zarr) as zio: dummy = zio.get_data(0, stream, zio.forecast_steps[0]) @@ -626,10 +626,10 @@ def is_regular(self, stream: str) -> bool: and np.allclose(sorted(da["lon"].values), sorted(da1["lon"].values)) ) ): - _logger.debug("Latitude and/or longitude coordinates are not regularly spaced.") + self._logger.debug("Latitude and/or longitude coordinates are not regularly spaced.") return False - _logger.debug("Latitude and longitude coordinates are regularly spaced.") + self._logger.debug("Latitude and longitude coordinates are regularly spaced.") return True @@ -680,7 +680,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non super().__init__(eval_cfg, run_id, private_paths) self.readers = [] - _logger.info(f"MERGE READERS: {self.run_ids} ...") + self._logger.info(f"MERGE READERS: {self.run_ids} ...") for run_id in self.run_ids: reader = WeatherGenZarrReader(self.eval_cfg, run_id, self.private_paths) @@ -732,7 +732,7 @@ def get_data( for reader in self.readers: da_tars, da_preds, da_fsteps = [], [], [] - _logger.info(f"MERGE READERS: Processing run_id {reader.run_id}...") + self._logger.info(f"MERGE READERS: Processing run_id {reader.run_id}...") out = reader.get_data( stream, @@ -743,7 +743,7 @@ def get_data( ) for fstep in out.target.keys(): - _logger.debug(f"MERGE READERS: Processing fstep {fstep}...") + self._logger.debug(f"MERGE READERS: Processing fstep {fstep}...") da_tars.append(out.target[fstep]) da_preds.append(out.prediction[fstep]) @@ -903,7 +903,7 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: ------- A range of ensemble members equal to the number of merged readers. """ - _logger.debug(f"Getting ensembles for stream {stream}...") + self._logger.debug(f"Getting ensembles for stream {stream}...") all_ensembles = [] for reader in self.readers: all_ensembles.append(reader.get_ensemble(stream)) @@ -928,5 +928,5 @@ def is_regular(self, stream: str) -> bool: ------- True if the stream is regularly spaced. False otherwise. """ - _logger.debug(f"Checking regular spacing for stream {stream}...") + self._logger.debug(f"Checking regular spacing for stream {stream}...") return all(reader.is_regular(stream) for reader in self.readers) diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py new file mode 100644 index 000000000..2bc2c5edc --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py @@ -0,0 +1,237 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +Integration test for the Weather Generator with multiple streams and observations. +This test must run on a GPU machine. +It performs training and inference with multiple data sources including gridded and obs data. + +Command: +uv run pytest ./integration_tests/small_multi_stream_test.py +""" +import argparse +import json +import logging +from pathlib import Path + +import omegaconf + +from weathergen.evaluate.run_evaluation import evaluate_from_config +from weathergen.run_train import inference_from_args, train_with_args +from weathergen.utils.metrics import get_train_metrics_path +from collections import defaultdict +from weathergen.evaluate.io.wegen_reader import ( + WeatherGenJSONReader, +) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +#TODO: define WEATHERGEN_HOME properly to avoid partial paths + + +streams = ["ERA5", "SurfaceCombined", "NPPATMS"] + +def get_evaluation_config(run_id, verbose = False): + """Create evaluation configuration for multiple streams.""" + cfg = omegaconf.OmegaConf.create( + + { + "global_plotting_options": { + "image_format": "png", + "dpi_val": 300, + }, + "evaluation": { + "regions": ["global"], + "metrics": ["rmse", "froct"], + "summary_plots": True, + "summary_dir": f"./results/{run_id}/plots/summary/", + "print_summary": False, + "verbose": verbose, + }, + "run_ids": { + run_id: { + "streams": { + "ERA5": { + "channels": ["q_850", "z_500", "2t", "10u", "10v", "msl"], + "evaluation": {"forecast_steps": "all", "sample": "all"}, + "plotting": { + "sample": [0, 1], + "forecast_step": [0], + "plot_maps": True, + "plot_histograms": True, + "plot_animations": False, + }, + }, + "SurfaceCombined": { + "channels": ["obsvalue_t2m_0"], + "evaluation": {"forecast_steps": "all", "sample": "all"}, + "plotting": { + "sample": [0, 1], + "forecast_step": [0], + "plot_maps": True, + "plot_histograms": True, + "plot_animations": False, + }, + }, + "NPPATMS": { + "channels": ["obsvalue_rawbt_1"], + "evaluation": {"forecast_steps": "all", "sample": "all"}, + "plotting": { + "sample": [0, 1], + "forecast_step": [0], + "plot_maps": True, + "plot_histograms": True, + "plot_animations": False, + }, + }, + }, + "label": "Multi-Stream Test", + "mini_epoch": 0, + "rank": 0, + } + }, } + ) + return cfg + +def infer_multi_stream(run_id): + """Run inference for multi-stream model.""" + logger.info("run multi-stream inference") + new_run_id = run_id + "_inf" #TODO: better naming + inference_from_args( + ["-start", "2021-10-10", "-end", "2022-10-11", "--samples", "10", "--options", "forecast_offset = 0"] + + [ + "--from_run_id", + run_id, + "--run_id", + new_run_id, + "--streams_output", + ",".join(streams), + "--config", + str("./config/evaluate/latent_space_eval_config.yaml") + ] + ) + return new_run_id + +def evaluate_multi_stream_results(run_id, verbose=False): + """Run evaluation for multiple streams.""" + + logger.info("run multi-stream evaluation") + cfg = get_evaluation_config(run_id, verbose=verbose) + try: + evaluate_from_config(cfg, None, None) + except FileNotFoundError as e: + logger.error(f"Error during evaluation: {e}") + + +def load_metrics(run_id): + """Helper function to load metrics""" + + file_path = get_train_metrics_path(base_path=Path("./results"), run_id=run_id) + + if not file_path.is_file(): + raise FileNotFoundError(f"Metrics file not found for run_id: {run_id}") + with open(file_path) as f: + json_str = f.readlines() + return json.loads("[" + "".join([s.replace("\n", ",") for s in json_str])[:-1] + "]") + + +def load_scores(eval_cfg, run_id): + """Helper function to load metrics""" + + + run_cfg = eval_cfg.run_ids[run_id] + + metrics = list(eval_cfg.evaluation.get("metrics")) + regions = list(eval_cfg.evaluation.get("regions")) + + reader = WeatherGenJSONReader(run_cfg, run_id, None, regions, metrics) + + scores = {} + + for stream_name in streams: + stream_loaded_scores, _ = reader.load_scores( + stream_name , + regions, + metrics, + ) + + scores[stream_name] = stream_loaded_scores + + return scores + + +def print_losses(run_id, stage="val"): + """Print validation losses for specified streams.""" + print(f"\n{stage.capitalize()} Losses for run_id: {run_id}") + metrics = load_metrics(run_id) + + losses = {} + + for stream_name in streams: + + loss = next( + ( + metric.get(f"LossPhysical.{stream_name}.mse.avg") + for metric in reversed(metrics) + if metric.get("stage") == stage + ), + None, + ) + + losses[stream_name] = loss + stage_label = "\nTrain" if stage == "train" else "Validation" + #TODO: understand why logger is not working + print(f"{stage_label} losses – " + ", ".join(f"{k}: {v:.4f}" for k, v in losses.items())) + +def print_evaluation_results(run_id, verbose=False): + """Print evaluation results for specified streams.""" + + eval_cfg = get_evaluation_config(run_id, verbose=verbose) + + try: + scores = load_scores(eval_cfg, run_id) + except FileNotFoundError as e: + print(f"Error loading scores: {e}") + return + + metrics = list(eval_cfg.evaluation.get("metrics")) + regions = list(eval_cfg.evaluation.get("regions")) + for stream_name in streams: + stream_scores = scores[stream_name] + + for metric in metrics: + print("------------------------------------------") + for region in regions: + + da = stream_scores[metric][region][stream_name][run_id] + print(f"\nEvaluation scores for {region} {stream_name} {metric}:") + + mean_da = da.mean(dim=["sample", "forecast_step", "ens"]) + print(mean_da.to_dataframe(name=f"{metric} {region} {stream_name}")) + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run multi-stream latent space evaluation") + parser.add_argument("--run_id", type=str, required=True, help="Run identifier for the model to evaluate") + parser.add_argument("--verbose", action="store_true", help="Enable verbose output", default=False) + args = parser.parse_args() + + run_id = args.run_id + verbose = args.verbose + + infer_run_id = run_id + "_inf" #infer_multi_stream(run_id) + + # Evaluate results + evaluate_multi_stream_results(infer_run_id, verbose=verbose) + print("\n\nFinal Results Summary: \n") + print_losses(run_id,stage="train") + print_losses(infer_run_id,stage="val") + + print_evaluation_results(infer_run_id, verbose=verbose) \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.yaml b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.yaml new file mode 100644 index 000000000..a4ab555bc --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.yaml @@ -0,0 +1,125 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# forecast_freeze_model: False +# forecast_att_dense_rate: 1.0 +# fe_num_blocks: 2 +# fe_num_heads: 16 +# fe_dropout_rate: 0.1 +# fe_with_qk_lnorm: True +# fe_layer_norm_after_blocks: [] +# impute_latent_noise_std: 0.0 + +# healpix_level: 4 + +################ + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # paths in the private config + # data_path_*, + # model_path, + # run_path, + # path_shared_ + # multiprocessing_method + + desc: "" + + run_id: ??? + run_history: [] + +train_log_freq: + terminal: 10 + metrics: 20 + checkpoint: 100 + +# config for training +training_config: + + # training_mode: ["masking"] + + # num_mini_epochs: 1 + # samples_per_mini_epoch: 512 #128 #256 + # shuffle: True + + # start_date: 2012-01-01T00:00 + # end_date: 2020-12-31T00:00 + + # time_window_step: 06:00:00 + # time_window_len: 06:00:00 + + # window_offset_prediction : 1 + + # learning_rate_scheduling : + # lr_start: 1e-6 + # lr_max: 0.00005 + # lr_final_decay: 1e-6 + # lr_final: 0.0 + # num_steps_warmup: 20 + # num_steps_cooldown: 10 + # policy_warmup: "cosine" + # policy_decay: "constant" + # policy_cooldown: "linear" + # parallel_scaling_policy: "sqrt" + + # optimizer: + # grad_clip: 1.0 + # weight_decay: 0.1 + # log_grad_norms: False + # adamw : + # # parameters are scaled by number of DDP workers + # beta1 : 0.975 + # beta2 : 0.9875 + # eps : 2e-08 + + # losses : { + # "physical": { + # type: LossPhysical, + # loss_fcts: { "mse": { }, }, + # }, + # } + + model_input: { + "forecasting" : { + masking_strategy: "forecast", + } + } + + forecast : + time_step: 06:00:00 + num_steps: 2 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +# validation_config: + +# samples_per_mini_epoch: 32 +# shuffle: False + +# start_date: 2021-10-10T00:00 +# end_date: 2022-10-11T00:00 + +# validate_with_ema: +# enabled : True +# ema_ramp_up_ratio: 0.09 +# ema_halflife_in_thousands: 1e-3 + +# write_num_samples: 0 + +# test_config: + +# write_num_samples: 2 diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 27e5d1710..02df4b38a 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -13,7 +13,7 @@ import numpy as np import xarray as xr -_logger = logging.getLogger(__name__) +# _logger = logging.getLogger(__name__) def collect_streams(runs: dict): @@ -105,13 +105,19 @@ def plot_metric_region( run_ids.append(run_id) if selected_data: - _logger.info(f"Creating plot for {metric} - {region} - {stream} - {ch}.") + plotter._logger.info(f"Creating plot for {metric} - {region} - {stream} - {ch}.") - name = create_filename( + name = create_filename(plotter, prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream, ch] ) selected_data, time_dim = _assign_time_coord(selected_data) + + if time_dim != "lead_time": + plotter._logger.warning( + "lead_time coordinate not found for all plotted data; " + "using forecast_step as x-axis." + ) plotter.plot( selected_data, @@ -149,10 +155,6 @@ def _assign_time_coord(selected_data: list[xr.DataArray]) -> tuple[xr.DataArray, ) if "lead_time" not in data.coords and "lead_time" not in data.dims: - _logger.warning( - "lead_time coordinate not found for all plotted data; " - "using forecast_step as x-axis." - ) return selected_data, time_dim # Swap forecast_step with lead_time if all available run_ids have lead_time coord @@ -209,9 +211,10 @@ def ratio_plot_metric_region( run_ids.append(run_id) if len(selected_data) > 0: - _logger.info(f"Creating Ratio plot for {metric} - {stream}") + plotter._logger.info(f"Creating Ratio plot for {metric} - {stream}") name = create_filename( + plotter, prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream] ) plotter.ratio_plot( @@ -269,12 +272,19 @@ def heat_maps_metric_region( run_ids.append(run_id) if len(selected_data) > 0: - _logger.info(f"Creating Heat maps for {metric} - {stream}") + plotter._logger.info(f"Creating Heat maps for {metric} - {stream}") name = create_filename( + plotter, prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream] ) selected_data, time_dim = _assign_time_coord(selected_data) - + + if time_dim != "lead_time": + plotter._logger.warning( + "lead_time coordinate not found for all plotted data; " + "using forecast_step as x-axis." + ) + plotter.heat_map( selected_data, labels, @@ -319,11 +329,11 @@ def score_card_metric_region( run_ids.append(run_id) if selected_data and len(selected_data) > 1.0: - _logger.info(f"Creating score cards for {metric} - {region} - {stream}.") + sc_plotter._logger.info(f"Creating score cards for {metric} - {region} - {stream}.") name = "_".join([metric, region, stream]) sc_plotter.plot(selected_data, run_ids, metric, channels_set, name) else: - _logger.info( + sc_plotter._logger.info( f"Only one run_id for ({region}) region under stream : {stream}. " "Creating bar plot is skipped..." ) @@ -365,11 +375,11 @@ def bar_plot_metric_region( run_ids.append(run_id) if selected_data and len(selected_data) > 1.0: - _logger.info(f"Creating bar plots for {metric} - {region} - {stream}.") + br_plotter._logger.info(f"Creating bar plots for {metric} - {region} - {stream}.") name = "_".join([metric, region, stream]) br_plotter.plot(selected_data, run_ids, metric, channels_set, name) else: - _logger.info( + br_plotter._logger.info( f"Only one run_id for ({region}) region under stream : {stream}. " "Creating bar plot is skipped..." ) @@ -421,6 +431,7 @@ def list_streams(cls): def create_filename( *, + plotter, prefix: Sequence[str] = (), middle: Iterable[str] = (), suffix: Sequence[str] = (), @@ -433,6 +444,8 @@ def create_filename( Parameters ---------- + plotter: + Plotter object to handle the plotting part prefix : Sequence[str] Parts that must appear before the truncated section. middle : Iterable[str] @@ -469,7 +482,7 @@ def create_filename( used += d if len(truncated_middle) < len(mid): - _logger.warning( + plotter._logger.warning( f"Filename truncated: only {len(truncated_middle)} of {len(mid)} middle parts used " f"to keep length <= {max_len}." ) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index f24b0e755..e14fa34b3 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -4,6 +4,7 @@ import os import re from pathlib import Path +from tabnanny import verbose import cartopy import cartopy.crs as ccrs @@ -33,18 +34,13 @@ logging.getLogger("matplotlib.category").setLevel(logging.ERROR) -_logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) - -_logger.debug(f"Taking cartopy paths from {work_dir}") - class Plotter: """ Contains all basic plotting functions. """ - def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | None = None): + def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | None = None, verbose = True): """ Initialize the Plotter class. @@ -64,8 +60,8 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | Stream identifier for which the plots will be created. It can also be set later via update_data_selection. """ - - _logger.info(f"Taking cartopy paths from {work_dir}") + self._logger = setup_logger(__name__, verbose) + self._logger.info(f"Taking cartopy paths from {work_dir}") self.image_format = plotter_cfg.get("image_format") self.dpi_val = plotter_cfg.get("dpi_val") @@ -79,8 +75,10 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | self.out_plot_basedir = Path(output_basedir) / "plots" + self._logger.debug(f"Taking cartopy paths from {work_dir}") + if not os.path.exists(self.out_plot_basedir): - _logger.info(f"Creating dir {self.out_plot_basedir}") + self._logger.info(f"Creating dir {self.out_plot_basedir}") os.makedirs(self.out_plot_basedir, exist_ok=True) self.sample = None @@ -103,17 +101,17 @@ def update_data_selection(self, select: dict): self.select = select if "sample" not in select: - _logger.warning("No sample in the selection. Might lead to unexpected results.") + self._logger.warning("No sample in the selection. Might lead to unexpected results.") else: self.sample = select["sample"] if "stream" not in select: - _logger.warning("No stream in the selection. Might lead to unexpected results.") + self._logger.warning("No stream in the selection. Might lead to unexpected results.") else: self.stream = select["stream"] if "forecast_step" not in select: - _logger.warning("No forecast_step in the selection. Might lead to unexpected results.") + self._logger.warning("No forecast_step in the selection. Might lead to unexpected results.") else: self.fstep = select["forecast_step"] @@ -191,7 +189,7 @@ def create_histograms_per_sample( hist_output_dir = self.out_plot_basedir / self.stream / "histograms" if not os.path.exists(hist_output_dir): - _logger.info(f"Creating dir {hist_output_dir}") + self._logger.info(f"Creating dir {hist_output_dir}") os.makedirs(hist_output_dir) for var in variables: @@ -210,19 +208,19 @@ def create_histograms_per_sample( if self.plot_subtimesteps: ntimes_unique = len(np.unique(targ.valid_time)) - _logger.info( + self._logger.info( f"Creating histograms for {ntimes_unique} valid times of variable {var}." ) groups = zip(targ.groupby("valid_time"), prd.groupby("valid_time"), strict=False) else: - _logger.info(f"Plotting histogram for all valid times of {var}") + self._logger.info(f"Plotting histogram for all valid times of {var}") groups = [((None, targ), (None, prd))] # wrap once with dummy valid_time for (valid_time, targ_t), (_, prd_t) in groups: if valid_time is not None: - _logger.debug(f"Plotting histogram for {var} at valid_time {valid_time}") + self._logger.debug(f"Plotting histogram for {var} at valid_time {valid_time}") name = self.plot_histogram(targ_t, prd_t, hist_output_dir, var, tag=tag) plot_names.append(name) @@ -297,7 +295,7 @@ def plot_histogram( name = "_".join(filter(None, parts)) fname = hist_output_dir / f"{name}.{self.image_format}" - _logger.debug(f"Saving histogram to {fname}") + self._logger.debug(f"Saving histogram to {fname}") plt.savefig(fname) plt.close() @@ -352,7 +350,7 @@ def create_maps_per_sample( map_output_dir = self.get_map_output_dir(tag) if not os.path.exists(map_output_dir): - _logger.info(f"Creating dir {map_output_dir}") + self._logger.info(f"Creating dir {map_output_dir}") os.makedirs(map_output_dir) for region in self.regions: @@ -369,22 +367,22 @@ def create_maps_per_sample( if self.plot_subtimesteps: ntimes_unique = len(np.unique(da.valid_time)) - _logger.info( + self._logger.info( f"Creating maps for {ntimes_unique} valid times of variable {var} - {tag}" ) if ntimes_unique == 0: - _logger.warning( + self._logger.warning( f"No valid times found for variable {var} - {tag}. Skipping." ) continue groups = da.groupby("valid_time") else: - _logger.info(f"Creating maps for all valid times of {var} - {tag}") + self._logger.info(f"Creating maps for all valid times of {var} - {tag}") groups = [(None, da)] # single dummy group for valid_time, da_t in groups: if valid_time is not None: - _logger.debug(f"Plotting map for {var} at valid_time {valid_time}") + self._logger.debug(f"Plotting map for {var} at valid_time {valid_time}") da_t = da_t.dropna(dim="ipoint") assert da_t.size > 0, "Data array must not be empty or contain only NAs" @@ -540,7 +538,7 @@ def scatter_plot( name = "_".join(filter(None, parts)) fname = f"{map_output_dir.joinpath(name)}.{self.image_format}" - _logger.debug(f"Saving map to {fname}") + self._logger.debug(f"Saving map to {fname}") plt.savefig(fname) plt.close() @@ -578,7 +576,7 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]: for region in self.regions: for _, sa in enumerate(samples): for _, var in enumerate(variables): - _logger.info(f"Creating animation for {var} sample: {sa} - {tag}") + self._logger.info(f"Creating animation for {var} sample: {sa} - {tag}") image_paths = [] for _, fstep in enumerate(fsteps): # TODO: refactor to avoid code duplication with scatter_plot @@ -612,7 +610,7 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]: ) else: - _logger.warning(f"No images found for animation {var} sample {sa}") + self._logger.warning(f"No images found for animation {var} sample {sa}") return image_paths @@ -621,7 +619,7 @@ def get_map_output_dir(self, tag): class LinePlots: - def __init__(self, plotter_cfg: dict, output_basedir: str | Path): + def __init__(self, plotter_cfg: dict, output_basedir: str | Path, verbose = True): """ Initialize the LinePlots class. @@ -643,6 +641,7 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path): Base directory under which the plots will be saved. Expected scheme `/`. """ + self._logger = setup_logger(__name__, verbose) self.image_format = plotter_cfg.get("image_format") self.dpi_val = plotter_cfg.get("dpi_val") @@ -653,10 +652,10 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path): self.baseline = plotter_cfg.get("baseline") self.out_plot_dir = Path(output_basedir) / "line_plots" if not os.path.exists(self.out_plot_dir): - _logger.info(f"Creating dir {self.out_plot_dir}") + self._logger.info(f"Creating dir {self.out_plot_dir}") os.makedirs(self.out_plot_dir, exist_ok=True) - _logger.info(f"Saving summary plots to: {self.out_plot_dir}") + self._logger.info(f"Saving summary plots to: {self.out_plot_dir}") def _check_lengths(self, data: xr.DataArray | list, labels: str | list) -> tuple[list, list]: """ @@ -695,10 +694,10 @@ def print_all_points_from_graph(self, fig: plt.Figure) -> None: ydata = line.get_ydata() xdata = line.get_xdata() label = line.get_label() - _logger.info(f"Summary for {label} plot:") + self._logger.info(f"Summary for {label} plot:") for xi, yi in zip(xdata, ydata, strict=False): - _logger.info(f" x: {xi:.3f}, y: {yi:.3f}") - _logger.info("--------------------------") + self._logger.info(f" x: {xi:.3f}, y: {yi:.3f}") + self._logger.info("--------------------------") return def _plot_ensemble(self, data: xr.DataArray, x_dim: str, label: str) -> None: @@ -768,7 +767,7 @@ def _plot_ensemble(self, data: xr.DataArray, x_dim: str, label: str) -> None: alpha=0.2, ) else: - _logger.warning( + self._logger.warning( f"LinePlot:: Unknown option for plot_ensemble: {self.plot_ensemble}. " "Skipping ensemble plotting." ) @@ -800,7 +799,7 @@ def _preprocess_data( non_x_dims = [dim for dim in data.dims if dim not in x_dims] if any(data.sizes.get(dim, 1) > 1 for dim in non_x_dims) and verbose: - logging.info(f"Averaging over dimensions: {non_x_dims}") + self._logger.info(f"Averaging over dimensions: {non_x_dims}") out = data.mean(dim=non_x_dims, skipna=True) @@ -854,7 +853,7 @@ def plot( non_zero_dims = [dim for dim in data.dims if dim != x_dim and data[dim].shape[0] > 1] if self.plot_ensemble and "ens" in non_zero_dims: - _logger.info(f"LinePlot:: Plotting ensemble with option {self.plot_ensemble}.") + self._logger.info(f"LinePlot:: Plotting ensemble with option {self.plot_ensemble}.") self._plot_ensemble(data, x_dim, label_list[i]) else: averaged = self._preprocess_data(data, x_dim) @@ -918,7 +917,7 @@ def _plot_base( plt.yscale("log") if print_summary: - _logger.info(f"Summary values for {name}") + self._logger.info(f"Summary values for {name}") self.print_all_points_from_graph(fig) if line: @@ -979,13 +978,13 @@ def ratio_plot( data_list, label_list = self._check_lengths(data, labels) if len(data_list) < 2: - _logger.warning("Ratio plot requires at least two datasets to compare. Skipping.") + self._logger.warning("Ratio plot requires at least two datasets to compare. Skipping.") return baseline_name = self.baseline baseline_idx = run_ids.index(self.baseline) if self.baseline in run_ids else None if baseline_idx is not None: - _logger.info(f"Using baseline run ID '{self.baseline}' for ratio plot.") + self._logger.info(f"Using baseline run ID '{self.baseline}' for ratio plot.") baseline = data_list[baseline_idx] else: @@ -1084,7 +1083,7 @@ def heat_map( ref = self._preprocess_data(ref, "channel", verbose=False) if ref.isnull().all(): - _logger.warning( + self._logger.warning( f"Heatmap:: Reference data for metric {metric} and label {label} contains " "only NaNs. Skipping heatmap." ) @@ -1141,14 +1140,17 @@ class ScoreCards: Base directory under which the score cards will be saved. """ - def __init__(self, plotter_cfg: dict, output_basedir: str | Path) -> None: + def __init__(self, plotter_cfg: dict, output_basedir: str | Path, verbose=True) -> None: self.image_format = plotter_cfg.get("image_format") self.dpi_val = plotter_cfg.get("dpi_val") self.improvement = plotter_cfg.get("improvement_scale", 0.2) self.out_plot_dir = Path(output_basedir) / "score_cards" self.baseline = plotter_cfg.get("baseline") + + self._logger = setup_logger(__name__, verbose) + if not os.path.exists(self.out_plot_dir): - _logger.info(f"Creating dir {self.out_plot_dir}") + self._logger.info(f"Creating dir {self.out_plot_dir}") os.makedirs(self.out_plot_dir, exist_ok=True) def plot( @@ -1268,7 +1270,7 @@ def plot( ] plt.legend(handles=legend, loc="upper left", bbox_to_anchor=(1.02, 1.0)) - _logger.info(f"Saving scorecards to: {self.out_plot_dir}") + self._logger.info(f"Saving scorecards to: {self.out_plot_dir}") parts = ["score_card", tag] + runs name = "_".join(filter(None, parts)) @@ -1475,15 +1477,18 @@ class BarPlots: Base directory under which the score cards will be saved. """ - def __init__(self, plotter_cfg: dict, output_basedir: str | Path) -> None: + def __init__(self, plotter_cfg: dict, output_basedir: str | Path, verbose: bool = False) -> None: + self.image_format = plotter_cfg.get("image_format") self.dpi_val = plotter_cfg.get("dpi_val") self.cmap = plotter_cfg.get("cmap", "bwr") self.out_plot_dir = Path(output_basedir) / "bar_plots" self.baseline = plotter_cfg.get("baseline") - _logger.info(f"Saving bar plots to: {self.out_plot_dir}") + + self._logger = setup_logger(__name__, verbose) + self._logger.info(f"Saving bar plots to: {self.out_plot_dir}") if not os.path.exists(self.out_plot_dir): - _logger.info(f"Creating dir {self.out_plot_dir}") + self._logger.info(f"Creating dir {self.out_plot_dir}") os.makedirs(self.out_plot_dir, exist_ok=True) def plot( @@ -1559,7 +1564,7 @@ def plot( transform=ax[run_index - 1].transAxes, ) - _logger.info(f"Saving bar plots to: {self.out_plot_dir}") + self._logger.info(f"Saving bar plots to: {self.out_plot_dir}") parts = ["bar_plot_compare", tag] + runs name = "_".join(filter(None, parts)) plt.savefig( @@ -1667,7 +1672,7 @@ def calculate_average_over_dim( ] if non_zero_dims: - _logger.info(f"Found multiple entries for dimensions: {non_zero_dims}. Averaging...") + self._logger.info(f"Found multiple entries for dimensions: {non_zero_dims}. Averaging...") baseline_score = baseline_var.mean( dim=[dim for dim in baseline_var.dims if dim != x_dim], skipna=True @@ -1718,3 +1723,25 @@ def channel_sort_key(name: str) -> tuple[int, str, int]: return (0, prefix, int(number)) else: return (1, name, float("inf")) + +def setup_logger(name: str, verbose: bool) -> logging.Logger: + """ + Set up a logger with the specified name and verbosity level. + + Parameters + ---------- + name : str + Name of the logger. + verbose : bool + If True, set logging level to INFO; otherwise, set to WARNING. + + Returns + ------- + logging.Logger + Configured logger instance. + """ + + logger = logging.getLogger(name) + logging_level = logging.INFO if verbose else logging.CRITICAL + 1 + logger.setLevel(logging_level) + return logger \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index dce9a6725..5baca2648 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -83,11 +83,26 @@ def setup_main_logger(log_file: str | None, log_queue: mp.Queue) -> QueueListene return listener -def setup_worker_logger(log_queue: mp.Queue) -> logging.Logger: - """""" +def setup_worker_logger(log_queue: mp.Queue, verbose: bool) -> logging.Logger: + """ + Set up worker process logger with QueueHandler. + Parameters + ---------- + log_queue: + Multiprocessing queue for logging. + verbose: + Verbosity flag. + Returns + ------- + Configured logger. + """ + if verbose: + logging_level = logging.INFO + else: + logging_level = logging.CRITICAL qh = QueueHandler(log_queue) logger = logging.getLogger() - logger.setLevel(logging.INFO) + logger.setLevel(logging_level) logger.handlers.clear() logger.addHandler(qh) return logger @@ -279,6 +294,8 @@ def evaluate_from_config( plot_score_maps = cfg.evaluation.get("plot_score_maps", False) global_plotting_opts = cfg.get("global_plotting_options", {}) use_parallel = cfg.evaluation.get("num_processes", 0) + verbose = cfg.evaluation.get("verbose", True) + if use_parallel == "auto": num_processes = mp.cpu_count() elif isinstance(use_parallel, int): @@ -328,7 +345,7 @@ def evaluate_from_config( with mp.Pool( processes=num_processes, initializer=setup_worker_logger, - initargs=(log_queue,), + initargs=(log_queue,verbose), ) as pool: results = pool.map( _process_stream_wrapper, @@ -375,7 +392,7 @@ def evaluate_from_config( # summary plots if scores_dict: _logger.info("Started creating summary plots...") - plot_summary(cfg, scores_dict, summary_dir) + plot_summary(cfg, scores_dict, summary_dir, verbose=verbose) if __name__ == "__main__": diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index 0ab0390da..a194b0545 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -156,7 +156,7 @@ def calc_scores_per_stream( _logger.debug( f"Applying bounding box mask for region '{region}' to targets and predictions." ) - + # breakpoint() tars, preds, tars_next, preds_next = [ bbox.apply_mask(x) if x is not None else None for x in (tars, preds, tars_next, preds_next) @@ -316,7 +316,7 @@ def _plot_score_maps_per_stream( plotter.scatter_plot(data, map_dir, channel, region, tag=tag, title=title) -def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: +def plot_data(reader: Reader, stream: str, global_plotting_opts: dict, verbose: bool = True) -> None: """ Plot the data for a given run and stream. @@ -328,6 +328,8 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: Stream name to plot data for. global_plotting_opts: dict Dictionary containing all plotting options that apply globally to all run_ids + verbose: bool + Option to print verbose log messages """ run_id = reader.run_id @@ -356,7 +358,7 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: "regions": global_plotting_opts.get("regions", ["global"]), "plot_subtimesteps": reader.get_inference_stream_attr(stream, "tokenize_spacetime", False), } - plotter = Plotter(plotter_cfg, reader.runplot_dir) + plotter = Plotter(plotter_cfg, reader.runplot_dir, verbose = verbose) available_data = reader.check_availability(stream, mode="plotting") @@ -501,7 +503,7 @@ def metric_list_to_json( ) -def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): +def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path, verbose: bool = True) -> None: """ Plot summary of the evaluation results. This function is a placeholder for future implementation. @@ -512,6 +514,10 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): Configuration dictionary containing all information for the evaluation. scores_dict : Dictionary containing scores for each metric and stream. + summary_dir : + Directory where the summary plots will be saved. + verbose: bool + Option to print verbose log messages """ _logger.info("Plotting summary of evaluation results...") runs = cfg.run_ids @@ -531,9 +537,9 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): "baseline": eval_opt.get("baseline", None), } - plotter = LinePlots(plot_cfg, summary_dir) - sc_plotter = ScoreCards(plot_cfg, summary_dir) - br_plotter = BarPlots(plot_cfg, summary_dir) + plotter = LinePlots(plot_cfg, summary_dir, verbose = verbose) + sc_plotter = ScoreCards(plot_cfg, summary_dir, verbose = verbose) + br_plotter = BarPlots(plot_cfg, summary_dir, verbose = verbose) for region in regions: for metric in metrics: if eval_opt.get("summary_plots", True): From a4b0f12db0e99e8a5753e924ad0ba2ddf77803f5 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Fri, 16 Jan 2026 18:13:30 +0100 Subject: [PATCH 02/44] lint --- .../src/weathergen/evaluate/io/csv_reader.py | 5 +- .../src/weathergen/evaluate/io/io_reader.py | 8 +- .../weathergen/evaluate/io/wegen_reader.py | 9 +- .../latent_space/latent_space_eval.py | 219 ++++++++++-------- .../evaluate/plotting/plot_utils.py | 26 ++- .../weathergen/evaluate/plotting/plotter.py | 37 +-- .../src/weathergen/evaluate/run_evaluation.py | 2 +- .../src/weathergen/evaluate/utils/utils.py | 12 +- 8 files changed, 186 insertions(+), 132 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py index 730e79655..14a2548bf 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py @@ -8,7 +8,6 @@ # nor does it submit to any jurisdiction. # Standard library -import logging import re from pathlib import Path @@ -29,7 +28,9 @@ class CsvReader(Reader): Reader class to read evaluation data from CSV files and convert to xarray DataArray. """ - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose = True): + def __init__( + self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose=True + ): """ Initialize the CsvReader. diff --git a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py index 527d2e4d7..ffeb7eeb7 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py @@ -59,7 +59,9 @@ class DataAvailability: class Reader: - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | None = None, verbose = True): + def __init__( + self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | None = None, verbose=True + ): """ Generic data reader class. @@ -87,8 +89,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | ) # base directory where results will be stored self._logger = logging.getLogger(__name__) - - logger_level = logging.INFO if verbose else logging.WARNING + + logger_level = logging.INFO if verbose else logging.WARNING self._logger.setLevel(logger_level) def get_stream(self, stream: str): diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 01cef980e..c51227164 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -9,7 +9,6 @@ # Standard library import json -import logging from collections import defaultdict from pathlib import Path @@ -35,7 +34,9 @@ class WeatherGenReader(Reader): - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose = True): + def __init__( + self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose=True + ): super().__init__(eval_cfg, run_id, private_paths, verbose) # TODO: remove backwards compatibility to "epoch" in Feb. 2026 @@ -46,7 +47,9 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non if not self.results_base_dir: self.results_base_dir = Path(get_shared_wg_path("results")) - self._logger.info(f"Results directory obtained from private config: {self.results_base_dir}") + self._logger.info( + f"Results directory obtained from private config: {self.results_base_dir}" + ) else: self._logger.info(f"Results directory parsed: {self.results_base_dir}") diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py index 2bc2c5edc..aea47b62a 100644 --- a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py +++ b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py @@ -15,31 +15,71 @@ Command: uv run pytest ./integration_tests/small_multi_stream_test.py """ + import argparse import json import logging +import sys from pathlib import Path import omegaconf -from weathergen.evaluate.run_evaluation import evaluate_from_config -from weathergen.run_train import inference_from_args, train_with_args -from weathergen.utils.metrics import get_train_metrics_path -from collections import defaultdict from weathergen.evaluate.io.wegen_reader import ( WeatherGenJSONReader, ) +from weathergen.evaluate.run_evaluation import evaluate_from_config +from weathergen.run_train import inference_from_args +from weathergen.utils.metrics import get_train_metrics_path + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -#TODO: define WEATHERGEN_HOME properly to avoid partial paths +if not logger.handlers: + h = logging.StreamHandler(sys.stdout) + h.setLevel(logging.INFO) + logger.addHandler(h) + +logger.propagate = False + +# TODO: define WEATHERGEN_HOME properly to avoid partial paths streams = ["ERA5", "SurfaceCombined", "NPPATMS"] -def get_evaluation_config(run_id, verbose = False): +############## INFERENCE ################# + + +def infer_multi_stream(run_id): + """Run inference for multi-stream model.""" + logger.info("run multi-stream inference") + new_run_id = run_id + "_inf" # TODO: better naming + inference_from_args( + [ + "-start", + "2021-10-10", + "-end", + "2022-10-11", + "--samples", + "10", + "--options", + "forecast_offset=0", + "zarr_store=zip", + ] + + ["--from_run_id", run_id, "--run_id", new_run_id, "--streams_output"] + + streams + + [ + "--config", + "./config/evaluate/latent_space_eval_config.yaml", + ] + ) + return new_run_id + + +############## EVALUATION ################# + + +def get_evaluation_config(run_id, verbose=False): """Create evaluation configuration for multiple streams.""" cfg = omegaconf.OmegaConf.create( - { "global_plotting_options": { "image_format": "png", @@ -94,32 +134,15 @@ def get_evaluation_config(run_id, verbose = False): "mini_epoch": 0, "rank": 0, } - }, } + }, + } ) return cfg -def infer_multi_stream(run_id): - """Run inference for multi-stream model.""" - logger.info("run multi-stream inference") - new_run_id = run_id + "_inf" #TODO: better naming - inference_from_args( - ["-start", "2021-10-10", "-end", "2022-10-11", "--samples", "10", "--options", "forecast_offset = 0"] - + [ - "--from_run_id", - run_id, - "--run_id", - new_run_id, - "--streams_output", - ",".join(streams), - "--config", - str("./config/evaluate/latent_space_eval_config.yaml") - ] - ) - return new_run_id def evaluate_multi_stream_results(run_id, verbose=False): """Run evaluation for multiple streams.""" - + logger.info("run multi-stream evaluation") cfg = get_evaluation_config(run_id, verbose=verbose) try: @@ -128,11 +151,63 @@ def evaluate_multi_stream_results(run_id, verbose=False): logger.error(f"Error during evaluation: {e}") +############## PRINT FUNCTIONS ################# + + +def print_losses(run_id, stage="val"): + """Print validation losses for specified streams.""" + logger.info(f"{stage.capitalize()} Losses for run_id: {run_id}") + metrics = load_metrics(run_id) + + losses = {} + + for stream_name in streams: + loss = next( + ( + metric.get(f"LossPhysical.{stream_name}.mse.avg") + for metric in reversed(metrics) + if metric.get("stage") == stage + ), + None, + ) + + losses[stream_name] = loss + stage_label = "Train" if stage == "train" else "Validation" + # TODO: understand why logger is not working + logger.info( + f"{stage_label} losses – " + ", ".join(f"{k}: {v:.4f}" for k, v in losses.items()) + "\n" + ) + + +def print_evaluation_results(run_id, verbose=False): + """Print evaluation results for specified streams.""" + + eval_cfg = get_evaluation_config(run_id, verbose=verbose) + scores = load_scores(eval_cfg, run_id) + + metrics = list(eval_cfg.evaluation.get("metrics")) + regions = list(eval_cfg.evaluation.get("regions")) + for stream_name in streams: + stream_scores = scores[stream_name] + + for metric in metrics: + logger.info("------------------------------------------") + for region in regions: + da = stream_scores[metric][region][stream_name][run_id] + logger.info(f"\nEvaluation scores for {region} {stream_name} {metric}:") + + mean_da = da.mean(dim=["sample", "forecast_step", "ens"]) + logger.info(mean_da.to_dataframe(name=f"{metric} {region} {stream_name}")) + + +############## HELPERS ################# + + def load_metrics(run_id): """Helper function to load metrics""" - + file_path = get_train_metrics_path(base_path=Path("./results"), run_id=run_id) - + if not file_path.is_file(): raise FileNotFoundError(f"Metrics file not found for run_id: {run_id}") with open(file_path) as f: @@ -142,96 +217,50 @@ def load_metrics(run_id): def load_scores(eval_cfg, run_id): """Helper function to load metrics""" - - + run_cfg = eval_cfg.run_ids[run_id] metrics = list(eval_cfg.evaluation.get("metrics")) regions = list(eval_cfg.evaluation.get("regions")) reader = WeatherGenJSONReader(run_cfg, run_id, None, regions, metrics) - + scores = {} for stream_name in streams: stream_loaded_scores, _ = reader.load_scores( - stream_name , + stream_name, regions, metrics, ) scores[stream_name] = stream_loaded_scores - - return scores - - -def print_losses(run_id, stage="val"): - """Print validation losses for specified streams.""" - print(f"\n{stage.capitalize()} Losses for run_id: {run_id}") - metrics = load_metrics(run_id) - - losses = {} - - for stream_name in streams: - - loss = next( - ( - metric.get(f"LossPhysical.{stream_name}.mse.avg") - for metric in reversed(metrics) - if metric.get("stage") == stage - ), - None, - ) - - losses[stream_name] = loss - stage_label = "\nTrain" if stage == "train" else "Validation" - #TODO: understand why logger is not working - print(f"{stage_label} losses – " + ", ".join(f"{k}: {v:.4f}" for k, v in losses.items())) -def print_evaluation_results(run_id, verbose=False): - """Print evaluation results for specified streams.""" - - eval_cfg = get_evaluation_config(run_id, verbose=verbose) - - try: - scores = load_scores(eval_cfg, run_id) - except FileNotFoundError as e: - print(f"Error loading scores: {e}") - return - - metrics = list(eval_cfg.evaluation.get("metrics")) - regions = list(eval_cfg.evaluation.get("regions")) - for stream_name in streams: - stream_scores = scores[stream_name] - - for metric in metrics: - print("------------------------------------------") - for region in regions: - - da = stream_scores[metric][region][stream_name][run_id] - print(f"\nEvaluation scores for {region} {stream_name} {metric}:") + return scores - mean_da = da.mean(dim=["sample", "forecast_step", "ens"]) - print(mean_da.to_dataframe(name=f"{metric} {region} {stream_name}")) - - +############## MAIN ################# if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run multi-stream latent space evaluation") - parser.add_argument("--run_id", type=str, required=True, help="Run identifier for the model to evaluate") - parser.add_argument("--verbose", action="store_true", help="Enable verbose output", default=False) + parser.add_argument( + "--run_id", type=str, required=True, help="Run identifier for the model to evaluate" + ) + parser.add_argument( + "--verbose", action="store_true", help="Enable verbose output", default=False + ) args = parser.parse_args() - + run_id = args.run_id verbose = args.verbose - - infer_run_id = run_id + "_inf" #infer_multi_stream(run_id) + + infer_run_id = infer_multi_stream(run_id) # Evaluate results evaluate_multi_stream_results(infer_run_id, verbose=verbose) - print("\n\nFinal Results Summary: \n") - print_losses(run_id,stage="train") - print_losses(infer_run_id,stage="val") - - print_evaluation_results(infer_run_id, verbose=verbose) \ No newline at end of file + logger.info("\n\nFinal Results Summary: \n") + logger.info("TRAINING & INFERENCE LOSSES: \n") + print_losses(run_id, stage="train") + print_losses(infer_run_id, stage="val") + logger.info("EVALUATION: \n") + print_evaluation_results(infer_run_id, verbose=verbose) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 02df4b38a..54e3d8572 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -7,7 +7,6 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import logging from collections.abc import Iterable, Sequence import numpy as np @@ -107,12 +106,15 @@ def plot_metric_region( if selected_data: plotter._logger.info(f"Creating plot for {metric} - {region} - {stream} - {ch}.") - name = create_filename(plotter, - prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream, ch] + name = create_filename( + plotter=plotter, + prefix=[metric, region], + middle=sorted(set(run_ids)), + suffix=[stream, ch], ) selected_data, time_dim = _assign_time_coord(selected_data) - + if time_dim != "lead_time": plotter._logger.warning( "lead_time coordinate not found for all plotted data; " @@ -214,8 +216,10 @@ def ratio_plot_metric_region( plotter._logger.info(f"Creating Ratio plot for {metric} - {stream}") name = create_filename( - plotter, - prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream] + plotter=plotter, + prefix=[metric, region], + middle=sorted(set(run_ids)), + suffix=[stream], ) plotter.ratio_plot( selected_data, @@ -274,17 +278,19 @@ def heat_maps_metric_region( if len(selected_data) > 0: plotter._logger.info(f"Creating Heat maps for {metric} - {stream}") name = create_filename( - plotter, - prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream] + plotter=plotter, + prefix=[metric, region], + middle=sorted(set(run_ids)), + suffix=[stream], ) selected_data, time_dim = _assign_time_coord(selected_data) - + if time_dim != "lead_time": plotter._logger.warning( "lead_time coordinate not found for all plotted data; " "using forecast_step as x-axis." ) - + plotter.heat_map( selected_data, labels, diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index e14fa34b3..ba07df48a 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -4,7 +4,6 @@ import os import re from pathlib import Path -from tabnanny import verbose import cartopy import cartopy.crs as ccrs @@ -40,7 +39,9 @@ class Plotter: Contains all basic plotting functions. """ - def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | None = None, verbose = True): + def __init__( + self, plotter_cfg: dict, output_basedir: str | Path, stream: str | None = None, verbose=True + ): """ Initialize the Plotter class. @@ -111,7 +112,9 @@ def update_data_selection(self, select: dict): self.stream = select["stream"] if "forecast_step" not in select: - self._logger.warning("No forecast_step in the selection. Might lead to unexpected results.") + self._logger.warning( + "No forecast_step in the selection. Might lead to unexpected results." + ) else: self.fstep = select["forecast_step"] @@ -619,7 +622,7 @@ def get_map_output_dir(self, tag): class LinePlots: - def __init__(self, plotter_cfg: dict, output_basedir: str | Path, verbose = True): + def __init__(self, plotter_cfg: dict, output_basedir: str | Path, verbose=True): """ Initialize the LinePlots class. @@ -1336,7 +1339,9 @@ def compare_models( baseline_var = baseline.sel({"channel": var}) data_var = data[run_index].sel({"channel": var}) - baseline_score, model_score = calculate_average_over_dim(x_dim, baseline_var, data_var) + baseline_score, model_score = calculate_average_over_dim( + x_dim, baseline_var, data_var, self._logger + ) diff = baseline_score - model_score skill = self.get_skill_score(model_score, baseline_score, metric) @@ -1477,8 +1482,9 @@ class BarPlots: Base directory under which the score cards will be saved. """ - def __init__(self, plotter_cfg: dict, output_basedir: str | Path, verbose: bool = False) -> None: - + def __init__( + self, plotter_cfg: dict, output_basedir: str | Path, verbose: bool = False + ) -> None: self.image_format = plotter_cfg.get("image_format") self.dpi_val = plotter_cfg.get("dpi_val") self.cmap = plotter_cfg.get("cmap", "bwr") @@ -1612,7 +1618,9 @@ def calc_ratio_per_run_id( data_var = data[run_index].sel({"channel": var}) channels_per_comparison.append(var) - baseline_score, model_score = calculate_average_over_dim(x_dim, baseline_var, data_var) + baseline_score, model_score = calculate_average_over_dim( + x_dim, baseline_var, data_var, self._logger + ) ratio_score.append(model_score / baseline_score) @@ -1645,7 +1653,7 @@ def colors(self, ratio_score: np.array, metric: str) -> list[tuple]: def calculate_average_over_dim( - x_dim: str, baseline_var: xr.DataArray, data_var: xr.DataArray + x_dim: str, baseline_var: xr.DataArray, data_var: xr.DataArray, logger ) -> tuple[xr.DataArray, xr.DataArray]: """ Calculate average over xarray dimensions that are larger than 1. Those might be the @@ -1659,6 +1667,8 @@ def calculate_average_over_dim( xarray DataArray with the scores of the baseline model for a specific channel/variable data_var: xr.DataArray xarray DataArray with the scores of the comparison model for a specific channel/variable + logger: logging.Logger + Logger instance for logging information Returns ------- @@ -1672,7 +1682,7 @@ def calculate_average_over_dim( ] if non_zero_dims: - self._logger.info(f"Found multiple entries for dimensions: {non_zero_dims}. Averaging...") + logger.info(f"Found multiple entries for dimensions: {non_zero_dims}. Averaging...") baseline_score = baseline_var.mean( dim=[dim for dim in baseline_var.dims if dim != x_dim], skipna=True @@ -1724,6 +1734,7 @@ def channel_sort_key(name: str) -> tuple[int, str, int]: else: return (1, name, float("inf")) + def setup_logger(name: str, verbose: bool) -> logging.Logger: """ Set up a logger with the specified name and verbosity level. @@ -1740,8 +1751,8 @@ def setup_logger(name: str, verbose: bool) -> logging.Logger: logging.Logger Configured logger instance. """ - + logger = logging.getLogger(name) - logging_level = logging.INFO if verbose else logging.CRITICAL + 1 + logging_level = logging.INFO if verbose else logging.CRITICAL + 1 logger.setLevel(logging_level) - return logger \ No newline at end of file + return logger diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 5baca2648..cb0c10061 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -345,7 +345,7 @@ def evaluate_from_config( with mp.Pool( processes=num_processes, initializer=setup_worker_logger, - initargs=(log_queue,verbose), + initargs=(log_queue, verbose), ) as pool: results = pool.map( _process_stream_wrapper, diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index a194b0545..cf460b694 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -316,7 +316,9 @@ def _plot_score_maps_per_stream( plotter.scatter_plot(data, map_dir, channel, region, tag=tag, title=title) -def plot_data(reader: Reader, stream: str, global_plotting_opts: dict, verbose: bool = True) -> None: +def plot_data( + reader: Reader, stream: str, global_plotting_opts: dict, verbose: bool = True +) -> None: """ Plot the data for a given run and stream. @@ -358,7 +360,7 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict, verbose: "regions": global_plotting_opts.get("regions", ["global"]), "plot_subtimesteps": reader.get_inference_stream_attr(stream, "tokenize_spacetime", False), } - plotter = Plotter(plotter_cfg, reader.runplot_dir, verbose = verbose) + plotter = Plotter(plotter_cfg, reader.runplot_dir, verbose=verbose) available_data = reader.check_availability(stream, mode="plotting") @@ -537,9 +539,9 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path, verbose: bool "baseline": eval_opt.get("baseline", None), } - plotter = LinePlots(plot_cfg, summary_dir, verbose = verbose) - sc_plotter = ScoreCards(plot_cfg, summary_dir, verbose = verbose) - br_plotter = BarPlots(plot_cfg, summary_dir, verbose = verbose) + plotter = LinePlots(plot_cfg, summary_dir, verbose=verbose) + sc_plotter = ScoreCards(plot_cfg, summary_dir, verbose=verbose) + br_plotter = BarPlots(plot_cfg, summary_dir, verbose=verbose) for region in regions: for metric in metrics: if eval_opt.get("summary_plots", True): From 40eccbdf7822197ae39091cb0884bb254b5ed000 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Fri, 16 Jan 2026 18:14:19 +0100 Subject: [PATCH 03/44] add usage --- .../src/weathergen/evaluate/latent_space/latent_space_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py index aea47b62a..60c620918 100644 --- a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py +++ b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py @@ -13,7 +13,7 @@ It performs training and inference with multiple data sources including gridded and obs data. Command: -uv run pytest ./integration_tests/small_multi_stream_test.py +uv run --offline packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py --run_id j2dkivn8 """ import argparse From 3ed01fc052b4802676eac613546754a892185f5d Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 21 Jan 2026 11:15:18 +0100 Subject: [PATCH 04/44] fix log --- packages/evaluate/src/weathergen/evaluate/run_evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index cb0c10061..f8d5f8a43 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -339,7 +339,7 @@ def evaluate_from_config( scores_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) if num_processes == 0: if log_queue is not None: - setup_worker_logger(log_queue) + setup_worker_logger(log_queue, verbose) results = [_process_stream(**task) for task in tasks] else: with mp.Pool( From 2f9f12587c889a3bd2f453fb6b721225ad67d047 Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Fri, 16 Jan 2026 17:24:50 +0100 Subject: [PATCH 05/44] Jk/develop/1639 fix shard val forward (#1642) * rm model_forward assignment in val * rm clutter from diffusion branch * reverse if order --- src/weathergen/train/trainer.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 6a9137f60..3cbe9f05a 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -520,16 +520,18 @@ def validate(self, mini_epoch, mode_cfg, batch_size): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - model_forward = ( - self.model.forward - if self.ema_model is None - else self.ema_model.forward_eval - ) - preds = model_forward( - self.model_params, - batch.get_source_samples(), - mode_cfg.window_offset_prediction, - ) + if self.ema_model is None: + preds = self.model( + self.model_params, + batch.get_source_samples(), + mode_cfg.window_offset_prediction, + ) + else: + preds = self.ema_model.forward_eval( + self.model_params, + batch.get_source_samples(), + mode_cfg.window_offset_prediction, + ) targets_and_auxs = {} for loss_name, target_aux in self.svalidate_with_ema_cfg.items(): From 7c4bb828319bf678e4c2791cd212105190d5346e Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 16 Jan 2026 17:25:07 +0100 Subject: [PATCH 06/44] Clessig/develop/fix finetuning 1640 (#1641) * Fix bug with diagnostic streams * Avoid that empty decoders are allocated --- src/weathergen/model/engines.py | 2 +- src/weathergen/model/model.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 1cd34a442..55c985a12 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -100,7 +100,7 @@ def forward(self, batch, pe_embed): for sample in batch.get_samples(): sdata += [sample.streams_data[stream_name].source_tokens_cells[istep]] - sdata = torch.cat(sdata) + sdata = torch.cat(sdata).to(tokens_all.dtype) # skip empty stream if len(sdata) == 0: continue diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 600c122c7..9d6bbb98f 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -313,6 +313,10 @@ def create(self) -> "Model": for i_stream, si in enumerate(cf.streams): stream_name = self.stream_names[i_stream] + # skip decoder if channels are empty + if len(si.train_target_channels) == 0 and len(si.val_target_channels) == 0: + continue + # extract and setup relevant parameters etc = si["embed_target_coords"] tro_type = ( @@ -494,16 +498,19 @@ def print_num_parameters(self) -> None: num_params_fe = get_num_parameters(self.forecast_engine.fe_blocks) + mdict = self.embed_target_coords num_params_embed_tcs = [ - get_num_parameters(self.embed_target_coords[name]) if self.embed_target_coords else 0 + get_num_parameters(mdict[name]) if mdict and name in mdict else 0 for name in self.stream_names ] + mdict = self.target_token_engines num_params_tte = [ - get_num_parameters(self.target_token_engines[name]) if self.target_token_engines else 0 + get_num_parameters(mdict[name]) if mdict and name in mdict else 0 for name in self.stream_names ] + mdict = self.pred_heads num_params_preds = [ - get_num_parameters(self.pred_heads[name]) if self.pred_heads else 0 + get_num_parameters(mdict[name]) if mdict and name in mdict else 0 for name in self.stream_names ] From 9144c64d9ea253bc463fb4c7c16f71dc138b0756 Mon Sep 17 00:00:00 2001 From: Sophie X <24638638+sophie-xhonneux@users.noreply.github.com> Date: Fri, 16 Jan 2026 11:38:04 -0500 Subject: [PATCH 07/44] Sophiex/dev/synop nppatms finetuning configs (#1644) * Doing something wrong * Make fine-tuning work * Rename sensibly --- config/config_finetuning_synop.yml | 301 +++++++++++++++++++ config/streams/era5_decoding_synop/era5.yml | 39 +++ config/streams/era5_decoding_synop/synop.yml | 33 ++ src/weathergen/model/model_interface.py | 2 + src/weathergen/train/trainer.py | 12 +- 5 files changed, 381 insertions(+), 6 deletions(-) create mode 100644 config/config_finetuning_synop.yml create mode 100644 config/streams/era5_decoding_synop/era5.yml create mode 100644 config/streams/era5_decoding_synop/synop.yml diff --git a/config/config_finetuning_synop.yml b/config/config_finetuning_synop.yml new file mode 100644 index 000000000..b4a8bcc50 --- /dev/null +++ b/config/config_finetuning_synop.yml @@ -0,0 +1,301 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 1024 +ae_local_num_blocks: 2 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 8 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 2 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: Linear # CrossAttentionAdaNormConditioning +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 1 +num_register_tokens: 7 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False + +healpix_level: 5 + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +freeze_modules: "(^(?!Bilinear|embed_target_coords|pred_heads|target_token_engines).*)" + +norm_type: "LayerNorm" + + +##################################### + +streams_directory: "./config/streams/era5_decoding_nppatms_synop/" +streams: ??? + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_log_freq: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 16 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 2020-12-31T00:00 + end_date: 2022-07-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + window_offset_prediction : 0 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 512 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 00:00:00 + num_steps: 0 + policy: null #"fixed" + # losses : { + # "physical": { + # type: LossPhysical, + # weight: 1.0, + # loss_fcts: { + # "mse": { + # weight: 1.0, + # target_source_correspondence: { 0 : { 0 : "identity"} }, + # }, + # }, + # target_and_aux_calc: "Physical", + # }, + # } + # + # model_input: { + # "source" : { + # # masking strategy: "random", "forecast" + # masking_strategy: "random", + # num_samples: 1, + # num_steps_input: 1, + # masking_strategy_config : { + # diffusion_rn : True, + # rate : 0.0, + # rate_sampling: False + # }, + # }, + # } + + # target_input: { + # "target" : { + # masking_strategy: "healpix", + # num_samples: 1, + # masking_strategy_config : { rate : 0.0, hl_mask: 0, rate_sampling: False }, + # }, + # } + + # forecast : + # time_step: 06:00:00 + # num_steps: 2 + # policy: "fixed" #null + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2022-10-01T00:00 + end_date: 2022-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # number of validation samples that are written to disk + write_num_samples: 0 + # output streams to write; default all + output_streams: null + + # run validation before training starts (mainly for model development) + validate_before_training: False # 8 #False + + # losses: { + # "physical": { + # type: LossPhysical, + # weight: 1.0, + # loss_fcts: { + # "mse": { + # weight: 1.0, + # }, + # }, + # }, + # } + + # # Requires enabled flags + # model_input: { + # "source" : { + # enabled : True, + # }, + # } + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null + diff --git a/config/streams/era5_decoding_synop/era5.yml b/config/streams/era5_decoding_synop/era5.yml new file mode 100644 index 000000000..2108a2cf9 --- /dev/null +++ b/config/streams/era5_decoding_synop/era5.yml @@ -0,0 +1,39 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + stream_id : 0 + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + loss_weight : 1. + source_exclude : ['w_', 'skt', 'sp', 'tcw', 'cp', 'tp'] + target: [] + diagnostic : False + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 32 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 4 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 4 + target_readout : + type : 'obs_value' + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 + diff --git a/config/streams/era5_decoding_synop/synop.yml b/config/streams/era5_decoding_synop/synop.yml new file mode 100644 index 000000000..32692cd0f --- /dev/null +++ b/config/streams/era5_decoding_synop/synop.yml @@ -0,0 +1,33 @@ +# obs_types +# 0 : polar orbiting satellites +# 1 : geostationay satellites +# 2 : conventional observations + +SurfaceCombined : + type : obs + stream_id : 2 + # source: [] + is_diagnostic: True + filenames : ['observations-ea-ofb-0001-1979-2023-combined-surface-v2.zarr'] + loss_weight : 1.0 + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 64 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 128 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 128 + target_readout : + type : 'obs_value' # token or obs_value + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index f330b6ed6..df483982a 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -60,7 +60,9 @@ def init_model_and_shard( if name == "": continue if re.fullmatch(cf.freeze_modules, name) is not None: + logger.info(f"Froze weights {name}") freeze_weights(module) + # TODO: this should be handled in the encoder to be close where q_cells is defined if "q_cells" in cf.freeze_modules: model.encoder.q_cells.requires_grad = False diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 3cbe9f05a..4d4db7e4c 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -69,7 +69,7 @@ def __init__(self, train_log_freq: Config): self.perf_mem = None self.t_start: float = 0 self.target_and_aux_calculators = None - self.svalidate_with_ema_cfg = None + self.validate_with_ema_cfg = None self.validate_with_ema: bool = False self.batch_size_per_gpu = -1 self.batch_size_validation_per_gpu = -1 @@ -196,7 +196,7 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): ) # get target_aux calculators for different loss terms - self.svalidate_with_ema_cfg = self.get_target_aux_calculators(self.test_cfg) + self.validate_with_ema_cfg = self.get_target_aux_calculators(self.test_cfg) self.loss_calculator_val = LossCalculator(cf, self.test_cfg, VAL, device=self.devices[0]) @@ -269,7 +269,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # get target_aux calculators for different loss terms self.target_and_aux_calculators = self.get_target_aux_calculators(self.training_cfg) - self.svalidate_with_ema_cfg = self.get_target_aux_calculators(self.validation_cfg) + self.validate_with_ema_cfg = self.get_target_aux_calculators(self.validation_cfg) # if with_fsdp then parameter count is unreliable if is_root(): @@ -441,7 +441,7 @@ def train(self, mini_epoch): ] [ target_aux.update_state_pre_backward(self.cf.general.istep, batch, self.model) - for _, target_aux in self.svalidate_with_ema_cfg.items() + for _, target_aux in self.validate_with_ema_cfg.items() ] # backward pass @@ -476,7 +476,7 @@ def train(self, mini_epoch): ] [ target_aux.update_state_post_opt_step(step, batch, self.model) - for _, target_aux in self.svalidate_with_ema_cfg.items() + for _, target_aux in self.validate_with_ema_cfg.items() ] # EMA update if self.validate_with_ema: @@ -534,7 +534,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): ) targets_and_auxs = {} - for loss_name, target_aux in self.svalidate_with_ema_cfg.items(): + for loss_name, target_aux in self.validate_with_ema_cfg.items(): target_idxs = get_target_idxs_from_cfg(self.training_cfg, loss_name) targets_and_auxs[loss_name] = target_aux.compute( self.cf.general.istep, From 699a8aa844813ae63143142c786446e3396f86de Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 16 Jan 2026 17:39:36 +0100 Subject: [PATCH 08/44] Enable multiple student views for one target for JEPA (#1617) * Enable multiple student views for one target * Improved readability --- src/weathergen/train/loss_modules/loss_module_ssl.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/weathergen/train/loss_modules/loss_module_ssl.py b/src/weathergen/train/loss_modules/loss_module_ssl.py index 826b732e9..95f610bdc 100644 --- a/src/weathergen/train/loss_modules/loss_module_ssl.py +++ b/src/weathergen/train/loss_modules/loss_module_ssl.py @@ -221,6 +221,7 @@ def gather_targets_for_loss(self, name, targets, metadata, target2source_matchin def jepa_loss(student_patches_masked, student_masks, teacher_patches_masked, teacher_masks): # TODO remove as we deal with batch dimension + assert teacher_masks.shape[0] == 1 or teacher_masks.shape[0] == student_masks.shape[0] student_masks = student_masks.squeeze(dim=1) teacher_masks = teacher_masks.squeeze(dim=1) masks_weight = ( @@ -233,7 +234,13 @@ def jepa_loss(student_patches_masked, student_masks, teacher_patches_masked, tea if mask.sum() == 0: logger.warning("jepa_loss mask is all true, likely incorrect masking config.") - loss = F.l1_loss(student_patches_masked[mask], teacher_patches_masked[mask]) + assert mask.shape[0] == student_patches_masked.shape[0], ( + "mask.shape[0], batch dimension, has to match batch dimension for student_patches_masked." + ) + # expand/repeat teacher_masks to match number of student samples + teacher_patches = teacher_patches_masked.expand((mask.shape[0], -1, -1)) + # compute loss + loss = F.l1_loss(student_patches_masked[mask], teacher_patches[mask]) loss = loss * masks_weight[mask] return loss.sum() # / student_masks.shape[0] From 88e809d0ddade2f7a7677f77c77bb570e4a1abeb Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sat, 17 Jan 2026 12:36:04 +0100 Subject: [PATCH 09/44] Fix test for empty targets in decoder creation (#1646) --- src/weathergen/model/model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 9d6bbb98f..a4d179443 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -314,7 +314,10 @@ def create(self) -> "Model": stream_name = self.stream_names[i_stream] # skip decoder if channels are empty - if len(si.train_target_channels) == 0 and len(si.val_target_channels) == 0: + if ( + len(si.get("train_target_channels", [])) == 0 + and len(si.get("val_target_channels", [])) == 0 + ): continue # extract and setup relevant parameters From 9bdd7d0a090052b1ee51c754651733c82d11dc7a Mon Sep 17 00:00:00 2001 From: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> Date: Sat, 17 Jan 2026 17:34:54 +0100 Subject: [PATCH 10/44] add regions to integration tests (#1648) --- integration_tests/small1_test.py | 3 ++- integration_tests/small_multi_stream_test.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/integration_tests/small1_test.py b/integration_tests/small1_test.py index cbd5a3d7b..797c6d501 100644 --- a/integration_tests/small1_test.py +++ b/integration_tests/small1_test.py @@ -15,8 +15,8 @@ import omegaconf import pytest - from weathergen.evaluate.run_evaluation import evaluate_from_config + from weathergen.run_train import inference_from_args, train_with_args from weathergen.utils.metrics import get_train_metrics_path @@ -105,6 +105,7 @@ def evaluate_results(run_id): "dpi_val": 300, }, "evaluation": { + "regions": ["global"], "metrics": ["rmse", "l1", "mse"], "verbose": True, "summary_plots": True, diff --git a/integration_tests/small_multi_stream_test.py b/integration_tests/small_multi_stream_test.py index f4bdb584e..dbcc3a3bc 100644 --- a/integration_tests/small_multi_stream_test.py +++ b/integration_tests/small_multi_stream_test.py @@ -23,8 +23,8 @@ import omegaconf import pytest - from weathergen.evaluate.run_evaluation import evaluate_from_config + from weathergen.run_train import inference_from_args, train_with_args from weathergen.utils.metrics import get_train_metrics_path @@ -105,6 +105,7 @@ def evaluate_multi_stream_results(run_id): "dpi_val": 300, }, "evaluation": { + "regions": ["global"], "metrics": ["rmse", "l1", "mse"], "verbose": True, "summary_plots": True, From 75df669ac630e3e064b5bdff4344bc5692c67558 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 19 Jan 2026 07:57:12 +0100 Subject: [PATCH 11/44] Memory pinning (#1615) * add pin mem to IOReaderData * add pin mem to sample & modelbatch class * add pin mem to stream data * add pin mem to training loop * run /scripts/actions.sh lint * run ./scripts/actions.sh unit-test * ignore check torch import in package * move pinning to MultiStreamDataSampler * add _pin_tensor & _pin_tensor_list helper func * ruff the code * move back pin mem. to train loop * Remove the ignore-import-error rule and revert to the state before the change * create protocol for pinnable obj * remove pin_mem from IOReaderData class * add pin_memory to Trainer.validate * remove pin_memory from loader_params * Rever export/export_inference.py to state before c3fc9a78 * change name * revise Pinnable class description * add memory_pinning in config, train & va loop * use getattr to avoid CICD warning * use setattr to avoid CICD warning * disable pylint for self.source_tokens_lens * Fixed issues with memory pinning due to rebasing and also adjusted config position of flag * Reverting unadvert changes --------- Co-authored-by: Javad Kasravi Co-authored-by: Javad Kasravi Co-authored-by: Javad kasravi --- config/config_physical_jepa.yml | 5 ++ config/default_config.yml | 5 ++ src/weathergen/datasets/batch.py | 43 +++++++++++++++++ src/weathergen/datasets/memory_pinning.py | 42 +++++++++++++++++ src/weathergen/datasets/stream_data.py | 56 +++++++++++++++++++++++ src/weathergen/train/trainer.py | 10 +++- uv.lock | 4 +- 7 files changed, 161 insertions(+), 4 deletions(-) create mode 100644 src/weathergen/datasets/memory_pinning.py diff --git a/config/config_physical_jepa.yml b/config/config_physical_jepa.yml index 5b7812cd4..b2b9d7267 100644 --- a/config/config_physical_jepa.yml +++ b/config/config_physical_jepa.yml @@ -126,6 +126,11 @@ data_loading : num_workers: 12 rng_seed: ??? + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + # config for training training_config: diff --git a/config/default_config.yml b/config/default_config.yml index ea5f4a98d..89056d133 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -126,6 +126,11 @@ data_loading : rng_seed: ??? repeat_data_in_mini_epoch : False + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + # config for training training_config: diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index efb3f08e2..e0dc59d62 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -34,6 +34,25 @@ class Sample: # keys: stream_name, values: StreamData streams_data: dict[str, StreamData | None] + def pin_memory(self): + """Pin all tensors in this Sample to CPU pinned memory""" + + # Pin StreamData objects in streams_data dict + if hasattr(self, "streams_data") and isinstance(self.streams_data, dict): + for _stream_name, stream_data in self.streams_data.items(): + if stream_data is not None and hasattr(stream_data, "pin_memory"): + stream_data.pin_memory() + + # Pin tensors in meta_info + if hasattr(self, "meta_info") and isinstance(self.meta_info, dict): + for _key, meta_data in self.meta_info.items(): + if isinstance(meta_data, SampleMetaData): + # Pin mask tensor + if meta_data.mask is not None and isinstance(meta_data.mask, torch.Tensor): + meta_data.mask = meta_data.mask.pin_memory() + + return self + def __init__(self, streams: dict) -> None: self.meta_info = {} @@ -156,6 +175,19 @@ def get_device(self) -> str | torch.device: """ return self.device + def pin_memory(self): + """Pin all tensors in this batch to CPU pinned memory""" + + # pin all samples + for sample in self.samples: + sample.pin_memory() + + # pin source_tokens_lens + if isinstance(self.tokens_lens, torch.Tensor): + self.tokens_lens = self.tokens_lens.pin_memory() + + return self + class ModelBatch: """ @@ -186,6 +218,17 @@ def __init__(self, streams: dict, num_source_samples: int, num_target_samples: i self.source2target_matching_idxs = np.full(num_source_samples, -1, dtype=np.int32) self.target2source_matching_idxs = [[] for _ in range(num_target_samples)] + def pin_memory(self): + """Pin all tensors in this batch to CPU pinned memory""" + + # pin source samples + self.source_samples.pin_memory() + + # pin target samples + self.target_samples.pin_memory() + + return self + def to_device(self, device): # -> ModelBatch """ Move batch to device diff --git a/src/weathergen/datasets/memory_pinning.py b/src/weathergen/datasets/memory_pinning.py new file mode 100644 index 000000000..da69e6e48 --- /dev/null +++ b/src/weathergen/datasets/memory_pinning.py @@ -0,0 +1,42 @@ +from typing import Protocol, runtime_checkable + +import torch + +from weathergen.common.io import IOReaderData + + +@runtime_checkable +class Pinnable(Protocol): + """ + Protocol that allows the pytorch content of a data structure + to be pinned to the memory of the current accelerator. + + This extends the pin_memory() capability of a torch Tensor + to other classes. + + It is blocking. + """ + + def pin_memory(self): ... + + +def pin_object(obj: Pinnable | torch.Tensor | IOReaderData | list | dict | None): + if obj is None: + return + elif isinstance(obj, torch.Tensor | Pinnable): + obj.pin_memory() + elif isinstance(obj, IOReaderData): + # Special case: IOReaderData is in common package and can't have torch deps + # Note: These SHOULD be numpy arrays per the type hints, but might be tensors + pin_object(obj.coords) + pin_object(obj.data) + pin_object(obj.geoinfos) + + elif isinstance(obj, list): + # Assume the list is a list of potentially pinnable objects and traverse it. + for e in obj: + pin_object(e) + elif isinstance(obj, dict): + # Assume the values are pinnable. + for e in obj.values(): + pin_object(e) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 407f38354..da6ab67ca 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -14,6 +14,37 @@ from weathergen.common.io import IOReaderData +def _pin_tensor(tensor: torch.Tensor) -> torch.Tensor: + """Pin a tensor to CPU pinned memory. + + Parameters + ---------- + tensor : torch.Tensor + + Returns + ------- + torch.Tensor + The pinned tensor. + """ + return tensor.pin_memory() if isinstance(tensor, torch.Tensor) else tensor + + +def _pin_tensor_list(tensor_list: list) -> list: + """Pin all tensors in a list to CPU pinned memory. + + Parameters + ---------- + tensor_list : list + List of tensors (or other objects) to pin. + + Returns + ------- + list + List with all torch.Tensor elements pinned to CPU pinned memory. + """ + return [_pin_tensor(t) for t in tensor_list] + + class StreamData: """ StreamData object that encapsulates all data the model ingests for one batch item @@ -75,6 +106,31 @@ def __init__(self, idx: int, input_steps: int, forecast_steps: int, healpix_cell self.source_idxs_embed = [torch.tensor([]) for _ in range(self.input_steps)] self.source_idxs_embed_pe = [torch.tensor([]) for _ in range(self.input_steps)] + def pin_memory(self): + """Pin all tensors in this StreamData object to CPU pinned memory""" + + # Pin target tensors + self.target_coords = _pin_tensor_list(self.target_coords) + self.target_coords_lens = _pin_tensor_list(self.target_coords_lens) + self.target_tokens = _pin_tensor_list(self.target_tokens) + self.target_tokens_lens = _pin_tensor_list(self.target_tokens_lens) + self.idxs_inv = _pin_tensor_list(self.idxs_inv) + self.target_coords_raw = _pin_tensor_list(self.target_coords_raw) + + # Pin source tensors + self.source_tokens_cells = _pin_tensor_list(self.source_tokens_cells) + self.source_tokens_lens = _pin_tensor_list(self.source_tokens_lens) + self.source_idxs_embed = _pin_tensor_list(self.source_idxs_embed) + self.source_idxs_embed_pe = _pin_tensor_list(self.source_idxs_embed_pe) + + # Pin source_raw (list of IOReaderData objects) + if hasattr(self, "source_raw"): + for raw_data in self.source_raw: + if raw_data is not None and hasattr(raw_data, "pin_memory"): + raw_data.pin_memory() + + return self + def to_device(self, device: str) -> None: """ Move data to GPU diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 4d4db7e4c..da284a01b 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -180,7 +180,6 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): "batch_sampler": None, "shuffle": False, "num_workers": loader_num_workers, - "pin_memory": True, } self.data_loader_validation = torch.utils.data.DataLoader( self.dataset, **loader_params, sampler=None @@ -226,7 +225,6 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): "batch_sampler": None, "shuffle": False, "num_workers": cf.data_loading.num_workers, - "pin_memory": True, } self.data_loader = torch.utils.data.DataLoader(self.dataset, **loader_params, sampler=None) self.data_loader_validation = torch.utils.data.DataLoader( @@ -398,6 +396,10 @@ def train(self, mini_epoch): # training loop self.t_start = time.time() for bidx, batch in enumerate(dataset_iter): + if cf.data_loading.get("memory_pinning", False): + # pin memory for faster CPU-GPU transfer + batch = batch.pin_memory() + batch.to_device(self.device) with torch.autocast( @@ -512,6 +514,10 @@ def validate(self, mini_epoch, mode_cfg, batch_size): # print progress bar but only in interactive mode, i.e. when without ddp with tqdm.tqdm(total=mode_cfg.samples_per_mini_epoch, disable=self.cf.with_ddp) as pbar: for bidx, batch in enumerate(dataset_val_iter): + if cf.data_loading.get("memory_pinning", False): + # pin memory for faster CPU-GPU transfer + batch = batch.pin_memory() + batch.to_device(self.device) # evaluate model diff --git a/uv.lock b/uv.lock index 658b977d4..491002ade 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = "==3.12.*" resolution-markers = [ "platform_machine == 'aarch64' and sys_platform == 'linux'", @@ -1103,7 +1103,7 @@ name = "jinja2" version = "3.1.6" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "markupsafe", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "markupsafe", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } wheels = [ From 15a8c291bb00f50e790677b305cbb4bc298f9052 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 20 Jan 2026 12:05:27 +0100 Subject: [PATCH 12/44] Allows for writing normalized samples; fixed config to keep it well-structured (#1653) --- config/default_config.yml | 13 +++++++++---- src/weathergen/run_train.py | 2 +- src/weathergen/train/trainer.py | 15 +++++++++++---- src/weathergen/utils/validation_io.py | 2 +- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 89056d133..4f6eacb5e 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -207,10 +207,15 @@ validation_config: ema_ramp_up_ratio: 0.09 ema_halflife_in_thousands: 1e-3 - # number of validation samples that are written to disk - write_num_samples: 0 - # output streams to write; default all - output_streams: null + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 8, + # write samples in normalized model space + normalized_samples: True, + # output streams to write; default all + streams: null, + } # run validation before training starts (mainly for model development) validate_before_training: False diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 0189f9766..2af04edfe 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -47,7 +47,7 @@ def inference_from_args(argl: list[str]): start_date=args.start_date, end_date=args.end_date, samples_per_mini_epoch=args.samples, - write_num_samples=args.samples if args.save_samples else 0, + output=dict(write_num_samples=args.samples if args.save_samples else 0), streams_output=args.streams_output, ) } diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index da284a01b..ad56e1367 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -510,6 +510,8 @@ def validate(self, mini_epoch, mode_cfg, batch_size): dataset_val_iter = iter(self.data_loader_validation) + num_samples_write = mode_cfg.get("output", {}).get("num_samples", 0) * batch_size + with torch.no_grad(): # print progress bar but only in interactive mode, i.e. when without ddp with tqdm.tqdm(total=mode_cfg.samples_per_mini_epoch, disable=self.cf.with_ddp) as pbar: @@ -557,16 +559,21 @@ def validate(self, mini_epoch, mode_cfg, batch_size): ) # log output - num_samples = mode_cfg.get("write_num_samples", 0) * batch_size - if bidx < num_samples: - dn_data = self.dataset_val.denormalize_target_channels + if bidx < num_samples_write: + # denormalization function for data + denormalize_data_fct = ( + (lambda x0, x1: x1) + if mode_cfg.get("output", {}).get("normalized_samples", False) + else self.dataset_val.denormalize_target_channels + ) + # write output write_output( self.cf, mode_cfg, batch_size, mini_epoch, bidx, - dn_data, + denormalize_data_fct, batch, preds, targets_and_auxs, diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 72813ca51..69be9854b 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -113,7 +113,7 @@ def write_output( # output stream names to be written, use specified ones or all if nothing specified stream_names = [stream.name for stream in cf.streams] - if val_cfg.get("streams_output") is not None: + if val_cfg.get("output").get("streams") is not None: output_stream_names = val_cfg.streams_output else: output_stream_names = stream_names From 9e912201cc7c7386da73600d7e73c2416992ce0b Mon Sep 17 00:00:00 2001 From: s6sebusc <49226935+s6sebusc@users.noreply.github.com> Date: Tue, 20 Jan 2026 14:44:57 +0100 Subject: [PATCH 13/44] Skipping missing scores in JSONreader (#1655) * split WeatherGenReader functionality to allow reading only JSON adding weathergen JSON reader to develop * informative error when metrics are not there * restore JSONreader after rebase * JSONreader mostly restored * MLFlow logging independent of JSON/zarr * linting, properly cheking fsteps, ens, samples in JSONreader * tiny change to restore the MergeReader * lint * enabling JSONreader to skip plots and missing scores gracefully * required reformatting * move skipping of metrics to the reader class * slighly more explicit formulations --------- Co-authored-by: Sebastian Buschow Co-authored-by: Sebastian Buschow Co-authored-by: iluise <72020169+iluise@users.noreply.github.com> Co-authored-by: Ilaria Luise --- .../src/weathergen/evaluate/io/io_reader.py | 5 ++-- .../weathergen/evaluate/io/wegen_reader.py | 28 +++++++++++-------- .../src/weathergen/evaluate/run_evaluation.py | 15 +++++----- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py index ffeb7eeb7..b218d2615 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py @@ -223,7 +223,7 @@ def check_availability( if available_data is not None and reader_data[name] != available[name]: self._logger.info( f"Requested all {name}s for {mode}, but previous config was a " - "strict subset. Recomputing." + "strict subset. Recomputation required." ) check_score = False @@ -245,7 +245,8 @@ def check_availability( if available_data is not None and not requested[name] <= available[name]: missing = requested[name] - available[name] self._logger.info( - f"{name.capitalize()}(s) {missing} missing in previous evaluation. Recomputing." + f"{name.capitalize()}(s) {missing} missing in previous evaluation." + "Recomputation required." ) check_score = False diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index c51227164..4e2ee77c0 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -174,8 +174,9 @@ def load_scores( ------- xr.DataArray The metric DataArray. - missing_metrics: - dictionary of missing regions and metrics that need to be recomputed. + computable_metrics: + dictionary of regions and metrics that can be recomputed + (empty for JSONreader). """ local_scores = {} @@ -199,8 +200,8 @@ def load_scores( # all other cases: recompute scores missing_metrics.setdefault(region, []).append(metric) continue - - return local_scores, missing_metrics + recomputable_missing_metrics = self.get_recomputable_metrics(missing_metrics) + return local_scores, recomputable_missing_metrics def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArray | None: """ @@ -214,11 +215,15 @@ def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArr if score_path.exists(): with open(score_path) as f: data_dict = json.load(f) - score = xr.DataArray.from_dict(data_dict) # not a dict though + score = xr.DataArray.from_dict(data_dict) else: score = None return score + def get_recomputable_metrics(self, metrics): + """determine whether given metrics can be re-computed.""" + return metrics + def get_inference_stream_attr(self, stream_name: str, key: str, default=None): """ Get the value of a key for a specific stream from the a model config. @@ -264,12 +269,7 @@ def __init__( for region in regions: for metric in metrics: score = self.load_single_score(stream, region, metric) - if score is None: - raise ValueError( - f"JSONreader couldn't find {metric} for {run_id}, stream {stream}, " - f"region {region}. Use type: zarr instead if possible." - ) - else: + if score is not None: for name in coord_names: vals = set(score[name].values) all_coords[name].append(vals) @@ -299,6 +299,12 @@ def get_data(self, *args, **kwargs): # it can still happen when a particular score was available for a different channel raise ValueError(f"Missing JSON data for run {self.run_id}.") + def get_recomputable_metrics(self, metrics): + _logger.info( + f"The following metrics have not yet been computed:{metrics}. Use type: zarr for that." + ) + return {} + class WeatherGenZarrReader(WeatherGenReader): def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index f8d5f8a43..8c73c93e7 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -232,7 +232,7 @@ def _process_stream( plot_score_maps: Bool to define if the score maps need to be plotted or not. """ - # try: + type_ = run.get("type", "zarr") reader = get_reader(type_, run, run_id, private_paths, regions, metrics) @@ -241,28 +241,29 @@ def _process_stream( return run_id, stream, {} # Parallel plotting - if stream_dict.get("plotting"): + if stream_dict.get("plotting") and type_ == "zarr": plot_data(reader, stream, global_plotting_opts) # Scoring per stream if not stream_dict.get("evaluation"): return run_id, stream, {} - stream_loaded_scores, missing_metrics = reader.load_scores( + stream_loaded_scores, recomputable_metrics = reader.load_scores( stream, regions, metrics, ) scores_dict = stream_loaded_scores - if missing_metrics or plot_score_maps: - regions_to_compute = list(set(missing_metrics.keys())) if missing_metrics else regions - metrics_to_compute = missing_metrics if missing_metrics else metrics + if recomputable_metrics or (plot_score_maps and type_ == "zarr"): + regions_to_compute = ( + list(set(recomputable_metrics.keys())) if recomputable_metrics else regions + ) + metrics_to_compute = recomputable_metrics if recomputable_metrics else metrics stream_computed_scores = calc_scores_per_stream( reader, stream, regions_to_compute, metrics_to_compute, plot_score_maps ) - metric_list_to_json(reader, stream, stream_computed_scores, regions) scores_dict = merge(stream_loaded_scores, stream_computed_scores) From c191c4f5e698a7ff72a04d843f3b815ec5ab5e32 Mon Sep 17 00:00:00 2001 From: Kacper Nowak Date: Sat, 24 Jan 2026 00:07:16 +0100 Subject: [PATCH 14/44] Remove target type config (#1651) * Add target type value error * Remove type * Remove unused code * Commit what shall have been committed * Remove target readout type from config * Add computing stream names to embedding engine --------- Co-authored-by: Christian Lessig --- config/streams/cams/cams_an.yml | 1 - config/streams/cams/cams_eac4.yml | 1 - config/streams/cerra_seviri/cerra.yml | 1 - config/streams/cerra_seviri/seviri.yml | 1 - config/streams/era5_1deg/era5.yml | 1 - config/streams/era5_decoding_synop/era5.yml | 1 - config/streams/era5_decoding_synop/synop.yml | 1 - config/streams/era5_nppatms_synop/era5.yml | 1 - .../streams/era5_nppatms_synop/npp_atms.yml | 1 - config/streams/era5_nppatms_synop/synop.yml | 1 - config/streams/fesom/fesom.yml | 1 - config/streams/fesom/fesom_elem.yml | 1 - config/streams/fesom/ifs.yml | 1 - .../icon_esm_historical_day/icon_esm_Oday.yml | 1 - .../icon_esm_SIday.yml | 1 - .../icon_esm_historical_day/icon_esm_day.yml | 1 - .../icon_esm_AERmon.yml | 1 - .../icon_esm_historical_mon/icon_esm_Amon.yml | 1 - .../icon_esm_historical_mon/icon_esm_Emon.yml | 1 - .../icon_esm_LImon.yml | 1 - .../icon_esm_historical_mon/icon_esm_Lmon.yml | 1 - .../icon_esm_historical_mon/icon_esm_Omon.yml | 1 - .../icon_esm_SImon.yml | 1 - config/streams/iconart_6h/icon.yml | 1 - config/streams/igra/igra.yml | 1 - integration_tests/streams/era5_small.yml | 1 - .../streams_multi/era5_small.yml | 1 - integration_tests/streams_multi/npp_atms.yml | 1 - integration_tests/streams_multi/synop.yml | 1 - src/weathergen/model/encoder.py | 2 +- src/weathergen/model/engines.py | 18 ++---------- src/weathergen/model/model.py | 29 +------------------ 32 files changed, 5 insertions(+), 73 deletions(-) diff --git a/config/streams/cams/cams_an.yml b/config/streams/cams/cams_an.yml index baafe5b93..b2d727087 100644 --- a/config/streams/cams/cams_an.yml +++ b/config/streams/cams/cams_an.yml @@ -68,7 +68,6 @@ CAMSANALYSIS : net : linear dim_embed : 512 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/cams/cams_eac4.yml b/config/streams/cams/cams_eac4.yml index 35607b7d8..c123f5d23 100644 --- a/config/streams/cams/cams_eac4.yml +++ b/config/streams/cams/cams_eac4.yml @@ -87,7 +87,6 @@ CAMSEAC4 : net : linear dim_embed : 512 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/cerra_seviri/cerra.yml b/config/streams/cerra_seviri/cerra.yml index ea6d17e1d..1f2899d82 100644 --- a/config/streams/cerra_seviri/cerra.yml +++ b/config/streams/cerra_seviri/cerra.yml @@ -24,7 +24,6 @@ CERRA : net : linear dim_embed : 256 target_readout : - type : 'obs_value' num_layers : 2 num_heads : 4 pred_head : diff --git a/config/streams/cerra_seviri/seviri.yml b/config/streams/cerra_seviri/seviri.yml index 4dabf68cd..e860f1da5 100644 --- a/config/streams/cerra_seviri/seviri.yml +++ b/config/streams/cerra_seviri/seviri.yml @@ -24,7 +24,6 @@ SEVIRI : net : linear dim_embed : 256 target_readout : - type : 'obs_value' num_layers : 2 num_heads : 8 pred_head : diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index 429682d85..26455bf4e 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -30,7 +30,6 @@ ERA5 : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/era5_decoding_synop/era5.yml b/config/streams/era5_decoding_synop/era5.yml index 2108a2cf9..c04f77a24 100644 --- a/config/streams/era5_decoding_synop/era5.yml +++ b/config/streams/era5_decoding_synop/era5.yml @@ -30,7 +30,6 @@ ERA5 : net : linear dim_embed : 4 target_readout : - type : 'obs_value' num_layers : 2 num_heads : 4 pred_head : diff --git a/config/streams/era5_decoding_synop/synop.yml b/config/streams/era5_decoding_synop/synop.yml index 32692cd0f..60ab8134a 100644 --- a/config/streams/era5_decoding_synop/synop.yml +++ b/config/streams/era5_decoding_synop/synop.yml @@ -25,7 +25,6 @@ SurfaceCombined : net : linear dim_embed : 128 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 pred_head : diff --git a/config/streams/era5_nppatms_synop/era5.yml b/config/streams/era5_nppatms_synop/era5.yml index 90d0b9790..6f22898f0 100644 --- a/config/streams/era5_nppatms_synop/era5.yml +++ b/config/streams/era5_nppatms_synop/era5.yml @@ -30,7 +30,6 @@ ERA5 : net : linear dim_embed : 128 target_readout : - type : 'obs_value' num_layers : 2 num_heads : 4 pred_head : diff --git a/config/streams/era5_nppatms_synop/npp_atms.yml b/config/streams/era5_nppatms_synop/npp_atms.yml index 75302f443..e638b9bc7 100644 --- a/config/streams/era5_nppatms_synop/npp_atms.yml +++ b/config/streams/era5_nppatms_synop/npp_atms.yml @@ -24,7 +24,6 @@ NPPATMS : net : linear dim_embed : 128 target_readout : - type : 'obs_value' num_layers : 1 num_heads : 4 pred_head : diff --git a/config/streams/era5_nppatms_synop/synop.yml b/config/streams/era5_nppatms_synop/synop.yml index ce9adfa44..754d9f113 100644 --- a/config/streams/era5_nppatms_synop/synop.yml +++ b/config/streams/era5_nppatms_synop/synop.yml @@ -23,7 +23,6 @@ SurfaceCombined : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 pred_head : diff --git a/config/streams/fesom/fesom.yml b/config/streams/fesom/fesom.yml index f3c1c85d9..511560445 100644 --- a/config/streams/fesom/fesom.yml +++ b/config/streams/fesom/fesom.yml @@ -27,7 +27,6 @@ FESOM_NODE : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/fesom/fesom_elem.yml b/config/streams/fesom/fesom_elem.yml index 7126d7e52..3d9451d55 100644 --- a/config/streams/fesom/fesom_elem.yml +++ b/config/streams/fesom/fesom_elem.yml @@ -27,7 +27,6 @@ FESOM_ELEM : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/fesom/ifs.yml b/config/streams/fesom/ifs.yml index 594d5c013..227f552ce 100644 --- a/config/streams/fesom/ifs.yml +++ b/config/streams/fesom/ifs.yml @@ -28,7 +28,6 @@ IFS_ATMO : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/icon_esm_historical_day/icon_esm_Oday.yml b/config/streams/icon_esm_historical_day/icon_esm_Oday.yml index e43ca55cd..07e941469 100644 --- a/config/streams/icon_esm_historical_day/icon_esm_Oday.yml +++ b/config/streams/icon_esm_historical_day/icon_esm_Oday.yml @@ -31,7 +31,6 @@ ICONESMOday : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/icon_esm_historical_day/icon_esm_SIday.yml b/config/streams/icon_esm_historical_day/icon_esm_SIday.yml index 36c6a6816..87ec16e3f 100644 --- a/config/streams/icon_esm_historical_day/icon_esm_SIday.yml +++ b/config/streams/icon_esm_historical_day/icon_esm_SIday.yml @@ -31,7 +31,6 @@ ICONESMSIday : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/icon_esm_historical_day/icon_esm_day.yml b/config/streams/icon_esm_historical_day/icon_esm_day.yml index 31e564268..2bdb69c10 100644 --- a/config/streams/icon_esm_historical_day/icon_esm_day.yml +++ b/config/streams/icon_esm_historical_day/icon_esm_day.yml @@ -33,7 +33,6 @@ ICONESMday : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/icon_esm_historical_mon/icon_esm_AERmon.yml b/config/streams/icon_esm_historical_mon/icon_esm_AERmon.yml index 8aa04508a..fb5ace82b 100644 --- a/config/streams/icon_esm_historical_mon/icon_esm_AERmon.yml +++ b/config/streams/icon_esm_historical_mon/icon_esm_AERmon.yml @@ -31,7 +31,6 @@ ICONESMAERmon : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/icon_esm_historical_mon/icon_esm_Amon.yml b/config/streams/icon_esm_historical_mon/icon_esm_Amon.yml index bd531d2bb..66ced4973 100644 --- a/config/streams/icon_esm_historical_mon/icon_esm_Amon.yml +++ b/config/streams/icon_esm_historical_mon/icon_esm_Amon.yml @@ -55,7 +55,6 @@ ICONESMAmon : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/icon_esm_historical_mon/icon_esm_Emon.yml b/config/streams/icon_esm_historical_mon/icon_esm_Emon.yml index 8948d49e8..bae69d6bf 100644 --- a/config/streams/icon_esm_historical_mon/icon_esm_Emon.yml +++ b/config/streams/icon_esm_historical_mon/icon_esm_Emon.yml @@ -31,7 +31,6 @@ ICONESMEmon : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/icon_esm_historical_mon/icon_esm_LImon.yml b/config/streams/icon_esm_historical_mon/icon_esm_LImon.yml index cb0de80e3..4e21f31a4 100644 --- a/config/streams/icon_esm_historical_mon/icon_esm_LImon.yml +++ b/config/streams/icon_esm_historical_mon/icon_esm_LImon.yml @@ -31,7 +31,6 @@ ICONESMLImon : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/icon_esm_historical_mon/icon_esm_Lmon.yml b/config/streams/icon_esm_historical_mon/icon_esm_Lmon.yml index a17e64022..cd14f5f7b 100644 --- a/config/streams/icon_esm_historical_mon/icon_esm_Lmon.yml +++ b/config/streams/icon_esm_historical_mon/icon_esm_Lmon.yml @@ -31,7 +31,6 @@ ICONESMLmon : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/icon_esm_historical_mon/icon_esm_Omon.yml b/config/streams/icon_esm_historical_mon/icon_esm_Omon.yml index f6d7a5010..d401c5534 100644 --- a/config/streams/icon_esm_historical_mon/icon_esm_Omon.yml +++ b/config/streams/icon_esm_historical_mon/icon_esm_Omon.yml @@ -46,7 +46,6 @@ ICONESMOmon : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/icon_esm_historical_mon/icon_esm_SImon.yml b/config/streams/icon_esm_historical_mon/icon_esm_SImon.yml index 3db96f552..baecdc81f 100644 --- a/config/streams/icon_esm_historical_mon/icon_esm_SImon.yml +++ b/config/streams/icon_esm_historical_mon/icon_esm_SImon.yml @@ -33,7 +33,6 @@ ICONESMSImon : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/iconart_6h/icon.yml b/config/streams/iconart_6h/icon.yml index 9dcd9e11c..3e5abf00f 100644 --- a/config/streams/iconart_6h/icon.yml +++ b/config/streams/iconart_6h/icon.yml @@ -35,7 +35,6 @@ ICONART : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 # sampling_rate : 0.2 diff --git a/config/streams/igra/igra.yml b/config/streams/igra/igra.yml index a7016e71c..791c46253 100644 --- a/config/streams/igra/igra.yml +++ b/config/streams/igra/igra.yml @@ -26,7 +26,6 @@ IGRA : net : linear dim_embed : 256 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 pred_head : diff --git a/integration_tests/streams/era5_small.yml b/integration_tests/streams/era5_small.yml index 6910a427a..1d06d5308 100644 --- a/integration_tests/streams/era5_small.yml +++ b/integration_tests/streams/era5_small.yml @@ -32,7 +32,6 @@ ERA5 : net : linear dim_embed : 16 target_readout : - type : 'obs_value' num_layers : 2 num_heads : 2 pred_head : diff --git a/integration_tests/streams_multi/era5_small.yml b/integration_tests/streams_multi/era5_small.yml index 04d47ca99..b08ab0fa3 100644 --- a/integration_tests/streams_multi/era5_small.yml +++ b/integration_tests/streams_multi/era5_small.yml @@ -23,7 +23,6 @@ ERA5: net: linear dim_embed: 16 target_readout: - type: 'obs_value' num_layers: 2 num_heads: 2 pred_head: diff --git a/integration_tests/streams_multi/npp_atms.yml b/integration_tests/streams_multi/npp_atms.yml index 6affb1da1..f933b7453 100644 --- a/integration_tests/streams_multi/npp_atms.yml +++ b/integration_tests/streams_multi/npp_atms.yml @@ -18,7 +18,6 @@ NPPATMS : net : linear dim_embed : 32 target_readout : - type : 'obs_value' num_layers : 2 num_heads : 4 pred_head : diff --git a/integration_tests/streams_multi/synop.yml b/integration_tests/streams_multi/synop.yml index 461bde9ab..73eb55be1 100644 --- a/integration_tests/streams_multi/synop.yml +++ b/integration_tests/streams_multi/synop.yml @@ -18,7 +18,6 @@ SurfaceCombined : net : linear dim_embed : 32 target_readout : - type : 'obs_value' # token or obs_value num_layers : 2 num_heads : 4 pred_head : diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 11c6901fd..f7cb0ea22 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -58,7 +58,7 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord # determine stream names once so downstream components use consistent keys self.stream_names = [str(stream_cfg["name"]) for stream_cfg in cf.streams] # separate embedding networks for differnt observation types - self.embed_engine = EmbeddingEngine(cf, self.sources_size, self.stream_names) + self.embed_engine = EmbeddingEngine(cf, self.sources_size) assert cf.ae_global_att_dense_rate == 1.0, "Local attention not adapted for register tokens" self.num_register_tokens = cf.num_register_tokens diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 55c985a12..a75331c6e 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -35,24 +35,19 @@ class EmbeddingEngine(torch.nn.Module): name: "EmbeddingEngine" - def __init__(self, cf: Config, sources_size, stream_names: list[str]) -> None: + def __init__(self, cf: Config, sources_size) -> None: """ Initialize the EmbeddingEngine with the configuration. :param cf: Configuration object containing parameters for the engine. :param sources_size: List of source sizes for each stream. - :param stream_names: Ordered list of stream identifiers aligned with cf.streams. """ super(EmbeddingEngine, self).__init__() self.cf = cf self.dtype = get_dtype(self.cf.mixed_precision_dtype) self.sources_size = sources_size # KCT:iss130, what is this? self.embeds = torch.nn.ModuleDict() - self.stream_names = list(stream_names) - - assert len(self.stream_names) == len(self.cf.streams), ( - "stream_names must align with cf.streams" - ) + self.stream_names = [str(stream_cfg["name"]) for stream_cfg in cf.streams] for i, (si, stream_name) in enumerate(zip(self.cf.streams, self.stream_names, strict=True)): if si.get("diagnostic", False) or self.sources_size[i] == 0: @@ -538,7 +533,6 @@ def __init__( tr_dim_head_proj, tr_mlp_hidden_factor, softcap, - tro_type, stream_name: str, ): """ @@ -550,7 +544,6 @@ def __init__( :param tr_dim_head_proj: Dimension for head projection. :param tr_mlp_hidden_factor: Hidden factor for the MLP layers. :param softcap: Softcap value for the attention layers. - :param tro_type: Type of target readout (e.g., "obs_value"). """ super(TargetPredictionEngineClassic, self).__init__() self.name = f"TargetPredictionEngine_{stream_name}" @@ -561,7 +554,6 @@ def __init__( self.tr_dim_head_proj = tr_dim_head_proj self.tr_mlp_hidden_factor = tr_mlp_hidden_factor self.softcap = softcap - self.tro_type = tro_type self.tte = torch.nn.ModuleList() for i in range(len(self.dims_embed) - 1): @@ -605,7 +597,7 @@ def __init__( MLP( self.dims_embed[i], self.dims_embed[i + 1], - with_residual=(self.cf.pred_dyadic_dims or self.tro_type == "obs_value"), + with_residual=True, hidden_factor=self.tr_mlp_hidden_factor, dropout_rate=0.1, # Assuming dropout_rate is 0.1 norm_type=self.cf.norm_type, @@ -646,7 +638,6 @@ def __init__( tr_dim_head_proj, tr_mlp_hidden_factor, softcap, - tro_type, stream_name: str, ): """ @@ -658,7 +649,6 @@ def __init__( :param tr_dim_head_proj: Dimension for head projection. :param tr_mlp_hidden_factor: Hidden factor for the MLP layers. :param softcap: Softcap value for the attention layers. - :param tro_type: Type of target readout (e.g., "obs_value"). the decoder_type decides the how the conditioning is done @@ -679,7 +669,6 @@ def __init__( self.tr_dim_head_proj = tr_dim_head_proj self.tr_mlp_hidden_factor = tr_mlp_hidden_factor self.softcap = softcap - self.tro_type = tro_type # For backwards compatibility @@ -771,7 +760,6 @@ def __init__( attention_kwargs=attention_kwargs, tr_dim_head_proj=tr_dim_head_proj, tr_mlp_hidden_factor=tr_mlp_hidden_factor, - tro_type=tro_type, mlp_norm_eps=self.cf.mlp_norm_eps, ) ) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index a4d179443..fc08a31eb 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -322,39 +322,13 @@ def create(self) -> "Model": # extract and setup relevant parameters etc = si["embed_target_coords"] - tro_type = ( - si["target_readout"]["type"] if "type" in si["target_readout"] else "token" - ) - dim_embed = si["embed_target_coords"]["dim_embed"] - dim_out = max( - dim_embed, - si["token_size"] * self.targets_num_channels[i_stream], - ) tr = si["target_readout"] num_layers = tr["num_layers"] tr_mlp_hidden_factor = tr["mlp_hidden_factor"] if "mlp_hidden_factor" in tr else 2 tr_dim_head_proj = tr["dim_head_proj"] if "dim_head_proj" in tr else None softcap = tr["softcap"] if "softcap" in tr else 0.0 - if tro_type == "obs_value": - # fixed dimension for obs_value type - dims_embed = [ - si["embed_target_coords"]["dim_embed"] for _ in range(num_layers + 1) - ] - else: - if cf.pred_dyadic_dims: - coord_dim = self.geoinfo_sizes[i_stream] * si["token_size"] - dims_embed = torch.tensor( - [dim_out // 2**i for i in range(num_layers - 1, -1, -1)] + [dim_out] - ) - dims_embed[dims_embed < coord_dim] = dims_embed[ - torch.where(dims_embed >= coord_dim)[0][0] - ] - dims_embed = dims_embed.tolist() - else: - dims_embed = torch.linspace( - dim_embed, dim_out, num_layers + 1, dtype=torch.int32 - ).tolist() + dims_embed = [si["embed_target_coords"]["dim_embed"] for _ in range(num_layers + 1)] if is_root(): logger.info("{} :: coord embed: :: {}".format(si["name"], dims_embed)) @@ -403,7 +377,6 @@ def create(self) -> "Model": tr_dim_head_proj, tr_mlp_hidden_factor, softcap, - tro_type, stream_name=stream_name, ) From c0545520b6a063611d2e9ec90eefc5c3d12b9e6f Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Mon, 26 Jan 2026 10:21:34 +0100 Subject: [PATCH 15/44] add default streams + fix lead time error (#1670) * add default streams + fix lead time error * update config * Correct a bug creating aggr issues on scores (#1685) --------- Co-authored-by: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> --- config/evaluate/eval_config.yml | 100 +++++++++++------- .../weathergen/evaluate/io/wegen_reader.py | 31 +++++- .../src/weathergen/evaluate/run_evaluation.py | 12 ++- src/weathergen/model/model_interface.py | 2 +- 4 files changed, 95 insertions(+), 50 deletions(-) diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index 24bca712a..36f235cd4 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -34,49 +34,59 @@ evaluation: num_processes: 0 #options: int, "auto", 0 means no parallelism (default) # baseline: "ar40mckx" + +default_streams: + ERA5: + channels: ["2t", "10u"] #, "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + evaluation: + forecast_step: "all" + sample: "all" + ensemble: "all" #supported: "all", "mean", [0,1,2] + plotting: + sample: [1, 3] + forecast_step: [1,3, 2] + ensemble: [0,2,5] #supported: "all", "mean", [0,1,2] + plot_maps: true + plot_target: false + plot_histograms: true + plot_animations: true + CERRA: + channels: ["z_500", "t_850", "u_850"] #, "blah"] + evaluation: + forecast_step: "all" + sample: "all" + plotting: + sample: [2, 3, 0] + forecast_step: [1,3, 4, 5] + plot_maps: true + plot_target: false + plot_histograms: true + plot_animations: true + run_ids : ar40mckx: label: "pretrained model ar40mckx" results_base_dir : "./results/" - mini_epoch: 0 - rank: 0 - streams: - ERA5: - channels: ["2t", "10u"] #, "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] - evaluation: - forecast_step: "all" - sample: "all" - ensemble: "all" #supported: "all", "mean", [0,1,2] - plotting: - sample: [1, 3] - forecast_step: [1,3, 2] #supported: "all", [1,2,3,...], "1-50" (equivalent of [1,2,3,...50]) - ensemble: [0,2,5] #supported: "all", "mean", [0,1,2] - plot_maps: true - plot_target: false - plot_histograms: true - plot_animations: true - CERRA: - channels: ["z_500", "t_850", "u_850"] #, "blah"] - evaluation: - forecast_step: "all" - sample: "all" - plotting: - sample: [2, 3, 0] - forecast_step: [1,3, 4, 5] - plot_maps: true - plot_target: false - plot_histograms: true - plot_animations: true + #NEW: if "streams" is not specified, the default streams are used + c8g5katp: label: "2 steps window" results_base_dir : "./results/" - mini_epoch: 0 - rank: 0 + #NEW: if "streams" is not specified, the default streams are used + + ############################### + + #How to define custom variables different than default_streams (default_streams will be ignored for this run_id): + + jjc9ym62: + label: "2 steps window" + results_base_dir : "./results/" + epoch: 1 #optional: if not specified epoch 0 (in inference it is always 0) is used + rank: 2 #optional: if not specified rank 0 is used streams: ERA5: - #climatology_path: "/aifs-ea-an-oper-0001-mars-o96-1980-2020-6h-v6_climatology.zarr" - channels: ["2t", "10u", "10v"] #, "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + channels: ["2t", "10u", "10v"] evaluation: forecast_step: "all" sample: "all" @@ -90,17 +100,20 @@ run_ids : plot_histograms: true plot_animations: true + + + #WeatherGeneratorMerge example: + ############################### + #Example of syntax to stack multiple runs over the ensemble dimension - merge_test: <------ This is the new run_id name of the merged dataset. NB. you always need to specify one - type: "merge" <------- VERY IMPORTANT + merge_test: #<------ This is the new run_id name of the merged dataset. NB. you always need to specify one + type: "merge" # <------- VERY IMPORTANT merge_run_ids: - so67dku4 - c9cg8ql3 - metrics_dir: "./merge_test/metrics/" <------- VERY IMPORTANT + metrics_dir: "./merge_test/metrics/" #<------- VERY IMPORTANT label: "Merged Results" results_base_dir : "./results/" - epoch: 0 - rank: 0 streams: ERA5: channels: ["z_500", "t_850", "u_850", "v_850", "q_850"] @@ -109,13 +122,14 @@ run_ids : sample: [0, 1, 2, 3] ensemble: "all" + #WeatherGeneratorJSON example: + ############################## + #Example of syntax to run over pre-computed scores when the .zarr output is not available anymore so67dku1: - type: "json" + type: "json" # <------- VERY IMPORTANT label: "WeatherGenerator" results_base_dir : "./results/" - epoch: 0 - rank: 0 streams: ERA5: channels: ["z_500", "t_850", "u_850", "v_850", "q_850"] @@ -124,6 +138,10 @@ run_ids : sample: [0, 1, 2, 3] ensemble: "all" #supported + + # CSV Reader example: + ############################### + #ADVANCED (please handle with care): example of how to use the csv reader to Plot PanguWeather scores computed with quaver pangu: type: "csv" diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 4e2ee77c0..4c71bfa94 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -40,8 +40,8 @@ def __init__( super().__init__(eval_cfg, run_id, private_paths, verbose) # TODO: remove backwards compatibility to "epoch" in Feb. 2026 - self.mini_epoch = eval_cfg.get("mini_epoch", eval_cfg.get("epoch")) - self.rank = eval_cfg.rank + self.mini_epoch = eval_cfg.get("mini_epoch", 0) + self.rank = eval_cfg.get("rank", 0) # Load model configuration and set (run-id specific) directories self.inference_cfg = self.get_inference_config() @@ -515,15 +515,32 @@ def get_data( ######## reader utils ######## - def add_lead_time_coord(self, da: xr.DataArray) -> xr.DataArray: + def add_lead_time_coord(self, da: xr.DataArray, sample_dim="sample") -> xr.DataArray: """ Add lead_time coordinate computed as: valid_time - source_interval_end lead_time has dims (sample, ipoint) and dtype timedelta64[ns]. + + Parameters + ---------- + da : + Input DataArray + sample_dim : + The name of the sample dimension (default is "sample") which should be kept. + Collapse over the others. + Returns + ------- + Returns a Dataset with an added lead_time coordinate. """ - lead_time = np.unique(da["valid_time"]) - da["source_interval_start"] + vt = da["valid_time"] + sis = da["source_interval_start"] + + vt_reduced = vt.min(dim=[d for d in vt.dims if d != sample_dim]) + + lead_time = vt_reduced - sis + return da.assign_coords(lead_time=lead_time) def scale_z_channels(self, data: xr.DataArray, stream: str) -> xr.DataArray: @@ -665,7 +682,7 @@ def _force_consistent_grids(ref: list[xr.DataArray]) -> xr.DataArray: sort_idx = np.lexsort((ref_lon.values, ref_lat.values)) npoints = sort_idx.size aligned = [] - for a in ref: + for i, a in enumerate(ref): a_sorted = a.isel(ipoint=sort_idx) a_sorted = a_sorted.assign_coords( @@ -673,6 +690,10 @@ def _force_consistent_grids(ref: list[xr.DataArray]) -> xr.DataArray: lat=("ipoint", ref_lat.values[sort_idx]), lon=("ipoint", ref_lon.values[sort_idx]), ) + + if "sample" not in a_sorted.dims: + a_sorted = a_sorted.expand_dims(sample=[i]) + aligned.append(a_sorted) return xr.concat(aligned, dim="sample") diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 8c73c93e7..3382ba24d 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -296,6 +296,7 @@ def evaluate_from_config( global_plotting_opts = cfg.get("global_plotting_options", {}) use_parallel = cfg.evaluation.get("num_processes", 0) verbose = cfg.evaluation.get("verbose", True) + default_streams = cfg.get("default_streams", {}) if use_parallel == "auto": num_processes = mp.cpu_count() @@ -319,9 +320,14 @@ def evaluate_from_config( # Build tasks per stream for run_id, run in runs.items(): type_ = run.get("type", "zarr") - reader = get_reader( - type_, run, run_id, private_paths, cfg.evaluation.regions, cfg.evaluation.metrics - ) + + if "streams" not in run: + run["streams"] = default_streams + + regions = cfg.evaluation.regions + metrics = cfg.evaluation.metrics + + reader = get_reader(type_, run, run_id, private_paths, regions, metrics) for stream in reader.streams: tasks.append( diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index df483982a..a0e05d24f 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -62,7 +62,7 @@ def init_model_and_shard( if re.fullmatch(cf.freeze_modules, name) is not None: logger.info(f"Froze weights {name}") freeze_weights(module) - + # TODO: this should be handled in the encoder to be close where q_cells is defined if "q_cells" in cf.freeze_modules: model.encoder.q_cells.requires_grad = False From 3aeb324706253b50ba0797792292f338a71a4b4d Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Mon, 26 Jan 2026 10:43:15 +0100 Subject: [PATCH 16/44] Update normlise output flag (#1681) --- config/default_config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 4f6eacb5e..d5eb4067e 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -212,7 +212,7 @@ validation_config: # number of samples that are written num_samples: 8, # write samples in normalized model space - normalized_samples: True, + normalized_samples: False, # output streams to write; default all streams: null, } @@ -250,4 +250,4 @@ wgtags: # *** Experiment-specific tags *** # All extra tags (including lists, dictionaries, etc.) are treated # as strings by mlflow, so treat all extra tags as simple string key: value pairs. - grid: null \ No newline at end of file + grid: null From 3c183dfffa2829d413f5e56aa1019d984b4d7bf6 Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Mon, 26 Jan 2026 13:00:45 +0100 Subject: [PATCH 17/44] Make ratio plot and bar plots run with just 1 run_id (#1672) * add default streams + fix lead time error * update config * update ratio plots and bar plots for single run * fix title * Update config Added support information for forecast_step configuration. --------- Co-authored-by: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> --- config/evaluate/eval_config.yml | 2 +- .../evaluate/plotting/plot_utils.py | 4 +- .../weathergen/evaluate/plotting/plotter.py | 55 +++++++++++++------ 3 files changed, 42 insertions(+), 19 deletions(-) diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index 36f235cd4..83e0a7b84 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -44,7 +44,7 @@ default_streams: ensemble: "all" #supported: "all", "mean", [0,1,2] plotting: sample: [1, 3] - forecast_step: [1,3, 2] + forecast_step: [1,3, 2] #supported: "all", [1,2,3,...], "1-50" (equivalent of [1,2,3,...50]) ensemble: [0,2,5] #supported: "all", "mean", [0,1,2] plot_maps: true plot_target: false diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 54e3d8572..933aa7ea1 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -334,7 +334,7 @@ def score_card_metric_region( selected_data.append(data) run_ids.append(run_id) - if selected_data and len(selected_data) > 1.0: + if selected_data: sc_plotter._logger.info(f"Creating score cards for {metric} - {region} - {stream}.") name = "_".join([metric, region, stream]) sc_plotter.plot(selected_data, run_ids, metric, channels_set, name) @@ -380,7 +380,7 @@ def bar_plot_metric_region( selected_data.append(data) run_ids.append(run_id) - if selected_data and len(selected_data) > 1.0: + if selected_data: br_plotter._logger.info(f"Creating bar plots for {metric} - {region} - {stream}.") name = "_".join([metric, region, stream]) br_plotter.plot(selected_data, run_ids, metric, channels_set, name) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index ba07df48a..cfbe29037 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -981,18 +981,20 @@ def ratio_plot( data_list, label_list = self._check_lengths(data, labels) if len(data_list) < 2: - self._logger.warning("Ratio plot requires at least two datasets to compare. Skipping.") - return - - baseline_name = self.baseline - baseline_idx = run_ids.index(self.baseline) if self.baseline in run_ids else None - if baseline_idx is not None: - self._logger.info(f"Using baseline run ID '{self.baseline}' for ratio plot.") - baseline = data_list[baseline_idx] - + baseline = xr.full_like(data_list[0], 1.0) + baseline_name = "ones" + descr = "scores" else: - baseline_name = run_ids[0] - baseline = data_list[0] + descr = "ratio_plot" + baseline_name = self.baseline + baseline_idx = run_ids.index(self.baseline) if self.baseline in run_ids else None + if baseline_idx is not None: + self._logger.info(f"Using baseline run ID '{self.baseline}' for ratio plot.") + baseline = data_list[baseline_idx] + + else: + baseline_name = run_ids[0] + baseline = data_list[0] ref_raw = self._preprocess_data(baseline, x_dim, verbose=False) @@ -1025,11 +1027,14 @@ def ratio_plot( linestyle="-", ) - parts = ["ratio_plot", tag] + parts = [descr, tag] name = "_".join(filter(None, parts)) plt.xticks(rotation=90, ha="right") plt.grid(True, linestyle="--", color="gray", alpha=0.2) - title = f"Ratio plot {tag.split('_')[0]} - {tag.split('_')[-1]} (baseline: {baseline_name})" + title = ( + f"{descr.replace('_', ' ')} {tag.split('_')[0]} -" + f" {tag.split('_')[-1]} (baseline: {baseline_name})" + ) self._plot_base(fig, name, x_dim, y_dim, print_summary, line=1.0, vlines=True, title=title) def heat_map( @@ -1525,7 +1530,7 @@ def plot( fig, ax = plt.subplots( 1, - len(runs) - 1, + len(runs) - 1 if len(runs) > 1 else 1, figsize=(5 * len(runs), 2 * len(channels)), dpi=self.dpi_val, squeeze=False, @@ -1536,6 +1541,13 @@ def plot( baseline_idx = runs.index(self.baseline) runs = [runs[baseline_idx]] + runs[:baseline_idx] + runs[baseline_idx + 1 :] data = [data[baseline_idx]] + data[:baseline_idx] + data[baseline_idx + 1 :] + elif len(runs) < 2: + _logger.warning( + "BarPlots:: Less than two runs provided. Generating bar plot against ones." + ) + ones_array = xr.full_like(data[0], 1.0) + runs = [""] + runs + data = [ones_array] + data for run_index in range(1, len(runs)): ratio_score, channels_per_comparison = self.calc_ratio_per_run_id( @@ -1554,10 +1566,20 @@ def plot( np.arange(len(ratio_score)), labels=channels_per_comparison ) ax[run_index - 1].invert_yaxis() - ax[run_index - 1].set_xlabel( + + xlabel = ( f"Relative {data[0].coords['metric'].item().upper()}: " f"Target Model ({runs[run_index]}) / Reference Model ({runs[0]})" ) + + if len(runs) == 2 and runs[0] == "": + xlabel = xlabel.replace("Relative ", "") + xlabel = xlabel.replace( + f"Target Model ({runs[run_index]}) / Reference Model ({runs[0]})", + f"Model ({runs[run_index]})", + ) + + ax[run_index - 1].set_xlabel(xlabel) else: ax[run_index - 1].set_visible(False) # or annotate as missing # Or show a message: @@ -1571,7 +1593,7 @@ def plot( ) self._logger.info(f"Saving bar plots to: {self.out_plot_dir}") - parts = ["bar_plot_compare", tag] + runs + parts = ["bar_plot", tag] + runs name = "_".join(filter(None, parts)) plt.savefig( f"{self.out_plot_dir.joinpath(name)}.{self.image_format}", @@ -1611,6 +1633,7 @@ def calc_ratio_per_run_id( """ ratio_score = [] channels_per_comparison = [] + for _, var in enumerate(channels): if var not in data[0].channel.values or var not in data[run_index].channel.values: continue From 25948f3cdab57976c967cb3c1ed0fb93829df5a4 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa <18899420+mtar@users.noreply.github.com> Date: Mon, 26 Jan 2026 14:51:19 +0100 Subject: [PATCH 18/44] slurm script inference (#1675) * add argument * check stage argument * removed unnecessary code * arbitrary position arguments * Fix error text * get stage info from environment variable. * Update run_train.py --------- Co-authored-by: Simon Grasse --- src/weathergen/run_train.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 2af04edfe..a4ffb1f5b 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -12,6 +12,7 @@ """ import logging +import os import pdb import sys import time @@ -190,9 +191,20 @@ def train_with_args(argl: list[str], stream_dir: str | None): if __name__ == "__main__": - # Entry point for slurm script. - # Check whether --from_run_id passed as argument. - if any("--from_run_id" in arg for arg in sys.argv): - train_continue() + try: + stage = os.environ.get("WEATHERGEN_STAGE") + except KeyError as e: + msg = "missing environment variable 'WEATHERGEN_STAGE'" + raise ValueError(msg) from e + + if stage == "train": + # Entry point for slurm script. + # Check whether --from_run_id passed as argument. + if any("--from_run_id" in arg for arg in sys.argv): + train_continue() + else: + train() + elif stage == "inference": + inference() else: - train() + logger.error("No stage was found.") From 333f662e12b225cedd992a1f97a89500ecb9ac61 Mon Sep 17 00:00:00 2001 From: Belkis Asma SEMCHEDDINE Date: Mon, 26 Jan 2026 15:31:27 +0100 Subject: [PATCH 19/44] clean-up in config.py focusing on shared path (#1579) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * caching get_shared_wg_path() * renaming get_path_output to get_path_results * model and results paths from get_shared_wg_path() and removed _get_config_attribute() * marking get_shared_wg_path() as private * removing set_path() * fixed call to _get_shared_wg_path * fixed import, code clean-up, change caching decorator * changed way of caching _get_shared_wg_base_path * fixed typing error * changes in Refactor shared WG path handling and model config I/O - Simplify get_path_model/get_path_run to always resolve via _get_shared_wg_path() - Change _get_shared_wg_path() to cached, argument-free helper returning the shared working dir from private config - Adjust model config save/load to build filenames relative to the run’s model directory instead of passing parent paths around - Update load_run_config and load_merge_configs to use new path helpers and improve assertion/log messages - Replace internal _get_shared_wg_path("results") usages with get_path_run() in wegen_reader and train_logger * fixed base_path in metrics_path * fixed forgotten config.general * fixed lint raised issues * Improve path handling and add missing docstrings - Add docstrings to 10+ utility functions for better documentation - Refactor load_run_config to improve path construction logic - Move mini_epoch string formatting from _get_model_config_file_read_name to caller for better separation of concerns - Add validation for mini_epoch_str format with descriptive error messages - Fix multi-line docstring format in _load_private_conf * fixed line too long * reverting to previous _get_model_config_file_read_name() * pretty fix for _get_model_config_file_read_name * pretty fix for _get_model_config_file_read_name * removed unused/undefined path --- .../common/src/weathergen/common/config.py | 137 +++++++----------- .../weathergen/evaluate/io/wegen_reader.py | 8 +- src/weathergen/utils/train_logger.py | 4 +- src/weathergen/utils/validation_io.py | 2 +- 4 files changed, 56 insertions(+), 95 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 2926e2037..58c4131b9 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -7,6 +7,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import functools import io import json import logging @@ -84,6 +85,7 @@ def str_to_datetime64(s: str | int | np.datetime64) -> np.datetime64: def _sanitize_start_end_time_keys(sub_conf): + """Convert start_date and end_date keys to datetime resolvers.""" time_keys = ["start_date", "end_date"] for key in time_keys: if key in sub_conf: @@ -93,6 +95,7 @@ def _sanitize_start_end_time_keys(sub_conf): def _sanitize_delta_time_keys(sub_conf): + """Convert time delta keys to timedelta resolvers.""" delta_keys = ["time_window_step", "time_window_len"] for key in delta_keys: if key in sub_conf: @@ -134,6 +137,7 @@ def _sanitize_time_keys(conf: Config) -> Config: def _strip_interpolation(conf: Config) -> Config: + """Remove OmegaConf interpolations and convert timedelta/datetime objects to strings.""" stripped = OmegaConf.create() for key in list(conf.keys()): if key.startswith("_"): @@ -165,12 +169,14 @@ def _strip_interpolation(conf: Config) -> Config: def get_run_id(): + """Generate a random 8-character run ID.""" s1 = string.ascii_lowercase s2 = string.ascii_lowercase + string.digits return "".join(random.sample(s1, 1)) + "".join(random.sample(s2, 7)) def format_cf(config: Config) -> str: + """Format config as a human-readable string.""" stream = io.StringIO() clean_cf = _strip_interpolation(config) for key, value in clean_cf.items(): @@ -188,15 +194,14 @@ def format_cf(config: Config) -> str: def save(config: Config, mini_epoch: int | None): """Save current config into the current runs model directory.""" - path_models = Path(config.model_path) # save in directory with model files - dirname = path_models / config.general.run_id + dirname = get_path_model(config) dirname.mkdir(exist_ok=True, parents=True) - fname = _get_model_config_file_write_name(path_models, config.general.run_id, mini_epoch) + fname = _get_model_config_file_write_name(config.general.run_id, mini_epoch) json_str = json.dumps(OmegaConf.to_container(_strip_interpolation(config))) - with fname.open("w") as f: + with (dirname / fname).open("w") as f: f.write(json_str) @@ -213,24 +218,27 @@ def load_run_config(run_id: str, mini_epoch: int | None, model_path: str | None) Returns: Configuration object loaded from the specified run and mini_epoch. """ + # Loading path if Path(run_id).exists(): # load from the full path if a full path is provided fname = Path(run_id) _logger.info(f"Loading config from provided full run_id path: {fname}") else: # Load model config here. In case model_path is not provided, get it from private conf if model_path is None: - pconf = _load_private_conf() - model_path = _get_config_attribute( - config=pconf, attribute_name="model_path", fallback="models" - ) - path = Path(model_path) - fname = _get_model_config_file_read_name(path, run_id, mini_epoch) + path = get_path_model(run_id=run_id) + else: + path = Path(model_path) + + fname = path / _get_model_config_file_read_name(run_id, mini_epoch) + if not fname.exists(): + # Fallback for old naming convention + # TODO remove compatibility + fname = path / _get_model_config_file_read_name(run_id, mini_epoch, use_old_name=True) assert fname.exists(), ( "The fallback path to the model does not exist. Please provide a `model_path`.", fname, ) - - _logger.info(f"Loading config from specified run_id and mini_epoch: {fname}") + _logger.info(f"Loading config from specified run_id and mini_epoch: {fname}") with fname.open() as f: json_str = f.read() @@ -241,7 +249,8 @@ def load_run_config(run_id: str, mini_epoch: int | None, model_path: str | None) return _apply_fixes(config) -def _get_model_config_file_write_name(path: Path, run_id: str, mini_epoch: int | None): +def _get_model_config_file_write_name(run_id: str, mini_epoch: int | None): + """Generate the filename for writing a model config file.""" if mini_epoch is None: mini_epoch_str = "" elif mini_epoch == -1: @@ -249,21 +258,20 @@ def _get_model_config_file_write_name(path: Path, run_id: str, mini_epoch: int | else: mini_epoch_str = f"_chkpt{mini_epoch:05d}" - return path / run_id / f"model_{run_id}{mini_epoch_str}.json" - + return f"model_{run_id}{mini_epoch_str}.json" -def _get_model_config_file_read_name(path: Path, run_id: str, mini_epoch: int | None): +def _get_model_config_file_read_name(run_id: str, mini_epoch: int | None, use_old_name=False): + """Generate the filename for reading a model config file.""" if mini_epoch is None: mini_epoch_str = "" elif mini_epoch == -1: mini_epoch_str = "_latest" - elif (path / run_id / f"model_{run_id}_epoch{mini_epoch:05d}.json").exists(): - mini_epoch_str = f"_epoch{mini_epoch:05d}" + elif use_old_name: + mini_epoch_str = f"_epoch{mini_epoch:05d}" # TODO remove compatibility else: mini_epoch_str = f"_chkpt{mini_epoch:05d}" - return path / run_id / f"model_{run_id}{mini_epoch_str}.json" - + return f"model_{run_id}{mini_epoch_str}.json" def get_model_results(run_id: str, mini_epoch: int, rank: int) -> Path: """ @@ -361,16 +369,12 @@ def load_merge_configs( c = _load_overwrite_conf(overwrite) c = _load_streams_in_config(c) overwrite_configs.append(c) - private_config = set_paths(private_config) if from_run_id is None: base_config = _load_base_conf(base) else: - base_config = load_run_config( - from_run_id, mini_epoch, private_config.get("model_path", None) - ) + base_config = load_run_config(from_run_id, mini_epoch, None) from_run_id = base_config.general.run_id - with open_dict(base_config): base_config.from_run_id = from_run_id # use OmegaConf.unsafe_merge if too slow @@ -478,9 +482,9 @@ def _load_overwrite_conf(overwrite: Path | dict | DictConfig) -> DictConfig: def _load_private_conf(private_home: Path | None = None) -> DictConfig: - "Return the private configuration." - "If none, take it from the environment variable WEATHERGEN_PRIVATE_CONF." - + """ + Return the private configuration from file or environment variable WEATHERGEN_PRIVATE_CONF. + """ env_script_path = _REPO_ROOT.parent / "WeatherGenerator-private" / "hpc" / "platform-env.py" if private_home is not None and private_home.is_file(): @@ -549,6 +553,7 @@ def _load_base_conf(base: Path | Config | None) -> Config: def load_streams(streams_directory: Path) -> list[Config]: + """Load all stream configurations from a directory.""" # TODO: might want to put this into config later instead of hardcoding it here... streams_history = { "streams_anemoi": "era5_1deg", @@ -613,43 +618,23 @@ def load_streams(streams_directory: Path) -> list[Config]: return list(streams.values()) -def set_paths(config: Config) -> Config: - """Set the configs run_path model_path attributes to default values if not present.""" - config = config.copy() - config.run_path = _get_config_attribute( - config=config, attribute_name="run_path", fallback="results" - ) - config.model_path = _get_config_attribute( - config=config, attribute_name="model_path", fallback="models" - ) - - return config - - -def _get_config_attribute(config: Config, attribute_name: str, fallback: str) -> str: - """Get an attribute from a Config. If not available, fall back to path_shared_working_dir - concatenated with the desired fallback path. Raise an error if neither the attribute nor a - fallback is specified.""" - attribute = OmegaConf.select(config, attribute_name) - fallback_root = OmegaConf.select(config, "path_shared_working_dir") - assert attribute is not None or fallback_root is not None, ( - f"Must specify `{attribute_name}` in config if `path_shared_working_dir` is None in config" - ) - attribute = attribute if attribute else fallback_root + fallback - return attribute - - def get_path_run(config: Config) -> Path: - """Get the current runs run_path for storing run results and logs.""" - return Path(config.run_path) / config.general.run_id + """Get the current runs results_path for storing run results and logs.""" + return _get_shared_wg_path() / "results" / config.general.run_id -def get_path_model(config: Config) -> Path: +def get_path_model(config: Config | None = None, run_id: str | None = None) -> Path: """Get the current runs model_path for storing model checkpoints.""" - return Path(config.model_path) / config.general.run_id + if config or run_id: + run_id = run_id if run_id else config.general.run_id + else: + msg = f"Missing run_id and cannot infer it from config: {config}" + raise ValueError(msg) + return _get_shared_wg_path() / "models" / run_id -def get_path_output(config: Config, mini_epoch: int) -> Path: +def get_path_results(config: Config, mini_epoch: int) -> Path: + """Get the path to validation results for a specific mini_epoch and rank.""" ext = StoreType(config.zarr_store).value # validate extension base_path = get_path_run(config) fname = f"validation_chkpt{mini_epoch:05d}_rank{config.rank:04d}.{ext}" @@ -657,33 +642,11 @@ def get_path_output(config: Config, mini_epoch: int) -> Path: return base_path / fname -def get_shared_wg_path(local_path: str | Path) -> Path: - """ - Resolves a local, relative path to an absolute path within the configured shared working - directory. - - This utility function retrieves the base path defined for the shared WeatherGenerator (WG) - working directory from the private configuration and appends the provided local path segment. - - Parameters - ---------- - local_path : str or Path - The local or relative path segment (e.g., 'results', 'models', 'output') that needs - to be located within the shared working directory structure. - - Returns - ------- - Path - The absolute pathlib.Path object pointing to the specified location - within the shared working directory. - - Notes - ----- - The shared working directory base is retrieved from the 'path_shared_working_dir' - key found in the private configuration loaded by `_load_private_conf()`. - """ - pcfg = _load_private_conf() - return Path(pcfg.get("path_shared_working_dir")) / local_path +@functools.cache +def _get_shared_wg_path() -> Path: + """Get the shared working directory for WeatherGenerator.""" + private_config = _load_private_conf() + return Path(private_config.get("path_shared_working_dir")) def validate_forecast_policy_and_steps(cf: OmegaConf): @@ -733,4 +696,4 @@ def validate_forecast_policy_and_steps(cf: OmegaConf): else True ), provide_forecast_policy + valid_forecast_policies + valid_forecast_steps else: - raise TypeError(valid_forecast_steps) + raise TypeError(valid_forecast_steps) \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 4c71bfa94..6721fed8f 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -20,7 +20,7 @@ # Local application / package from weathergen.common.config import ( - get_shared_wg_path, + get_path_run, load_merge_configs, load_run_config, ) @@ -46,10 +46,8 @@ def __init__( self.inference_cfg = self.get_inference_config() if not self.results_base_dir: - self.results_base_dir = Path(get_shared_wg_path("results")) - self._logger.info( - f"Results directory obtained from private config: {self.results_base_dir}" - ) + self.results_base_dir = get_path_run(self.inference_cfg) + self._logger.info(f"Results directory obtained from private config: {self.results_base_dir}") else: self._logger.info(f"Results directory parsed: {self.results_base_dir}") diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 929a78ea9..2812134a0 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -91,7 +91,7 @@ def log_metrics(self, stage: Stage, metrics: dict[str, float]) -> None: # but we can probably do better and rely for example on the logging module. metrics_path = get_train_metrics_path( - base_path=Path(self.cf.run_path), run_id=self.cf.general.run_id + base_path=config.get_path_run(self.cf).parent, run_id=self.cf.general.run_id ) with open(metrics_path, "ab") as f: s = json.dumps(clean_metrics) + "\n" @@ -144,7 +144,7 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: ) run_id = cf.general.run_id - result_dir_base = Path(cf.run_path) + result_dir_base = config.get_path_run(cf) result_dir = result_dir_base / run_id fname_log_train = result_dir / f"{run_id}_train_log.txt" fname_log_val = result_dir / f"{run_id}_val_log.txt" diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 69be9854b..4365492e0 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -158,6 +158,6 @@ def write_output( sample_start, val_cfg.get("window_offset_prediction", 0), ) - with zarrio_writer(config.get_path_output(cf, mini_epoch)) as zio: + with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio: for subset in data.items(): zio.write_zarr(subset) From b37706c18e2812d6227f3cd68c186e2c93ed9dee Mon Sep 17 00:00:00 2001 From: Michael Tarnawa <18899420+mtar@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:46:50 +0100 Subject: [PATCH 20/44] [infra] consistent cli options (#1668) * replace '_' with '-' * cli options underscore to dash * change underscores to hyphens * rename options in cli unit test --- integration_tests/jepa1_test.py | 2 +- integration_tests/small1_test.py | 14 ++++----- integration_tests/small_multi_stream_test.py | 10 +++---- src/weathergen/run_train.py | 4 +-- src/weathergen/utils/cli.py | 24 ++++++++-------- tests/test_cli.py | 30 ++++++++++---------- 6 files changed, 42 insertions(+), 42 deletions(-) diff --git a/integration_tests/jepa1_test.py b/integration_tests/jepa1_test.py index 21df79626..f2959f3c9 100644 --- a/integration_tests/jepa1_test.py +++ b/integration_tests/jepa1_test.py @@ -52,7 +52,7 @@ def test_train(setup, test_run_id): train_with_args( [ f"--config={WEATHERGEN_HOME}/integration_tests/jepa1.yaml" ] + [ - "--run_id", + "--run-id", test_run_id, ], f"{WEATHERGEN_HOME}/config/streams/streams_test/", diff --git a/integration_tests/small1_test.py b/integration_tests/small1_test.py index 797c6d501..d3c6e4024 100644 --- a/integration_tests/small1_test.py +++ b/integration_tests/small1_test.py @@ -52,7 +52,7 @@ def test_train(setup, test_run_id): train_with_args( f"--config={WEATHERGEN_HOME}/integration_tests/small1.yaml".split() + [ - "--run_id", + "--run-id", test_run_id, ], f"{WEATHERGEN_HOME}/config/streams/streams_test/", @@ -69,11 +69,11 @@ def test_train(setup, test_run_id): def infer(run_id): logger.info("run inference") inference_from_args( - ["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--mini_epoch", "0"] + ["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--mini-epoch", "0"] + [ - "--from_run_id", + "--from-run-id", run_id, - "--run_id", + "--run-id", run_id, "--config", f"{WEATHERGEN_HOME}/integration_tests/small1.yaml", @@ -84,11 +84,11 @@ def infer(run_id): def infer_with_missing(run_id): logger.info("run inference") inference_from_args( - ["-start", "2021-10-10", "-end", "2022-10-11", "--samples", "10", "--mini_epoch", "0"] + ["-start", "2021-10-10", "-end", "2022-10-11", "--samples", "10", "--mini-epoch", "0"] + [ - "--from_run_id", + "--from-run-id", run_id, - "--run_id", + "--run-id", run_id, "--config", f"{WEATHERGEN_HOME}/integration_tests/small1.yaml", diff --git a/integration_tests/small_multi_stream_test.py b/integration_tests/small_multi_stream_test.py index dbcc3a3bc..3c7a1e4d6 100644 --- a/integration_tests/small_multi_stream_test.py +++ b/integration_tests/small_multi_stream_test.py @@ -61,7 +61,7 @@ def test_train_multi_stream(setup, test_run_id): train_with_args( f"--config={WEATHERGEN_HOME}/integration_tests/small_multi_stream.yaml".split() + [ - "--run_id", + "--run-id", test_run_id, ], f"{WEATHERGEN_HOME}/integration_tests/streams_multi/", @@ -79,13 +79,13 @@ def infer_multi_stream(run_id): """Run inference for multi-stream model.""" logger.info("run multi-stream inference") inference_from_args( - ["-start", "2021-10-10", "-end", "2022-10-11", "--samples", "10", "--mini_epoch", "0"] + ["-start", "2021-10-10", "-end", "2022-10-11", "--samples", "10", "--mini-epoch", "0"] + [ - "--from_run_id", + "--from-run-id", run_id, - "--run_id", + "--run-id", run_id, - "--streams_output", + "--streams-output", "ERA5", "SurfaceCombined", "NPPATMS", diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index a4ffb1f5b..58655c17b 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -199,8 +199,8 @@ def train_with_args(argl: list[str], stream_dir: str | None): if stage == "train": # Entry point for slurm script. - # Check whether --from_run_id passed as argument. - if any("--from_run_id" in arg for arg in sys.argv): + # Check whether --from-run-id passed as argument. + if any("--from-run-id" in arg for arg in sys.argv): train_continue() else: train() diff --git a/src/weathergen/utils/cli.py b/src/weathergen/utils/cli.py index e7fb2039d..d8f116431 100644 --- a/src/weathergen/utils/cli.py +++ b/src/weathergen/utils/cli.py @@ -18,7 +18,7 @@ def get_continue_parser() -> argparse.ArgumentParser: _add_model_loading_params(parser) parser.add_argument( - "--finetune_forecast", + "--finetune-forecast", action="store_true", help=( "Fine tune for forecasting. It overwrites some of the Config settings. " @@ -36,14 +36,14 @@ def get_inference_parser() -> argparse.ArgumentParser: _add_general_arguments(parser) parser.add_argument( - "--start_date", + "--start-date", "-start", type=_format_date, default="2022-10-01", help="Start date for inference. Format must be parsable with pd.to_datetime.", ) parser.add_argument( - "--end_date", + "--end-date", "-end", type=_format_date, default="2022-12-01", @@ -53,13 +53,13 @@ def get_inference_parser() -> argparse.ArgumentParser: "--samples", type=int, default=10000000, help="Number of inference samples." ) parser.add_argument( # behaviour changed => implies default=False - "--save_samples", + "--save-samples", type=bool, default=True, help="Toggle saving of samples from inference. Default True", ) parser.add_argument( - "--streams_output", + "--streams-output", nargs="+", help="Output streams during inference.", ) @@ -79,7 +79,7 @@ def _format_date(date: str) -> str: def _add_general_arguments(parser: argparse.ArgumentParser): parser.add_argument( - "--private_config", + "--private-config", type=Path, default=None, help=( @@ -95,7 +95,7 @@ def _add_general_arguments(parser: argparse.ArgumentParser): help="Optional experiment specfic configuration files in ascending order of precedence.", ) parser.add_argument( - "--run_id", + "--run-id", type=str, help=( "The run id for this run." @@ -126,7 +126,7 @@ def _add_general_arguments(parser: argparse.ArgumentParser): def _add_model_loading_params(parser: argparse.ArgumentParser): parser.add_argument( "-id", - "--from_run_id", + "--from-run-id", required=True, help=( "Start inference or continue training from the WeatherGenerator" @@ -135,18 +135,18 @@ def _add_model_loading_params(parser: argparse.ArgumentParser): ) parser.add_argument( "-e", - "--mini_epoch", + "--mini-epoch", type=int, default=-1, help=( - "Mini_epoch of pretrained WeatherGenerator model used" + "Mini-epoch of pretrained WeatherGenerator model used" " (Default -1 corresponds to the last checkpoint)." ), ) parser.add_argument( - "--reuse_run_id", + "--reuse-run-id", action="store_true", - help="Use the id given via --from_run_id also for the current run. " + help="Use the id given via --from-run-id also for the current run. " "The storage location for artifacts will be reused as well. " "This might overwrite artifacts from previous runs.", ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 8e47df184..e652c7924 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,10 +6,10 @@ DATE_FORMATS = ["2022-12-01T00:00:00", "20221201", "2022-12-01", "12.01.2022"] EXPECTED_DATE_STR = "202212010000" -MODEL_LOADING_ARGS = ["from_run_id", "mini_epoch", "reuse_run_id"] -GENERAL_ARGS = ["config", "private_config", "options", "run_id"] +MODEL_LOADING_ARGS = ["from-run-id", "mini-epoch", "reuse-run-id"] +GENERAL_ARGS = ["config", "private-config", "options", "run-id"] MODEL_LOADING_PARSERS = [cli.get_continue_parser(), cli.get_inference_parser()] -BASIC_ARGLIST = ["--from_run_id", "test123"] +BASIC_ARGLIST = ["--from-run-id", "test123"] @pytest.fixture @@ -18,7 +18,7 @@ def inference_parser(): def test_private_config_is_path(): - argl = ["--private_config", "foo/bar"] + argl = ["--private-config", "foo/bar"] args = cli.get_train_parser().parse_args(argl) @@ -61,14 +61,14 @@ def test_model_loading_has_params(parser): @pytest.mark.parametrize("streams", [["ERA5", "FOO"], ["BAR"]]) def test_inference_streams_output(inference_parser, streams): - arglist = BASIC_ARGLIST + ["--streams_output", *streams] + arglist = BASIC_ARGLIST + ["--streams-output", *streams] args = inference_parser.parse_args(arglist) assert args.streams_output == streams def test_inference_streams_output_empty(inference_parser): - arglist = BASIC_ARGLIST + ["--streams_output", *[]] + arglist = BASIC_ARGLIST + ["--streams-output", *[]] with pytest.raises(SystemExit): inference_parser.parse_args(arglist) @@ -76,12 +76,12 @@ def test_inference_streams_output_empty(inference_parser): def test_inference_defaults(inference_parser): default_args = [ - "start_date", - "end_date", + "start-date", + "end-date", "samples", - "streams_output", - "mini_epoch", - "private_config", + "streams-output", + "mini-epoch", + "private-config", ] default_values = [inference_parser.get_default(arg) for arg in default_args] # apply custom type @@ -99,23 +99,23 @@ def test_inference_defaults(inference_parser): @pytest.mark.parametrize("date", DATE_FORMATS) def test_inference_start_date(inference_parser, date): - args = inference_parser.parse_args(BASIC_ARGLIST + ["--start_date", date]) + args = inference_parser.parse_args(BASIC_ARGLIST + ["--start-date", date]) assert args.start_date == EXPECTED_DATE_STR def test_inference_start_date_invalid(inference_parser): with pytest.raises(SystemExit): - inference_parser.parse_args(BASIC_ARGLIST + ["--start_date", "foobar"]) + inference_parser.parse_args(BASIC_ARGLIST + ["--start-date", "foobar"]) @pytest.mark.parametrize("date", DATE_FORMATS) def test_inference_end_date(inference_parser, date): - args = inference_parser.parse_args(BASIC_ARGLIST + ["--end_date", date]) + args = inference_parser.parse_args(BASIC_ARGLIST + ["--end-date", date]) assert args.end_date == EXPECTED_DATE_STR def test_inference_end_date_invalid(inference_parser): with pytest.raises(SystemExit): - inference_parser.parse_args(BASIC_ARGLIST + ["--end_date", "foobar"]) + inference_parser.parse_args(BASIC_ARGLIST + ["--end-date", "foobar"]) From 9710f810b65b23998f5b417627c89e6ee22997d5 Mon Sep 17 00:00:00 2001 From: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:50:26 +0100 Subject: [PATCH 21/44] Fix bug for missing run_id path in model path (#1704) --- packages/common/src/weathergen/common/config.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 58c4131b9..ad81ee84e 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -227,7 +227,7 @@ def load_run_config(run_id: str, mini_epoch: int | None, model_path: str | None) if model_path is None: path = get_path_model(run_id=run_id) else: - path = Path(model_path) + path = Path(model_path) / run_id fname = path / _get_model_config_file_read_name(run_id, mini_epoch) if not fname.exists(): @@ -260,6 +260,7 @@ def _get_model_config_file_write_name(run_id: str, mini_epoch: int | None): return f"model_{run_id}{mini_epoch_str}.json" + def _get_model_config_file_read_name(run_id: str, mini_epoch: int | None, use_old_name=False): """Generate the filename for reading a model config file.""" if mini_epoch is None: @@ -267,12 +268,13 @@ def _get_model_config_file_read_name(run_id: str, mini_epoch: int | None, use_ol elif mini_epoch == -1: mini_epoch_str = "_latest" elif use_old_name: - mini_epoch_str = f"_epoch{mini_epoch:05d}" # TODO remove compatibility + mini_epoch_str = f"_epoch{mini_epoch:05d}" # TODO remove compatibility else: mini_epoch_str = f"_chkpt{mini_epoch:05d}" return f"model_{run_id}{mini_epoch_str}.json" + def get_model_results(run_id: str, mini_epoch: int, rank: int) -> Path: """ Get the path to the model results zarr store from a given run_id and mini_epoch. @@ -696,4 +698,4 @@ def validate_forecast_policy_and_steps(cf: OmegaConf): else True ), provide_forecast_policy + valid_forecast_policies + valid_forecast_steps else: - raise TypeError(valid_forecast_steps) \ No newline at end of file + raise TypeError(valid_forecast_steps) From fa952ff0b8cc8432304249199342009dcbc28f06 Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:55:42 +0100 Subject: [PATCH 22/44] fix bar plot (#1698) Co-authored-by: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> --- .../src/weathergen/evaluate/io/csv_reader.py | 3 ++ .../weathergen/evaluate/io/wegen_reader.py | 7 +++-- .../weathergen/evaluate/plotting/plotter.py | 31 ++++++++++++------- .../src/weathergen/evaluate/utils/utils.py | 10 +++--- 4 files changed, 32 insertions(+), 19 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py index 14a2548bf..83ce2664b 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py @@ -70,6 +70,9 @@ def __init__( else: self.data = pd.concat([self.data, data], ignore_index=True) + self.data = self.data.dropna(subset=["step", "level"]) + self.data["level"] = self.data["level"].astype(int) + self.data["channel"] = ( self.data["parameter"].astype(str) + "_" + self.data["level"].astype(str) if "level" in self.data.columns diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 6721fed8f..ecf50ed56 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -680,9 +680,10 @@ def _force_consistent_grids(ref: list[xr.DataArray]) -> xr.DataArray: sort_idx = np.lexsort((ref_lon.values, ref_lat.values)) npoints = sort_idx.size aligned = [] + samples = [] for i, a in enumerate(ref): a_sorted = a.isel(ipoint=sort_idx) - + samples.append(a_sorted.sample.values) a_sorted = a_sorted.assign_coords( ipoint=np.arange(npoints), lat=("ipoint", ref_lat.values[sort_idx]), @@ -693,8 +694,8 @@ def _force_consistent_grids(ref: list[xr.DataArray]) -> xr.DataArray: a_sorted = a_sorted.expand_dims(sample=[i]) aligned.append(a_sorted) - - return xr.concat(aligned, dim="sample") + + return xr.concat(aligned, dim="sample").assign_coords({"sample": samples}) class WeatherGenMergeReader(Reader): diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index cfbe29037..6480d81c8 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -699,7 +699,9 @@ def print_all_points_from_graph(self, fig: plt.Figure) -> None: label = line.get_label() self._logger.info(f"Summary for {label} plot:") for xi, yi in zip(xdata, ydata, strict=False): - self._logger.info(f" x: {xi:.3f}, y: {yi:.3f}") + xi = xi if isinstance(xi, str) else f"{float(xi):.3f}" + yi = yi if isinstance(yi, str) else f"{float(yi):.3f}" + self._logger.info(f" x: {xi}, y: {yi}") self._logger.info("--------------------------") return @@ -1536,7 +1538,7 @@ def plot( squeeze=False, ) ax = ax.flatten() - + single_run = False if self.baseline and self.baseline in runs: baseline_idx = runs.index(self.baseline) runs = [runs[baseline_idx]] + runs[:baseline_idx] + runs[baseline_idx + 1 :] @@ -1548,22 +1550,25 @@ def plot( ones_array = xr.full_like(data[0], 1.0) runs = [""] + runs data = [ones_array] + data + single_run = True + for run_index in range(1, len(runs)): - ratio_score, channels_per_comparison = self.calc_ratio_per_run_id( + + score, channels_per_comparison = self.calc_ratio_per_run_id( data, channels, run_index ) - if len(ratio_score) > 0: + if len(score) > 0: ax[run_index - 1].barh( - np.arange(len(ratio_score)), - ratio_score, - color=self.colors(ratio_score, metric), + np.arange(len(score)), + score, + color=self.colors(score, metric), align="center", edgecolor="black", linewidth=0.5, ) ax[run_index - 1].set_yticks( - np.arange(len(ratio_score)), labels=channels_per_comparison + np.arange(len(score)), labels=channels_per_comparison ) ax[run_index - 1].invert_yaxis() @@ -1633,7 +1638,7 @@ def calc_ratio_per_run_id( """ ratio_score = [] channels_per_comparison = [] - + for _, var in enumerate(channels): if var not in data[0].channel.values or var not in data[run_index].channel.values: continue @@ -1646,8 +1651,12 @@ def calc_ratio_per_run_id( ) ratio_score.append(model_score / baseline_score) - - ratio_score = np.array(ratio_score) - 1 + + if np.allclose(baseline_score, 1.0, atol=1e-6): + ratio_score = np.array(ratio_score) + else: + ratio_score = np.array(ratio_score) - 1 + return ratio_score, channels_per_comparison def colors(self, ratio_score: np.array, metric: str) -> list[tuple]: diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index cf460b694..871e3f6a0 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -193,13 +193,13 @@ def calc_scores_per_stream( criteria = { "forecast_step": int(fstep), - "sample": combined_metrics.sample, - "channel": combined_metrics.channel, - "metric": combined_metrics.metric, + "sample": combined_metrics.sample.values, + "channel": combined_metrics.channel.values, + "metric": combined_metrics.metric.values, } if "ens" in combined_metrics.dims: - criteria["ens"] = combined_metrics.ens - + criteria["ens"] = combined_metrics.ens.values + metric_stream.loc[criteria] = combined_metrics lead_time_map[fstep] = ( From 34fa89aa7e7d823c4493a3e0ffe3f1826874d427 Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Tue, 27 Jan 2026 12:07:23 +0100 Subject: [PATCH 23/44] Fix output generation during inference (#1707) * rename write_num_samples to num_samples * Fixing linting --------- Co-authored-by: Christian Lessig --- packages/evaluate/src/weathergen/evaluate/plotting/plotter.py | 2 -- src/weathergen/run_train.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 6480d81c8..da68924d5 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -1538,7 +1538,6 @@ def plot( squeeze=False, ) ax = ax.flatten() - single_run = False if self.baseline and self.baseline in runs: baseline_idx = runs.index(self.baseline) runs = [runs[baseline_idx]] + runs[:baseline_idx] + runs[baseline_idx + 1 :] @@ -1550,7 +1549,6 @@ def plot( ones_array = xr.full_like(data[0], 1.0) runs = [""] + runs data = [ones_array] + data - single_run = True for run_index in range(1, len(runs)): diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 58655c17b..65551fccb 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -48,7 +48,7 @@ def inference_from_args(argl: list[str]): start_date=args.start_date, end_date=args.end_date, samples_per_mini_epoch=args.samples, - output=dict(write_num_samples=args.samples if args.save_samples else 0), + output=dict(num_samples=args.samples if args.save_samples else 0), streams_output=args.streams_output, ) } From d3684000ecfa0817ccea41c48673e91afbc86814 Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Tue, 27 Jan 2026 14:27:10 +0100 Subject: [PATCH 24/44] backwards compatilble run_id look up (#1715) --- packages/common/src/weathergen/common/config.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index ad81ee84e..6416158d0 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -175,6 +175,11 @@ def get_run_id(): return "".join(random.sample(s1, 1)) + "".join(random.sample(s2, 7)) +def get_run_id_from_config(config: Config) -> str: + general_cfg = config.get("general", None) + return general_cfg.run_id if general_cfg else config.run_id + + def format_cf(config: Config) -> str: """Format config as a human-readable string.""" stream = io.StringIO() @@ -198,7 +203,7 @@ def save(config: Config, mini_epoch: int | None): dirname = get_path_model(config) dirname.mkdir(exist_ok=True, parents=True) - fname = _get_model_config_file_write_name(config.general.run_id, mini_epoch) + fname = _get_model_config_file_write_name(get_run_id_from_config(config), mini_epoch) json_str = json.dumps(OmegaConf.to_container(_strip_interpolation(config))) with (dirname / fname).open("w") as f: @@ -376,7 +381,7 @@ def load_merge_configs( base_config = _load_base_conf(base) else: base_config = load_run_config(from_run_id, mini_epoch, None) - from_run_id = base_config.general.run_id + from_run_id = get_run_id_from_config(base_config) with open_dict(base_config): base_config.from_run_id = from_run_id # use OmegaConf.unsafe_merge if too slow @@ -429,8 +434,8 @@ def set_run_id(config: Config, run_id: str | None, reuse_run_id: bool) -> Config """ config = config.copy() if reuse_run_id: - assert config.general.run_id is not None, "Loaded run_id should not be None." - _logger.info(f"reusing run_id from previous run: {config.general.run_id}") + assert get_run_id_from_config(config) is not None, "Loaded run_id should not be None." + _logger.info(f"reusing run_id from previous run: {get_run_id_from_config(config)}") else: if run_id is None: # generate new id if run_id is None @@ -622,13 +627,13 @@ def load_streams(streams_directory: Path) -> list[Config]: def get_path_run(config: Config) -> Path: """Get the current runs results_path for storing run results and logs.""" - return _get_shared_wg_path() / "results" / config.general.run_id + return _get_shared_wg_path() / "results" / get_run_id_from_config(config) def get_path_model(config: Config | None = None, run_id: str | None = None) -> Path: """Get the current runs model_path for storing model checkpoints.""" if config or run_id: - run_id = run_id if run_id else config.general.run_id + run_id = run_id if run_id else get_run_id_from_config(config) else: msg = f"Missing run_id and cannot infer it from config: {config}" raise ValueError(msg) From e84d8d831af3b6106ef62e87ce0cdb86d52b4742 Mon Sep 17 00:00:00 2001 From: Till Hauer Date: Tue, 27 Jan 2026 14:47:06 +0100 Subject: [PATCH 25/44] remove misleading logging of `mini_epoch` (#1679) * remove misleading logging of mini_epoch * add forecast_steps logging --- src/weathergen/datasets/multi_stream_data_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 29723811f..bf3a6ed7d 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -286,7 +286,7 @@ def reset(self): else self.forecast_steps.max() ) if fsm > 0: - logger.info(f"forecast_steps at mini_epoch={self.mini_epoch} : {fsm}") + logger.info(f"forecast_steps : {fsm}") # data index_range = self.time_window_handler.get_index_range() From cfcad2eece0f268e45a336de0bedee0ffa9802d9 Mon Sep 17 00:00:00 2001 From: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> Date: Tue, 27 Jan 2026 15:06:11 +0100 Subject: [PATCH 26/44] Fix duplicate run_id in results and runplots paths (#1716) * Fix duplicate run_id in results and runplots paths. Linting. * remove duplicate run_id also from metrics directory * Linting --- .../src/weathergen/evaluate/io/wegen_reader.py | 8 ++++---- .../src/weathergen/evaluate/plotting/plotter.py | 16 +++++----------- .../src/weathergen/evaluate/utils/utils.py | 2 +- 3 files changed, 10 insertions(+), 16 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index ecf50ed56..097507ca2 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -62,12 +62,12 @@ def __init__( self.step_hrs = self.inference_cfg.get("step_hrs", 1) self.results_dir, self.runplot_dir = ( - Path(self.results_base_dir) / self.run_id, - Path(self.runplot_base_dir) / self.run_id, + Path(self.results_base_dir), + Path(self.runplot_base_dir), ) # for backward compatibility allow metric_dir to be specified in the run config self.metrics_dir = Path( - self.eval_cfg.get("metrics_dir", self.metrics_base_dir / self.run_id / "evaluation") + self.eval_cfg.get("metrics_dir", self.metrics_base_dir / "evaluation") ) def get_inference_config(self): @@ -694,7 +694,7 @@ def _force_consistent_grids(ref: list[xr.DataArray]) -> xr.DataArray: a_sorted = a_sorted.expand_dims(sample=[i]) aligned.append(a_sorted) - + return xr.concat(aligned, dim="sample").assign_coords({"sample": samples}) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index da68924d5..ffa0393d2 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -1550,12 +1550,8 @@ def plot( runs = [""] + runs data = [ones_array] + data - for run_index in range(1, len(runs)): - - score, channels_per_comparison = self.calc_ratio_per_run_id( - data, channels, run_index - ) + score, channels_per_comparison = self.calc_ratio_per_run_id(data, channels, run_index) if len(score) > 0: ax[run_index - 1].barh( np.arange(len(score)), @@ -1565,9 +1561,7 @@ def plot( edgecolor="black", linewidth=0.5, ) - ax[run_index - 1].set_yticks( - np.arange(len(score)), labels=channels_per_comparison - ) + ax[run_index - 1].set_yticks(np.arange(len(score)), labels=channels_per_comparison) ax[run_index - 1].invert_yaxis() xlabel = ( @@ -1636,7 +1630,7 @@ def calc_ratio_per_run_id( """ ratio_score = [] channels_per_comparison = [] - + for _, var in enumerate(channels): if var not in data[0].channel.values or var not in data[run_index].channel.values: continue @@ -1649,12 +1643,12 @@ def calc_ratio_per_run_id( ) ratio_score.append(model_score / baseline_score) - + if np.allclose(baseline_score, 1.0, atol=1e-6): ratio_score = np.array(ratio_score) else: ratio_score = np.array(ratio_score) - 1 - + return ratio_score, channels_per_comparison def colors(self, ratio_score: np.array, metric: str) -> list[tuple]: diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index 871e3f6a0..cab8b9e5f 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -199,7 +199,7 @@ def calc_scores_per_stream( } if "ens" in combined_metrics.dims: criteria["ens"] = combined_metrics.ens.values - + metric_stream.loc[criteria] = combined_metrics lead_time_map[fstep] = ( From 16f790d3c6fe5517ca80ae11009046e68772fe0b Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Fri, 16 Jan 2026 16:56:48 +0100 Subject: [PATCH 27/44] latent_space evaluation scripts + propagate verbose --- .../src/weathergen/evaluate/io/csv_reader.py | 4 ++++ .../evaluate/latent_space/latent_space_eval.py | 2 -- .../weathergen/evaluate/plotting/plot_utils.py | 8 +++++++- .../src/weathergen/evaluate/plotting/plotter.py | 16 +++++++++++++--- .../src/weathergen/evaluate/run_evaluation.py | 4 ++++ 5 files changed, 28 insertions(+), 6 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py index 83ce2664b..13c929954 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py @@ -28,9 +28,13 @@ class CsvReader(Reader): Reader class to read evaluation data from CSV files and convert to xarray DataArray. """ +<<<<<<< HEAD def __init__( self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose=True ): +======= + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose = True): +>>>>>>> e62a22f2 (latent_space evaluation scripts + propagate verbose) """ Initialize the CsvReader. diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py index 60c620918..34770d565 100644 --- a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py +++ b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py @@ -217,14 +217,12 @@ def load_metrics(run_id): def load_scores(eval_cfg, run_id): """Helper function to load metrics""" - run_cfg = eval_cfg.run_ids[run_id] metrics = list(eval_cfg.evaluation.get("metrics")) regions = list(eval_cfg.evaluation.get("regions")) reader = WeatherGenJSONReader(run_cfg, run_id, None, regions, metrics) - scores = {} for stream_name in streams: diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 933aa7ea1..4020758b2 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -114,6 +114,12 @@ def plot_metric_region( ) selected_data, time_dim = _assign_time_coord(selected_data) + + if time_dim != "lead_time": + plotter._logger.warning( + "lead_time coordinate not found for all plotted data; " + "using forecast_step as x-axis." + ) if time_dim != "lead_time": plotter._logger.warning( @@ -290,7 +296,7 @@ def heat_maps_metric_region( "lead_time coordinate not found for all plotted data; " "using forecast_step as x-axis." ) - + plotter.heat_map( selected_data, labels, diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index ffa0393d2..824691961 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -4,6 +4,7 @@ import os import re from pathlib import Path +from tabnanny import verbose import cartopy import cartopy.crs as ccrs @@ -622,7 +623,7 @@ def get_map_output_dir(self, tag): class LinePlots: - def __init__(self, plotter_cfg: dict, output_basedir: str | Path, verbose=True): + def __init__(self, plotter_cfg: dict, output_basedir: str | Path, verbose = True): """ Initialize the LinePlots class. @@ -982,10 +983,20 @@ def ratio_plot( data_list, label_list = self._check_lengths(data, labels) + baseline_name = self.baseline + baseline_idx = run_ids.index(self.baseline) if self.baseline in run_ids else None + if len(data_list) < 2: baseline = xr.full_like(data_list[0], 1.0) baseline_name = "ones" descr = "scores" + self._logger.warning("Ratio plot requires at least two datasets to compare. Skipping.") + return + + if baseline_idx is not None: + self._logger.info(f"Using baseline run ID '{self.baseline}' for ratio plot.") + baseline = data_list[baseline_idx] + else: descr = "ratio_plot" baseline_name = self.baseline @@ -1706,7 +1717,7 @@ def calculate_average_over_dim( ] if non_zero_dims: - logger.info(f"Found multiple entries for dimensions: {non_zero_dims}. Averaging...") + self._logger.info(f"Found multiple entries for dimensions: {non_zero_dims}. Averaging...") baseline_score = baseline_var.mean( dim=[dim for dim in baseline_var.dims if dim != x_dim], skipna=True @@ -1758,7 +1769,6 @@ def channel_sort_key(name: str) -> tuple[int, str, int]: else: return (1, name, float("inf")) - def setup_logger(name: str, verbose: bool) -> logging.Logger: """ Set up a logger with the specified name and verbosity level. diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 3382ba24d..052ed3a89 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -352,7 +352,11 @@ def evaluate_from_config( with mp.Pool( processes=num_processes, initializer=setup_worker_logger, +<<<<<<< HEAD initargs=(log_queue, verbose), +======= + initargs=(log_queue,verbose), +>>>>>>> e62a22f2 (latent_space evaluation scripts + propagate verbose) ) as pool: results = pool.map( _process_stream_wrapper, From d0701eb3d388f83c5864b5fa1ee7c8fe5864048c Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Fri, 16 Jan 2026 18:13:30 +0100 Subject: [PATCH 28/44] lint --- packages/evaluate/src/weathergen/evaluate/io/csv_reader.py | 5 +---- .../evaluate/src/weathergen/evaluate/io/wegen_reader.py | 6 ++++-- .../evaluate/src/weathergen/evaluate/plotting/plot_utils.py | 4 ++-- .../evaluate/src/weathergen/evaluate/plotting/plotter.py | 6 +++--- packages/evaluate/src/weathergen/evaluate/run_evaluation.py | 4 ++++ 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py index 13c929954..8bed08597 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py @@ -28,13 +28,10 @@ class CsvReader(Reader): Reader class to read evaluation data from CSV files and convert to xarray DataArray. """ -<<<<<<< HEAD + def __init__( self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose=True ): -======= - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose = True): ->>>>>>> e62a22f2 (latent_space evaluation scripts + propagate verbose) """ Initialize the CsvReader. diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 097507ca2..431ef3651 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -46,8 +46,10 @@ def __init__( self.inference_cfg = self.get_inference_config() if not self.results_base_dir: - self.results_base_dir = get_path_run(self.inference_cfg) - self._logger.info(f"Results directory obtained from private config: {self.results_base_dir}") + self.results_base_dir = Path(get_shared_wg_path("results")) + self._logger.info( + f"Results directory obtained from private config: {self.results_base_dir}" + ) else: self._logger.info(f"Results directory parsed: {self.results_base_dir}") diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 4020758b2..ffd6c370c 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -114,7 +114,7 @@ def plot_metric_region( ) selected_data, time_dim = _assign_time_coord(selected_data) - + if time_dim != "lead_time": plotter._logger.warning( "lead_time coordinate not found for all plotted data; " @@ -296,7 +296,7 @@ def heat_maps_metric_region( "lead_time coordinate not found for all plotted data; " "using forecast_step as x-axis." ) - + plotter.heat_map( selected_data, labels, diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 824691961..fa52b425f 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -4,7 +4,6 @@ import os import re from pathlib import Path -from tabnanny import verbose import cartopy import cartopy.crs as ccrs @@ -623,7 +622,7 @@ def get_map_output_dir(self, tag): class LinePlots: - def __init__(self, plotter_cfg: dict, output_basedir: str | Path, verbose = True): + def __init__(self, plotter_cfg: dict, output_basedir: str | Path, verbose=True): """ Initialize the LinePlots class. @@ -1717,7 +1716,7 @@ def calculate_average_over_dim( ] if non_zero_dims: - self._logger.info(f"Found multiple entries for dimensions: {non_zero_dims}. Averaging...") + logger.info(f"Found multiple entries for dimensions: {non_zero_dims}. Averaging...") baseline_score = baseline_var.mean( dim=[dim for dim in baseline_var.dims if dim != x_dim], skipna=True @@ -1769,6 +1768,7 @@ def channel_sort_key(name: str) -> tuple[int, str, int]: else: return (1, name, float("inf")) + def setup_logger(name: str, verbose: bool) -> logging.Logger: """ Set up a logger with the specified name and verbosity level. diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 052ed3a89..5a2d55789 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -352,11 +352,15 @@ def evaluate_from_config( with mp.Pool( processes=num_processes, initializer=setup_worker_logger, +<<<<<<< HEAD <<<<<<< HEAD initargs=(log_queue, verbose), ======= initargs=(log_queue,verbose), >>>>>>> e62a22f2 (latent_space evaluation scripts + propagate verbose) +======= + initargs=(log_queue, verbose), +>>>>>>> acd0439c (lint) ) as pool: results = pool.map( _process_stream_wrapper, From 49601ff9c6773870fc4a4063ce3e29b6e58c662a Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Tue, 27 Jan 2026 15:06:32 +0100 Subject: [PATCH 29/44] update latent space eval --- .../src/weathergen/evaluate/latent_space/latent_space_eval.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py index 34770d565..c3ae7aec9 100644 --- a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py +++ b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py @@ -92,6 +92,8 @@ def get_evaluation_config(run_id, verbose=False): "summary_dir": f"./results/{run_id}/plots/summary/", "print_summary": False, "verbose": verbose, + "ratio_plots": True, + "bar_plots": True, }, "run_ids": { run_id: { From 9824a5435c16690aa9d1eab5070c490b45365ec6 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Tue, 27 Jan 2026 17:43:23 +0100 Subject: [PATCH 30/44] rebase to develop --- packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 431ef3651..e6ba66264 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -46,7 +46,7 @@ def __init__( self.inference_cfg = self.get_inference_config() if not self.results_base_dir: - self.results_base_dir = Path(get_shared_wg_path("results")) + self.results_base_dir = get_path_run(self.inference_cfg) self._logger.info( f"Results directory obtained from private config: {self.results_base_dir}" ) From 9518cf5961e24afdce2e73da95cc4d1b98aab4fd Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Tue, 27 Jan 2026 20:39:05 +0100 Subject: [PATCH 31/44] final version --- .../src/weathergen/evaluate/io/io_reader.py | 2 +- .../weathergen/evaluate/io/wegen_reader.py | 15 +++++++------ .../latent_space/latent_space_eval.py | 6 +++--- .../weathergen/evaluate/plotting/plotter.py | 2 +- .../src/weathergen/evaluate/run_evaluation.py | 21 +++++++------------ 5 files changed, 22 insertions(+), 24 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py index b218d2615..aab517916 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py @@ -90,7 +90,7 @@ def __init__( self._logger = logging.getLogger(__name__) - logger_level = logging.INFO if verbose else logging.WARNING + logger_level = logging.INFO if verbose else logging.CRITICAL + 1 self._logger.setLevel(logger_level) def get_stream(self, stream: str): diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index e6ba66264..0b5e96c9c 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -256,8 +256,9 @@ def __init__( private_paths: dict | None = None, regions: list[str] | None = None, metrics: list[str] | None = None, + verbose=True, ): - super().__init__(eval_cfg, run_id, private_paths) + super().__init__(eval_cfg, run_id, private_paths, verbose) # goes looking for the coordinates available for all streams, regions, metrics streams = list(self.eval_cfg.streams.keys()) coord_names = ["sample", "forecast_step", "ens"] @@ -300,16 +301,18 @@ def get_data(self, *args, **kwargs): raise ValueError(f"Missing JSON data for run {self.run_id}.") def get_recomputable_metrics(self, metrics): - _logger.info( - f"The following metrics have not yet been computed:{metrics}. Use type: zarr for that." - ) + if metrics: + self._logger.info( + f"The following metrics have not yet been computed:{metrics}. " + "Use type: zarr for that. Skipping them." + ) return {} class WeatherGenZarrReader(WeatherGenReader): - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose=True): """Data reader class for WeatherGenerator model outputs stored in Zarr format.""" - super().__init__(eval_cfg, run_id, private_paths) + super().__init__(eval_cfg, run_id, private_paths, verbose) zarr_ext = self.inference_cfg.get("zarr_store", "zarr") # for backwards compatibility assume zarr store is local i.e. .zarr format diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py index c3ae7aec9..5e3226a4e 100644 --- a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py +++ b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py @@ -13,7 +13,7 @@ It performs training and inference with multiple data sources including gridded and obs data. Command: -uv run --offline packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py --run_id j2dkivn8 +uv run --offline packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py --run-id j2dkivn8 """ import argparse @@ -64,7 +64,7 @@ def infer_multi_stream(run_id): "forecast_offset=0", "zarr_store=zip", ] - + ["--from_run_id", run_id, "--run_id", new_run_id, "--streams_output"] + + ["--from-run-id", run_id, "--run-id", new_run_id, "--streams-output"] + streams + [ "--config", @@ -244,7 +244,7 @@ def load_scores(eval_cfg, run_id): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run multi-stream latent space evaluation") parser.add_argument( - "--run_id", type=str, required=True, help="Run identifier for the model to evaluate" + "--run-id", type=str, required=True, help="Run identifier for the model to evaluate" ) parser.add_argument( "--verbose", action="store_true", help="Enable verbose output", default=False diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index fa52b425f..65495ae6c 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -1553,7 +1553,7 @@ def plot( runs = [runs[baseline_idx]] + runs[:baseline_idx] + runs[baseline_idx + 1 :] data = [data[baseline_idx]] + data[:baseline_idx] + data[baseline_idx + 1 :] elif len(runs) < 2: - _logger.warning( + self._logger.warning( "BarPlots:: Less than two runs provided. Generating bar plot against ones." ) ones_array = xr.full_like(data[0], 1.0) diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 5a2d55789..f055d8ab2 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -179,15 +179,16 @@ def get_reader( private_paths: dict[str, str], region: str | None = None, metric: str | None = None, + verbose: bool = True, ): if reader_type == "zarr": - reader = WeatherGenZarrReader(run, run_id, private_paths) + reader = WeatherGenZarrReader(run, run_id, private_paths, verbose) elif reader_type == "csv": - reader = CsvReader(run, run_id, private_paths) + reader = CsvReader(run, run_id, private_paths, verbose) elif reader_type == "json": - reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric) + reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric, verbose) elif reader_type == "merge": - reader = WeatherGenMergeReader(run, run_id, private_paths) + reader = WeatherGenMergeReader(run, run_id, private_paths, verbose) else: raise ValueError(f"Unknown reader type: {reader_type}") return reader @@ -295,6 +296,7 @@ def evaluate_from_config( plot_score_maps = cfg.evaluation.get("plot_score_maps", False) global_plotting_opts = cfg.get("global_plotting_options", {}) use_parallel = cfg.evaluation.get("num_processes", 0) + verbose = cfg.evaluation.get("verbose", True) default_streams = cfg.get("default_streams", {}) @@ -327,7 +329,7 @@ def evaluate_from_config( regions = cfg.evaluation.regions metrics = cfg.evaluation.metrics - reader = get_reader(type_, run, run_id, private_paths, regions, metrics) + reader = get_reader(type_, run, run_id, private_paths, regions, metrics, verbose) for stream in reader.streams: tasks.append( @@ -352,15 +354,8 @@ def evaluate_from_config( with mp.Pool( processes=num_processes, initializer=setup_worker_logger, -<<<<<<< HEAD -<<<<<<< HEAD - initargs=(log_queue, verbose), -======= - initargs=(log_queue,verbose), ->>>>>>> e62a22f2 (latent_space evaluation scripts + propagate verbose) -======= initargs=(log_queue, verbose), ->>>>>>> acd0439c (lint) + ) as pool: results = pool.map( _process_stream_wrapper, From 4e06d00ba82cc6e29173bbf93d77b687effcf01f Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 28 Jan 2026 09:42:34 +0100 Subject: [PATCH 32/44] add readme --- packages/evaluate/pyproject.toml | 1 + .../evaluate/latent_space/README.md | 170 ++++++++++++++++++ .../latent_space/latent_space_eval.py | 2 +- 3 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 packages/evaluate/src/weathergen/evaluate/latent_space/README.md diff --git a/packages/evaluate/pyproject.toml b/packages/evaluate/pyproject.toml index 3e3570bf1..83281cca0 100644 --- a/packages/evaluate/pyproject.toml +++ b/packages/evaluate/pyproject.toml @@ -29,6 +29,7 @@ dev = [ ] [project.scripts] +latent_space_eval= "weathergen.evaluate.latent_space.latent_space_eval:latent_space_analysis" evaluation = "weathergen.evaluate.run_evaluation:evaluate" export = "weathergen.evaluate.export.export_inference:export" diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/README.md b/packages/evaluate/src/weathergen/evaluate/latent_space/README.md new file mode 100644 index 000000000..6eae6ec39 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/latent_space/README.md @@ -0,0 +1,170 @@ +# Multi-Stream Latent Space Inference & Evaluation + +This script runs **inference and evaluation for a multi-stream weather generation model** and produces quantitative metrics and diagnostic plots for each stream. + +The workflow is designed to: + +1. Run inference from an existing trained model run +2. Evaluate forecasts across multiple data streams +3. Print training, validation, and evaluation summaries + +--- + +## Supported Streams + +The script currently evaluates the following streams: + +```python +streams = ["ERA5", "SurfaceCombined", "NPPATMS"] +``` + +Each stream has its own channels, metrics, and plotting configuration. + +--- + +## Overview of the Pipeline + +### 1. Inference (`infer_multi_stream`) + +* Runs inference using a trained model (`run_id`) +* Creates a new run ID with `_inf` suffix +* Generates forecasts over a fixed date range +* Outputs results in Zarr format + +**Inference period** + +* Start: `2021-10-10` +* End: `2022-10-11` +* Samples per forecast: `10` + +--- + +### 2. Evaluation Configuration (`get_evaluation_config`) + +Defines: + +* Metrics: `rmse`, `froct` +* Regions: `global` +* Plotting options (maps, histograms, summary plots) +* Stream-specific channels and evaluation settings + +Each stream is evaluated independently but summarized jointly. + +--- + +### 3. Evaluation (`evaluate_multi_stream_results`) + +* Loads inference outputs +* Computes metrics per stream, region, and forecast step +* Saves plots and summaries under: + +``` +./results/_inf/plots/summary/ +``` + +--- + +### 4. Reporting Utilities + +#### Print Losses (`print_losses`) + +Prints training or validation losses for each stream: + +* Uses `LossPhysical..mse.avg` +* Supports `train` and `val` stages + +#### Print Evaluation Results (`print_evaluation_results`) + +Prints mean evaluation scores averaged over: + +* samples +* forecast steps +* ensemble members + +Results are grouped by: + +* stream +* metric +* region + +--- + +## File Structure Expectations + +The script assumes the following directory layout: + +``` +. +├── config/ +│ └── evaluate/ +│ └── latent_space_eval_config.yaml +├── models/ +│ └── +│ ├── _latest.chkpt +│ ├── model_.json +├── results/ +│ └── _inf/ +│ ├── metrics.json +│ └── plots/ + +``` + +--- + +## Running the Script + +### Command Line + +```bash +uv run --offline latent_space_analysis --run-id +``` + +Optional verbose mode: + +```bash +uv run --offline latent_space_analysis --run_id --run-id --verbose +``` + +--- + +## Execution Flow (Main) + +When executed, the script performs the following steps: + +1. **Inference** + + * Creates a new inference run ID + * Generates forecasts from the trained run + +2. **Evaluation** + + * Computes metrics and generates plots + +3. **Reporting** + + * Prints training losses from the original run + * Prints validation losses from the inference run + * Prints evaluation metrics for all streams + +--- + +## Metrics & Outputs + +### Metrics (for now) + +* RMSE +* FROCT + +### Outputs + +* Summary plots +* Per-stream maps and histograms +* Console summaries of losses and evaluation scores + +--- + +## Notes & Assumptions + +* The original training run must already exist under `./results/` +* The `_inf` suffix is currently used for inference run naming +* Forecast steps and samples default to `"all"` during evaluation diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py index 5e3226a4e..1dac69b1a 100644 --- a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py +++ b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py @@ -241,7 +241,7 @@ def load_scores(eval_cfg, run_id): ############## MAIN ################# -if __name__ == "__main__": +def latent_space_analysis(): parser = argparse.ArgumentParser(description="Run multi-stream latent space evaluation") parser.add_argument( "--run-id", type=str, required=True, help="Run identifier for the model to evaluate" From 9a713a87f26939d483af3e121d2c7f9687878c90 Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:05:03 +0100 Subject: [PATCH 33/44] Update default_config.yml --- config/default_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index b4e0781cb..4bff1abfd 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -87,7 +87,7 @@ norm_type: "LayerNorm" ##################################### -streams_directory: "./config/streams/cerra_era5/" +streams_directory: "./config/streams/era5_1deg/" streams: ??? # type of zarr_store From a097b3c50968fafe93314add8a598218b57a09ea Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:05:30 +0100 Subject: [PATCH 34/44] Update default_forecast_config.yml --- config/default_forecast_config.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config/default_forecast_config.yml b/config/default_forecast_config.yml index e74b9df23..66d8c3994 100644 --- a/config/default_forecast_config.yml +++ b/config/default_forecast_config.yml @@ -1,4 +1,4 @@ -streams_directory: "./config/streams/cerra_era5/" +streams_directory: "./config/streams/era5_1deg/" embed_orientation: "channels" embed_unembed_mode: "block" @@ -50,7 +50,7 @@ pred_mlp_adaln: True # one is training an auto-encoder forecast_offset : 1 forecast_delta: 00:00:00 -forecast_steps: 4 +forecast_steps: 4f forecast_policy: "fixed" forecast_freeze_model: False forecast_att_dense_rate: 1.0 @@ -182,4 +182,4 @@ wgtags: # Examples: "rollout_ablation_grid" exp: None # *** Experiment-specific tags *** - grid: None \ No newline at end of file + grid: None From 27447927f23d9552ddbe5f2a28823b465785c03e Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:06:01 +0100 Subject: [PATCH 35/44] Update default_forecast_config.yml --- config/default_forecast_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_forecast_config.yml b/config/default_forecast_config.yml index 66d8c3994..056f2a506 100644 --- a/config/default_forecast_config.yml +++ b/config/default_forecast_config.yml @@ -50,7 +50,7 @@ pred_mlp_adaln: True # one is training an auto-encoder forecast_offset : 1 forecast_delta: 00:00:00 -forecast_steps: 4f +forecast_steps: 4 forecast_policy: "fixed" forecast_freeze_model: False forecast_att_dense_rate: 1.0 From 74760c855f9552ba02c2ff7da20f828145eb6d51 Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:06:39 +0100 Subject: [PATCH 36/44] Update csv_reader.py --- packages/evaluate/src/weathergen/evaluate/io/csv_reader.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py index 8bed08597..33224300e 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py @@ -19,10 +19,6 @@ # Local application / package from weathergen.evaluate.io.io_reader import Reader -# _logger = logging.getLogger(__name__) -# _logger.setLevel(logging.INFO) - - class CsvReader(Reader): """ Reader class to read evaluation data from CSV files and convert to xarray DataArray. From 668bd2c322de50154558c5320feafff69401a8cf Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:07:03 +0100 Subject: [PATCH 37/44] Update wegen_reader.py --- packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index f069ea1c0..645841b0a 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -29,10 +29,6 @@ from weathergen.evaluate.scores.score_utils import to_list from weathergen.evaluate.utils.derived_channels import DeriveChannels -# _logger = logging.getLogger(__name__) -# _logger.setLevel(logging.INFO) - - class WeatherGenReader(Reader): def __init__( self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose=True From 594efb11a81897568a72b4f5516102e99c3745a5 Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:07:48 +0100 Subject: [PATCH 38/44] Update plot_utils.py --- .../evaluate/src/weathergen/evaluate/plotting/plot_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 9ce8cd65e..2488d931c 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -12,8 +12,6 @@ import numpy as np import xarray as xr -# _logger = logging.getLogger(__name__) - def collect_streams(runs: dict): """Get all unique streams across runs, sorted. From 725e1456f0f8d16e0b2b8c9ff6198284b4c5de28 Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:08:19 +0100 Subject: [PATCH 39/44] Update utils.py --- packages/evaluate/src/weathergen/evaluate/utils/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index cab8b9e5f..10ae1c694 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -156,7 +156,6 @@ def calc_scores_per_stream( _logger.debug( f"Applying bounding box mask for region '{region}' to targets and predictions." ) - # breakpoint() tars, preds, tars_next, preds_next = [ bbox.apply_mask(x) if x is not None else None for x in (tars, preds, tars_next, preds_next) From 04a487294b43457b3294fabad1febbece6e8ba66 Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:14:28 +0100 Subject: [PATCH 40/44] Update plotter.py --- packages/evaluate/src/weathergen/evaluate/plotting/plotter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 107290c53..1a58ec16b 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -991,7 +991,7 @@ def ratio_plot( baseline_name = self.baseline baseline_idx = run_ids.index(self.baseline) if self.baseline in run_ids else None if baseline_idx is not None: - _logger.info(f"Using baseline run ID '{self.baseline}' for ratio plot.") + self._logger.info(f"Using baseline run ID '{self.baseline}' for ratio plot.") baseline = data_list[baseline_idx] else: From a98404537b15da0e5a3e1901b9e4b24782d5b7ec Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Thu, 29 Jan 2026 15:34:21 +0100 Subject: [PATCH 41/44] lint --- .../evaluate/src/weathergen/evaluate/io/csv_reader.py | 2 +- .../evaluate/src/weathergen/evaluate/io/wegen_reader.py | 8 ++++---- .../weathergen/evaluate/latent_space/latent_space_eval.py | 7 ++++--- .../src/weathergen/evaluate/plotting/plot_utils.py | 1 + .../evaluate/src/weathergen/evaluate/run_evaluation.py | 1 - 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py index 33224300e..5bb3d4c01 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/csv_reader.py @@ -19,12 +19,12 @@ # Local application / package from weathergen.evaluate.io.io_reader import Reader + class CsvReader(Reader): """ Reader class to read evaluation data from CSV files and convert to xarray DataArray. """ - def __init__( self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose=True ): diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 645841b0a..6291b0b61 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -29,6 +29,7 @@ from weathergen.evaluate.scores.score_utils import to_list from weathergen.evaluate.utils.derived_channels import DeriveChannels + class WeatherGenReader(Reader): def __init__( self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose=True @@ -306,7 +307,9 @@ def get_recomputable_metrics(self, metrics): class WeatherGenZarrReader(WeatherGenReader): - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose=True): + def __init__( + self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, verbose=True + ): """Data reader class for WeatherGenerator model outputs stored in Zarr format.""" super().__init__(eval_cfg, run_id, private_paths, verbose) @@ -322,9 +325,6 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non ): self.fname_zarr = fname_zarr else: - self.fname_zarr = fname_zarr_old - - if not self.fname_zarr.exists(): self._logger.error(f"Zarr file {self.fname_zarr} does not exist.") raise FileNotFoundError(f"Zarr file {self.fname_zarr} does not exist") diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py index 1dac69b1a..15852117f 100644 --- a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py +++ b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py @@ -13,7 +13,7 @@ It performs training and inference with multiple data sources including gridded and obs data. Command: -uv run --offline packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py --run-id j2dkivn8 +uv run --offline latent_space_eval --run-id j2dkivn8 """ import argparse @@ -92,8 +92,8 @@ def get_evaluation_config(run_id, verbose=False): "summary_dir": f"./results/{run_id}/plots/summary/", "print_summary": False, "verbose": verbose, - "ratio_plots": True, - "bar_plots": True, + "ratio_plots": True, + "bar_plots": True, }, "run_ids": { run_id: { @@ -241,6 +241,7 @@ def load_scores(eval_cfg, run_id): ############## MAIN ################# + def latent_space_analysis(): parser = argparse.ArgumentParser(description="Run multi-stream latent space evaluation") parser.add_argument( diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 2488d931c..9a13702f6 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -384,6 +384,7 @@ def bar_plot_metric_region( name = "_".join([metric, region, stream]) br_plotter.plot(selected_data, run_ids, metric, channels_set, name) + class DefaultMarkerSize: """ Utility class for managing default configuration values, such as marker sizes diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index fb13a865b..413815f3e 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -354,7 +354,6 @@ def evaluate_from_config( processes=num_processes, initializer=setup_worker_logger, initargs=(log_queue, verbose), - ) as pool: results = pool.map( _process_stream_wrapper, From 139a2e2d76c7c133b99b41d898ee655e8f102294 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Thu, 29 Jan 2026 16:30:58 +0100 Subject: [PATCH 42/44] fix verbose --- .../src/weathergen/evaluate/io/io_reader.py | 1 - .../evaluate/latent_space/latent_space_eval.py | 14 ++++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py index aab517916..f5024e6fa 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py @@ -89,7 +89,6 @@ def __init__( ) # base directory where results will be stored self._logger = logging.getLogger(__name__) - logger_level = logging.INFO if verbose else logging.CRITICAL + 1 self._logger.setLevel(logger_level) diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py index 15852117f..58936b97e 100644 --- a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py +++ b/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py @@ -61,15 +61,16 @@ def infer_multi_stream(run_id): "--samples", "10", "--options", - "forecast_offset=0", + # "training_config.forecast.offset=0", + # "training_config.forecast.num_steps=0" "zarr_store=zip", ] + ["--from-run-id", run_id, "--run-id", new_run_id, "--streams-output"] + streams - + [ - "--config", - "./config/evaluate/latent_space_eval_config.yaml", - ] + # + [ + # "--config", + # "./config/evaluate/latent_space_eval_config.yaml", + # ] ) return new_run_id @@ -223,8 +224,9 @@ def load_scores(eval_cfg, run_id): metrics = list(eval_cfg.evaluation.get("metrics")) regions = list(eval_cfg.evaluation.get("regions")) + verbose = eval_cfg.get("verbose", False) - reader = WeatherGenJSONReader(run_cfg, run_id, None, regions, metrics) + reader = WeatherGenJSONReader(run_cfg, run_id, None, regions, metrics, verbose=verbose) scores = {} for stream_name in streams: From 85bc1b28a29f59d82ff1f16790dbfdbae1c1052c Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Fri, 30 Jan 2026 10:43:43 +0100 Subject: [PATCH 43/44] rename to ssl analysis --- packages/evaluate/pyproject.toml | 2 +- .../evaluate/{latent_space => ssl}/README.md | 9 +++------ .../latent_space_eval.py => ssl/ssl_eval.py} | 12 ++++++------ .../latent_space_eval.yaml => ssl/ssl_eval.yaml} | 0 4 files changed, 10 insertions(+), 13 deletions(-) rename packages/evaluate/src/weathergen/evaluate/{latent_space => ssl}/README.md (91%) rename packages/evaluate/src/weathergen/evaluate/{latent_space/latent_space_eval.py => ssl/ssl_eval.py} (96%) rename packages/evaluate/src/weathergen/evaluate/{latent_space/latent_space_eval.yaml => ssl/ssl_eval.yaml} (100%) diff --git a/packages/evaluate/pyproject.toml b/packages/evaluate/pyproject.toml index 83281cca0..8d8c5b918 100644 --- a/packages/evaluate/pyproject.toml +++ b/packages/evaluate/pyproject.toml @@ -29,7 +29,7 @@ dev = [ ] [project.scripts] -latent_space_eval= "weathergen.evaluate.latent_space.latent_space_eval:latent_space_analysis" +ssl_analysis= "weathergen.evaluate.ssl.ssl_eval:ssl_analysis" evaluation = "weathergen.evaluate.run_evaluation:evaluate" export = "weathergen.evaluate.export.export_inference:export" diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/README.md b/packages/evaluate/src/weathergen/evaluate/ssl/README.md similarity index 91% rename from packages/evaluate/src/weathergen/evaluate/latent_space/README.md rename to packages/evaluate/src/weathergen/evaluate/ssl/README.md index 6eae6ec39..122e861fe 100644 --- a/packages/evaluate/src/weathergen/evaluate/latent_space/README.md +++ b/packages/evaluate/src/weathergen/evaluate/ssl/README.md @@ -1,4 +1,4 @@ -# Multi-Stream Latent Space Inference & Evaluation +# Multi-Stream Self Supervised Learning Inference & Evaluation This script runs **inference and evaluation for a multi-stream weather generation model** and produces quantitative metrics and diagnostic plots for each stream. @@ -95,9 +95,6 @@ The script assumes the following directory layout: ``` . -├── config/ -│ └── evaluate/ -│ └── latent_space_eval_config.yaml ├── models/ │ └── │ ├── _latest.chkpt @@ -116,13 +113,13 @@ The script assumes the following directory layout: ### Command Line ```bash -uv run --offline latent_space_analysis --run-id +uv run --offline ssl_analysis --run-id ``` Optional verbose mode: ```bash -uv run --offline latent_space_analysis --run_id --run-id --verbose +uv run --offline ssl_analysis --run_id --run-id --verbose ``` --- diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py b/packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.py similarity index 96% rename from packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py rename to packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.py index 58936b97e..b9cbf70ae 100644 --- a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.py +++ b/packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.py @@ -8,12 +8,12 @@ # nor does it submit to any jurisdiction. """ -Integration test for the Weather Generator with multiple streams and observations. -This test must run on a GPU machine. -It performs training and inference with multiple data sources including gridded and obs data. +Script to run Inference and evaluation for the Weather Generator with multiple streams and observations. +This script must run on a GPU machine. +It performs a standardised routine of inference and evaluation over multiple data sources including gridded and obs data. Command: -uv run --offline latent_space_eval --run-id j2dkivn8 +uv run --offline ssl_analysis --run-id j2dkivn8 """ import argparse @@ -244,8 +244,8 @@ def load_scores(eval_cfg, run_id): ############## MAIN ################# -def latent_space_analysis(): - parser = argparse.ArgumentParser(description="Run multi-stream latent space evaluation") +def ssl_analysis(): + parser = argparse.ArgumentParser(description="Run multi-stream SSL evaluation") parser.add_argument( "--run-id", type=str, required=True, help="Run identifier for the model to evaluate" ) diff --git a/packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.yaml b/packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.yaml similarity index 100% rename from packages/evaluate/src/weathergen/evaluate/latent_space/latent_space_eval.yaml rename to packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.yaml From 1c436bc845acec73dede3cc812549b5c133d5a93 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Fri, 30 Jan 2026 10:46:41 +0100 Subject: [PATCH 44/44] lint --- packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.py b/packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.py index b9cbf70ae..18d848815 100644 --- a/packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.py +++ b/packages/evaluate/src/weathergen/evaluate/ssl/ssl_eval.py @@ -8,9 +8,11 @@ # nor does it submit to any jurisdiction. """ -Script to run Inference and evaluation for the Weather Generator with multiple streams and observations. +Script to run Inference and evaluation for the Weather Generator with +multiple streams and observations. This script must run on a GPU machine. -It performs a standardised routine of inference and evaluation over multiple data sources including gridded and obs data. +It performs a standardised routine of inference and evaluation +over multiple data sources including gridded and obs data. Command: uv run --offline ssl_analysis --run-id j2dkivn8