From 79aba955e17746dbaa5de35ee96b3ac2e6d34b0f Mon Sep 17 00:00:00 2001 From: Sebastian Buschow Date: Thu, 4 Dec 2025 08:55:25 +0100 Subject: [PATCH 01/13] split WeatherGenReader functionality to allow reading only JSON adding weathergen JSON reader to develop --- .../weathergen/evaluate/io/wegen_reader.py | 337 ++++++++++-------- .../src/weathergen/evaluate/run_evaluation.py | 8 +- 2 files changed, 188 insertions(+), 157 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 0141fa587..19e5bd256 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -35,8 +35,6 @@ class WeatherGenReader(Reader): def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): - """Data reader class for WeatherGenerator model outputs stored in Zarr format.""" - super().__init__(eval_cfg, run_id, private_paths) # TODO: remove backwards compatibility to "epoch" in Feb. 2026 @@ -71,24 +69,6 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non self.eval_cfg.get("metrics_dir", self.metrics_base_dir / self.run_id / "evaluation") ) - fname_zarr_new = self.results_dir.joinpath( - f"validation_chkpt{self.mini_epoch:05d}_rank{self.rank:04d}.zarr" - ) - fname_zarr_old = self.results_dir.joinpath( - f"validation_epoch{self.mini_epoch:05d}_rank{self.rank:04d}.zarr" - ) - - if fname_zarr_new.exists() or fname_zarr_new.is_dir(): - self.fname_zarr = fname_zarr_new - else: - self.fname_zarr = fname_zarr_old - - if not self.fname_zarr.exists() or not self.fname_zarr.is_dir(): - _logger.error(f"Zarr file {self.fname_zarr} does not exist.") - raise FileNotFoundError( - f"Zarr file {self.fname_zarr} does not exist or is not a directory." - ) - def get_inference_config(self): """ load the config associated to the inference run (different from the eval_cfg which @@ -116,6 +96,189 @@ def get_inference_config(self): return config + def get_climatology_filename(self, stream: str) -> str | None: + """ + Get the climatology filename for a given stream from the inference configuration. + Parameters + ---------- + stream : + Name of the data stream. + Returns + ------- + Climatology filename if specified, otherwise None. + """ + + stream_dict = self.get_stream(stream) + + clim_data_path = stream_dict.get("climatology_path", None) + if not clim_data_path: + clim_base_dir = self.inference_cfg.get("data_path_aux", None) + + clim_fn = next( + ( + item.get("climatology_filename") + for item in self.inference_cfg["streams"] + if item.get("name") == stream + ), + None, + ) + + if clim_base_dir and clim_fn: + clim_data_path = Path(clim_base_dir).join(clim_fn) + else: + _logger.warning( + f"No climatology path specified for stream {stream}. Setting climatology to " + "NaN. Add 'climatology_path' to evaluation config to use metrics like ACC." + ) + + return clim_data_path + + def get_channels(self, stream: str) -> list[str]: + """ + Get the list of channels for a given stream from the config. + + Parameters + ---------- + stream : + The name of the stream to get channels for. + + Returns + ------- + A list of channel names. + """ + _logger.debug(f"Getting channels for stream {stream}...") + all_channels = self.get_inference_stream_attr(stream, "val_target_channels") + _logger.debug(f"Channels found in config: {all_channels}") + return all_channels + + def load_scores(self, stream: str, regions: str, metrics: str) -> xr.DataArray | None: + """ + Load the pre-computed scores for a given run, stream and metric and epoch. + + Parameters + ---------- + reader : + Reader object containing all info for a specific run_id + stream : + Stream name. + regions : + Region names. + metrics : + Metric names. + + Returns + ------- + xr.DataArray + The metric DataArray. + missing_metrics: + dictionary of missing regions and metrics that need to be recomputed. + """ + + local_scores = {} + missing_metrics = {} + for region in regions: + for metric in metrics: + score_path = ( + Path(self.metrics_dir) + / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" + ) + _logger.debug(f"Looking for: {score_path}") + + if score_path.exists(): + with open(score_path) as f: + data_dict = json.load(f) + score_dict = xr.DataArray.from_dict(data_dict) + + available_data = self.check_availability(stream, score_dict, mode="evaluation") + + if available_data.score_availability: + score_dict = score_dict.sel( + sample=available_data.samples, + channel=available_data.channels, + forecast_step=available_data.fsteps, + ) + local_scores.setdefault(metric, {}).setdefault(region, {}).setdefault( + stream, {} + )[self.run_id] = score_dict + continue + + # all other cases: recompute scores + missing_metrics.setdefault(region, []).append(metric) + continue + + return local_scores, missing_metrics + + def get_inference_stream_attr(self, stream_name: str, key: str, default=None): + """ + Get the value of a key for a specific stream from the a model config. + + Parameters: + ------------ + config: + The full configuration dictionary. + stream_name: + The name of the stream (e.g. 'ERA5'). + key: + The key to look up (e.g. 'tokenize_spacetime'). + default: Optional + Value to return if not found (default: None). + + Returns: + The parameter value if found, otherwise the default. + """ + for stream in self.inference_cfg.get("streams", []): + if stream.get("name") == stream_name: + return stream.get(key, default) + return default + +class WeatherGenJSONReader(WeatherGenReader): + + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, actual_eval_cfg: dict = {}): + super().__init__(eval_cfg, run_id, private_paths) + # is this the best way to learn which steps and samples are available? + dummy = self.load_scores( + stream=list(self.eval_cfg.streams.keys())[0], + region=actual_eval_cfg.regions[0], + metric=actual_eval_cfg.metrics[0] + ) + self.samples = set(dummy.sample.values) + self.fsteps = set(dummy.forecast_step.values) + self.ens = list(dummy.ens.values) + + def get_samples(self) -> set[int]: + return self.samples + + def get_forecast_steps(self) -> set[int]: + return self.fsteps + + def get_ensemble(self, stream: str | None = None) -> list[str]: + return self.ens + + +class WeatherGenZarrReader(WeatherGenReader): + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): + """Data reader class for WeatherGenerator model outputs stored in Zarr format.""" + super().__init__(eval_cfg, run_id, private_paths) + + fname_zarr_new = self.results_dir.joinpath( + f"validation_chkpt{self.mini_epoch:05d}_rank{self.rank:04d}.zarr" + ) + fname_zarr_old = self.results_dir.joinpath( + f"validation_epoch{self.mini_epoch:05d}_rank{self.rank:04d}.zarr" + ) + + if fname_zarr_new.exists() or fname_zarr_new.is_dir(): + self.fname_zarr = fname_zarr_new + else: + self.fname_zarr = fname_zarr_old + + if not self.fname_zarr.exists() or not self.fname_zarr.is_dir(): + _logger.error(f"Zarr file {self.fname_zarr} does not exist.") + raise FileNotFoundError( + f"Zarr file {self.fname_zarr} does not exist or is not a directory." + ) + + def get_data( self, stream: str, @@ -335,43 +498,6 @@ def scale_z_channels(self, data: xr.DataArray, stream: str) -> xr.DataArray: return data_scaled - def get_climatology_filename(self, stream: str) -> str | None: - """ - Get the climatology filename for a given stream from the inference configuration. - Parameters - ---------- - stream : - Name of the data stream. - Returns - ------- - Climatology filename if specified, otherwise None. - """ - - stream_dict = self.get_stream(stream) - - clim_data_path = stream_dict.get("climatology_path", None) - if not clim_data_path: - clim_base_dir = self.inference_cfg.get("data_path_aux", None) - - clim_fn = next( - ( - item.get("climatology_filename") - for item in self.inference_cfg["streams"] - if item.get("name") == stream - ), - None, - ) - - if clim_base_dir and clim_fn: - clim_data_path = Path(clim_base_dir).join(clim_fn) - else: - _logger.warning( - f"No climatology path specified for stream {stream}. Setting climatology to " - "NaN. Add 'climatology_path' to evaluation config to use metrics like ACC." - ) - - return clim_data_path - def get_stream(self, stream: str): """ returns the dictionary associated to a particular stream. @@ -402,24 +528,6 @@ def get_forecast_steps(self) -> set[int]: with ZarrIO(self.fname_zarr) as zio: return set(int(f) for f in zio.forecast_steps) - def get_channels(self, stream: str) -> list[str]: - """ - Get the list of channels for a given stream from the config. - - Parameters - ---------- - stream : - The name of the stream to get channels for. - - Returns - ------- - A list of channel names. - """ - _logger.debug(f"Getting channels for stream {stream}...") - all_channels = self.get_inference_stream_attr(stream, "val_target_channels") - _logger.debug(f"Channels found in config: {all_channels}") - return all_channels - def get_ensemble(self, stream: str | None = None) -> list[str]: """Get the list of ensemble member names for a given stream from the config. Parameters @@ -478,85 +586,6 @@ def is_regular(self, stream: str) -> bool: _logger.debug("Latitude and longitude coordinates are regularly spaced.") return True - def load_scores(self, stream: str, regions: str, metrics: str) -> xr.DataArray | None: - """ - Load the pre-computed scores for a given run, stream and metric and epoch. - - Parameters - ---------- - reader : - Reader object containing all info for a specific run_id - stream : - Stream name. - regions : - Region names. - metrics : - Metric names. - - Returns - ------- - xr.DataArray - The metric DataArray. - missing_metrics: - dictionary of missing regions and metrics that need to be recomputed. - """ - - local_scores = {} - missing_metrics = {} - for region in regions: - for metric in metrics: - score_path = ( - Path(self.metrics_dir) - / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" - ) - _logger.debug(f"Looking for: {score_path}") - - if score_path.exists(): - with open(score_path) as f: - data_dict = json.load(f) - score_dict = xr.DataArray.from_dict(data_dict) - - available_data = self.check_availability(stream, score_dict, mode="evaluation") - - if available_data.score_availability: - score_dict = score_dict.sel( - sample=available_data.samples, - channel=available_data.channels, - forecast_step=available_data.fsteps, - ) - local_scores.setdefault(metric, {}).setdefault(region, {}).setdefault( - stream, {} - )[self.run_id] = score_dict - continue - - # all other cases: recompute scores - missing_metrics.setdefault(region, []).append(metric) - continue - - return local_scores, missing_metrics - - def get_inference_stream_attr(self, stream_name: str, key: str, default=None): - """ - Get the value of a key for a specific stream from the a model config. - - Parameters: - ------------ - config: - The full configuration dictionary. - stream_name: - The name of the stream (e.g. 'ERA5'). - key: - The key to look up (e.g. 'tokenize_spacetime'). - default: Optional - Value to return if not found (default: None). - - Returns: - The parameter value if found, otherwise the default. - """ - for stream in self.inference_cfg.get("streams", []): - if stream.get("name") == stream_name: - return stream.get(key, default) - return default ################### Helper functions ######################## diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 67ce8e02c..ab2ddd95c 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -28,7 +28,7 @@ from weathergen.common.logger import init_loggers from weathergen.common.platform_env import get_platform_env from weathergen.evaluate.io.csv_reader import CsvReader -from weathergen.evaluate.io.wegen_reader import WeatherGenReader +from weathergen.evaluate.io.wegen_reader import WeatherGenZarrReader, WeatherGenJSONReader from weathergen.evaluate.plotting.plot_utils import collect_channels from weathergen.evaluate.utils.utils import ( calc_scores_per_stream, @@ -281,9 +281,11 @@ def evaluate_from_config( type_ = run.get("type", "zarr") if type_ == "zarr": - reader = WeatherGenReader(run, run_id, private_paths) + reader = WeatherGenZarrReader(run, run_id, private_paths) elif type_ == "csv": reader = CsvReader(run, run_id, private_paths) + elif type == "json": + reader = WeatherGenJSONReader(run, run_id, private_paths, cfg.evaluation) else: raise ValueError(f"Unknown run type: {type_}") @@ -335,7 +337,7 @@ def evaluate_from_config( channels_set = collect_channels(scores_dict, metric, region, runs) for run_id, run in runs.items(): - reader = WeatherGenReader(run, run_id, private_paths) + reader = WeatherGenZarrReader(run, run_id, private_paths) from_run_id = reader.inference_cfg["from_run_id"] parent_run = get_or_create_mlflow_parent_run(mlflow_client, from_run_id) _logger.info(f"MLFlow parent run: {parent_run}") From 71319837225d8cf1cf35143a8bffb1e1176e633e Mon Sep 17 00:00:00 2001 From: Sebastian Buschow Date: Thu, 4 Dec 2025 10:07:09 +0100 Subject: [PATCH 02/13] informative error when metrics are not there --- .../src/weathergen/evaluate/io/wegen_reader.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 19e5bd256..a089bc238 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -236,11 +236,13 @@ class WeatherGenJSONReader(WeatherGenReader): def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, actual_eval_cfg: dict = {}): super().__init__(eval_cfg, run_id, private_paths) # is this the best way to learn which steps and samples are available? - dummy = self.load_scores( - stream=list(self.eval_cfg.streams.keys())[0], - region=actual_eval_cfg.regions[0], - metric=actual_eval_cfg.metrics[0] - ) + stream=list(self.eval_cfg.streams.keys())[0] + region=actual_eval_cfg.regions[0] + metric=actual_eval_cfg.metrics[0] + dummy = self.load_scores( stream, region, metric ) + if dummy is None: + raise ValueError(f"JSONreader could not find {metric} for {run_id}, stream {stream}, region {region}. " + "use type: zarr instead if possible") self.samples = set(dummy.sample.values) self.fsteps = set(dummy.forecast_step.values) self.ens = list(dummy.ens.values) From 45dfaf2c65239f19c821ec9a5afeb6852f0689f1 Mon Sep 17 00:00:00 2001 From: Sebastian Buschow Date: Wed, 7 Jan 2026 14:33:29 +0100 Subject: [PATCH 03/13] restore JSONreader after rebase --- .../weathergen/evaluate/io/wegen_reader.py | 47 ++++++++++--------- .../src/weathergen/evaluate/run_evaluation.py | 4 +- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index a089bc238..51379fc00 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -151,9 +151,9 @@ def get_channels(self, stream: str) -> list[str]: _logger.debug(f"Channels found in config: {all_channels}") return all_channels - def load_scores(self, stream: str, regions: str, metrics: str) -> xr.DataArray | None: + def load_scores(self, stream: str, regions: list[str], metrics: list[str]) -> xr.DataArray | None: """ - Load the pre-computed scores for a given run, stream and metric and epoch. + Load multiple pre-computed scores for a given run, stream and metric and epoch. Parameters ---------- @@ -178,28 +178,18 @@ def load_scores(self, stream: str, regions: str, metrics: str) -> xr.DataArray | missing_metrics = {} for region in regions: for metric in metrics: - score_path = ( - Path(self.metrics_dir) - / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" - ) - _logger.debug(f"Looking for: {score_path}") - - if score_path.exists(): - with open(score_path) as f: - data_dict = json.load(f) - score_dict = xr.DataArray.from_dict(data_dict) - - available_data = self.check_availability(stream, score_dict, mode="evaluation") - + score = self.load_single_score(stream, region, metric) + if score is not None: + available_data = self.check_availability(stream, score, mode="evaluation") if available_data.score_availability: - score_dict = score_dict.sel( + score = score.sel( sample=available_data.samples, channel=available_data.channels, forecast_step=available_data.fsteps, ) local_scores.setdefault(metric, {}).setdefault(region, {}).setdefault( stream, {} - )[self.run_id] = score_dict + )[self.run_id] = score continue # all other cases: recompute scores @@ -208,6 +198,23 @@ def load_scores(self, stream: str, regions: str, metrics: str) -> xr.DataArray | return local_scores, missing_metrics + def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArray | None: + ''' + Load a single pre-computed score for a given run, stream and metric + ''' + score_path = ( + Path(self.metrics_dir) + / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" + ) + _logger.debug(f"Looking for: {score_path}") + if score_path.exists(): + with open(score_path) as f: + data_dict = json.load(f) + score = xr.DataArray.from_dict(data_dict) # not a dict though + else: + score = None + return score + def get_inference_stream_attr(self, stream_name: str, key: str, default=None): """ Get the value of a key for a specific stream from the a model config. @@ -233,13 +240,11 @@ def get_inference_stream_attr(self, stream_name: str, key: str, default=None): class WeatherGenJSONReader(WeatherGenReader): - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, actual_eval_cfg: dict = {}): + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, region: str = "global", metric: str = "rmse"): super().__init__(eval_cfg, run_id, private_paths) # is this the best way to learn which steps and samples are available? stream=list(self.eval_cfg.streams.keys())[0] - region=actual_eval_cfg.regions[0] - metric=actual_eval_cfg.metrics[0] - dummy = self.load_scores( stream, region, metric ) + dummy = self.load_single_score( stream, region, metric ) if dummy is None: raise ValueError(f"JSONreader could not find {metric} for {run_id}, stream {stream}, region {region}. " "use type: zarr instead if possible") diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index ab2ddd95c..aef0601c5 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -284,8 +284,8 @@ def evaluate_from_config( reader = WeatherGenZarrReader(run, run_id, private_paths) elif type_ == "csv": reader = CsvReader(run, run_id, private_paths) - elif type == "json": - reader = WeatherGenJSONReader(run, run_id, private_paths, cfg.evaluation) + elif type_ == "json": + reader = WeatherGenJSONReader(run, run_id, private_paths, cfg.evaluation.regions[0], cfg.evaluation.metrics[0]) else: raise ValueError(f"Unknown run type: {type_}") From 28db65896335422e7efcb82e9196c09f002352de Mon Sep 17 00:00:00 2001 From: Sebastian Buschow Date: Wed, 7 Jan 2026 17:49:59 +0100 Subject: [PATCH 04/13] JSONreader mostly restored --- .../weathergen/evaluate/io/wegen_reader.py | 5 +++ .../src/weathergen/evaluate/run_evaluation.py | 32 +++++++++++-------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 51379fc00..49a8fb23b 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -261,6 +261,11 @@ def get_forecast_steps(self) -> set[int]: def get_ensemble(self, stream: str | None = None) -> list[str]: return self.ens + def get_data(self, *args, **kwargs): + # TODO this should not be needed, the reader should not even be created if this is the case + # it can still happen when a particular score was available for a different channel + raise ValueError(f"Missing JSON data for run {self.run_id}.") + class WeatherGenZarrReader(WeatherGenReader): def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index aef0601c5..85713bd4b 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -151,6 +151,22 @@ def evaluate_from_args(argl: list[str], log_queue: mp.Queue) -> None: assert isinstance(cf, DictConfig) evaluate_from_config(cf, mlflow_client, log_queue) +def get_reader( reader_type: str, + run: dict, + run_id: str, + private_paths: dict[str, str], + region: str | None = None, + metric: str | None = None + ): + if reader_type == "zarr": + reader = WeatherGenZarrReader(run, run_id, private_paths) + elif reader_type == "csv": + reader = CsvReader(run, run_id, private_paths) + elif reader_type == "json": + reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric) + else: + raise ValueError(f"Unknown reader type: {reader_type}") + return reader def _process_stream_wrapper( args: dict[str, object], @@ -193,11 +209,7 @@ def _process_stream( """ # try: type_ = run.get("type", "zarr") - reader = ( - WeatherGenReader(run, run_id, private_paths) - if type_ == "zarr" - else CsvReader(run, run_id, private_paths) - ) + reader = get_reader( type_, run, run_id, private_paths, regions[0], metrics[0] ) stream_dict = reader.get_stream(stream) if not stream_dict: @@ -279,15 +291,7 @@ def evaluate_from_config( # Build tasks per stream for run_id, run in runs.items(): type_ = run.get("type", "zarr") - - if type_ == "zarr": - reader = WeatherGenZarrReader(run, run_id, private_paths) - elif type_ == "csv": - reader = CsvReader(run, run_id, private_paths) - elif type_ == "json": - reader = WeatherGenJSONReader(run, run_id, private_paths, cfg.evaluation.regions[0], cfg.evaluation.metrics[0]) - else: - raise ValueError(f"Unknown run type: {type_}") + reader = get_reader( type_, run, run_id, private_paths, cfg.evaluation.regions[0], cfg.evaluation.metrics[0] ) for stream in reader.streams: tasks.append( From 8fd16e960e7f95cb6d728186f960197d3ecd948f Mon Sep 17 00:00:00 2001 From: Sebastian Buschow Date: Wed, 7 Jan 2026 17:58:23 +0100 Subject: [PATCH 05/13] MLFlow logging independent of JSON/zarr --- packages/evaluate/src/weathergen/evaluate/run_evaluation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 85713bd4b..900bd3f98 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -28,7 +28,7 @@ from weathergen.common.logger import init_loggers from weathergen.common.platform_env import get_platform_env from weathergen.evaluate.io.csv_reader import CsvReader -from weathergen.evaluate.io.wegen_reader import WeatherGenZarrReader, WeatherGenJSONReader +from weathergen.evaluate.io.wegen_reader import WeatherGenReader, WeatherGenZarrReader, WeatherGenJSONReader from weathergen.evaluate.plotting.plot_utils import collect_channels from weathergen.evaluate.utils.utils import ( calc_scores_per_stream, @@ -341,7 +341,7 @@ def evaluate_from_config( channels_set = collect_channels(scores_dict, metric, region, runs) for run_id, run in runs.items(): - reader = WeatherGenZarrReader(run, run_id, private_paths) + reader = WeatherGenReader(run, run_id, private_paths) from_run_id = reader.inference_cfg["from_run_id"] parent_run = get_or_create_mlflow_parent_run(mlflow_client, from_run_id) _logger.info(f"MLFlow parent run: {parent_run}") From 7cbcdbaf54281a28efcdd7e3696a0c34119c6ff9 Mon Sep 17 00:00:00 2001 From: Sebastian Buschow Date: Fri, 9 Jan 2026 12:02:52 +0100 Subject: [PATCH 06/13] linting, properly cheking fsteps, ens, samples in JSONreader --- .../src/weathergen/evaluate/io/io_reader.py | 2 +- .../weathergen/evaluate/io/wegen_reader.py | 78 +++++++++++++------ .../src/weathergen/evaluate/run_evaluation.py | 29 ++++--- 3 files changed, 74 insertions(+), 35 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py index 2dd12c27a..92e6748fd 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py @@ -231,7 +231,7 @@ def check_availability( missing.remove("mean") if missing: _logger.info( - f"Requested {name}(s) {missing} do(es) not exist in Zarr. " + f"Requested {name}(s) {missing} is unavailable. " f"Removing missing {name}(s) for {mode}." ) requested[name] = requested[name] & reader_data[name] diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 49a8fb23b..f753e0f72 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -10,6 +10,7 @@ # Standard library import json import logging +from collections import defaultdict from pathlib import Path # Third-party @@ -151,7 +152,9 @@ def get_channels(self, stream: str) -> list[str]: _logger.debug(f"Channels found in config: {all_channels}") return all_channels - def load_scores(self, stream: str, regions: list[str], metrics: list[str]) -> xr.DataArray | None: + def load_scores( + self, stream: str, regions: list[str], metrics: list[str] + ) -> xr.DataArray | None: """ Load multiple pre-computed scores for a given run, stream and metric and epoch. @@ -199,18 +202,18 @@ def load_scores(self, stream: str, regions: list[str], metrics: list[str]) -> xr return local_scores, missing_metrics def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArray | None: - ''' + """ Load a single pre-computed score for a given run, stream and metric - ''' + """ score_path = ( - Path(self.metrics_dir) - / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" - ) + Path(self.metrics_dir) + / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" + ) _logger.debug(f"Looking for: {score_path}") if score_path.exists(): with open(score_path) as f: data_dict = json.load(f) - score = xr.DataArray.from_dict(data_dict) # not a dict though + score = xr.DataArray.from_dict(data_dict) # not a dict though else: score = None return score @@ -238,28 +241,57 @@ def get_inference_stream_attr(self, stream_name: str, key: str, default=None): return stream.get(key, default) return default -class WeatherGenJSONReader(WeatherGenReader): - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None, region: str = "global", metric: str = "rmse"): +class WeatherGenJSONReader(WeatherGenReader): + def __init__( + self, + eval_cfg: dict, + run_id: str, + private_paths: dict | None = None, + regions: list[str] | None = None, + metrics: list[str] | None = None, + ): super().__init__(eval_cfg, run_id, private_paths) - # is this the best way to learn which steps and samples are available? - stream=list(self.eval_cfg.streams.keys())[0] - dummy = self.load_single_score( stream, region, metric ) - if dummy is None: - raise ValueError(f"JSONreader could not find {metric} for {run_id}, stream {stream}, region {region}. " - "use type: zarr instead if possible") - self.samples = set(dummy.sample.values) - self.fsteps = set(dummy.forecast_step.values) - self.ens = list(dummy.ens.values) + # goes looking for the coordinates available for all streams, regions, metrics + streams = list(self.eval_cfg.streams.keys()) + coord_names = ["sample", "forecast_step", "ens"] + all_coords = {name: [] for name in coord_names} # collect all available coordinates + provenance = { + name: defaultdict(list) for name in coord_names + } # remember who had which coords, so we can warn about it later. + for stream in streams: + for region in regions: + for metric in metrics: + score = self.load_single_score(stream, region, metric) + if score is None: + raise ValueError( + f"JSONreader couldn't find {metric} for {run_id}, stream {stream}, " + f"region {region}. Use type: zarr instead if possible." + ) + else: + for name in coord_names: + vals = set(score[name].values) + all_coords[name].append(vals) + for val in vals: + provenance[name][val].append((stream, region, metric)) + self.common_coords = {name: set.intersection(*all_coords[name]) for name in coord_names} + # issue warnings for skipped coords + for name in coord_names: + skipped = set.union(*all_coords[name]) - self.common_coords[name] + if skipped: + message = [f"Some {name}(s) were not common among streams, regions and metrics:"] + for val in skipped: + message.append(f" {val} only in {provenance[name][val]}") + _logger.warning("\n".join(message)) def get_samples(self) -> set[int]: - return self.samples + return self.common_coords["sample"] def get_forecast_steps(self) -> set[int]: - return self.fsteps - + return self.common_coords["forecast_step"] + def get_ensemble(self, stream: str | None = None) -> list[str]: - return self.ens + return self.common_coords["ens"] def get_data(self, *args, **kwargs): # TODO this should not be needed, the reader should not even be created if this is the case @@ -290,7 +322,6 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non f"Zarr file {self.fname_zarr} does not exist or is not a directory." ) - def get_data( self, stream: str, @@ -599,7 +630,6 @@ def is_regular(self, stream: str) -> bool: return True - ################### Helper functions ######################## diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 900bd3f98..050fbd40f 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -28,7 +28,11 @@ from weathergen.common.logger import init_loggers from weathergen.common.platform_env import get_platform_env from weathergen.evaluate.io.csv_reader import CsvReader -from weathergen.evaluate.io.wegen_reader import WeatherGenReader, WeatherGenZarrReader, WeatherGenJSONReader +from weathergen.evaluate.io.wegen_reader import ( + WeatherGenJSONReader, + WeatherGenReader, + WeatherGenZarrReader, +) from weathergen.evaluate.plotting.plot_utils import collect_channels from weathergen.evaluate.utils.utils import ( calc_scores_per_stream, @@ -151,13 +155,15 @@ def evaluate_from_args(argl: list[str], log_queue: mp.Queue) -> None: assert isinstance(cf, DictConfig) evaluate_from_config(cf, mlflow_client, log_queue) -def get_reader( reader_type: str, - run: dict, - run_id: str, - private_paths: dict[str, str], - region: str | None = None, - metric: str | None = None - ): + +def get_reader( + reader_type: str, + run: dict, + run_id: str, + private_paths: dict[str, str], + region: str | None = None, + metric: str | None = None, +): if reader_type == "zarr": reader = WeatherGenZarrReader(run, run_id, private_paths) elif reader_type == "csv": @@ -168,6 +174,7 @@ def get_reader( reader_type: str, raise ValueError(f"Unknown reader type: {reader_type}") return reader + def _process_stream_wrapper( args: dict[str, object], ) -> tuple[str, str, dict[str, dict[str, dict[str, float]]]]: @@ -209,7 +216,7 @@ def _process_stream( """ # try: type_ = run.get("type", "zarr") - reader = get_reader( type_, run, run_id, private_paths, regions[0], metrics[0] ) + reader = get_reader(type_, run, run_id, private_paths, regions, metrics) stream_dict = reader.get_stream(stream) if not stream_dict: @@ -291,7 +298,9 @@ def evaluate_from_config( # Build tasks per stream for run_id, run in runs.items(): type_ = run.get("type", "zarr") - reader = get_reader( type_, run, run_id, private_paths, cfg.evaluation.regions[0], cfg.evaluation.metrics[0] ) + reader = get_reader( + type_, run, run_id, private_paths, cfg.evaluation.regions, cfg.evaluation.metrics + ) for stream in reader.streams: tasks.append( From 90bfe4b7b6edae1f55554204f96490c835f8e586 Mon Sep 17 00:00:00 2001 From: Sebastian Buschow Date: Tue, 13 Jan 2026 16:54:46 +0100 Subject: [PATCH 07/13] tiny change to restore the MergeReader --- packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py | 2 +- packages/evaluate/src/weathergen/evaluate/run_evaluation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 14331b162..82c16121e 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -680,7 +680,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non _logger.info(f"MERGE READERS: {self.run_ids} ...") for run_id in self.run_ids: - reader = WeatherGenReader(self.eval_cfg, run_id, self.private_paths) + reader = WeatherGenZarrReader(self.eval_cfg, run_id, self.private_paths) self.readers.append(reader) def get_data( diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 92154ccaf..48069572f 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -171,7 +171,7 @@ def get_reader( reader = CsvReader(run, run_id, private_paths) elif reader_type == "json": reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric) - elif reader_type == "merge": + elif reader_type == "merge": reader = WeatherGenMergeReader(run, run_id, private_paths) else: raise ValueError(f"Unknown reader type: {reader_type}") From 816e8cd6d9a33b774033954900412b0b6ffba4ad Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 14 Jan 2026 14:10:44 +0100 Subject: [PATCH 08/13] lint --- packages/common/src/weathergen/common/config.py | 2 +- packages/evaluate/src/weathergen/evaluate/run_evaluation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 26d1709b2..95a711878 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -500,7 +500,7 @@ def _load_private_conf(private_home: Path | None = None) -> DictConfig: def _load_base_conf(base: Path | Config | None) -> Config: """Return the base configuration""" - match base : + match base: case Path(): _logger.info(f"Loading specified base config from file: {base}.") conf = OmegaConf.load(base) diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 48069572f..dce9a6725 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -30,9 +30,9 @@ from weathergen.evaluate.io.csv_reader import CsvReader from weathergen.evaluate.io.wegen_reader import ( WeatherGenJSONReader, + WeatherGenMergeReader, WeatherGenReader, WeatherGenZarrReader, - WeatherGenMergeReader, ) from weathergen.evaluate.plotting.plot_utils import collect_channels from weathergen.evaluate.utils.utils import ( From 6fe6d49358c66a6f64efa147344120d57663d983 Mon Sep 17 00:00:00 2001 From: Sebastian Buschow Date: Mon, 19 Jan 2026 16:35:41 +0100 Subject: [PATCH 09/13] enabling JSONreader to skip plots and missing scores gracefully --- .../src/weathergen/evaluate/io/io_reader.py | 5 ++-- .../weathergen/evaluate/io/wegen_reader.py | 28 ++--------------- .../src/weathergen/evaluate/run_evaluation.py | 30 ++++++++++++------- 3 files changed, 25 insertions(+), 38 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py index 92e6748fd..bc65b9366 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py @@ -219,7 +219,7 @@ def check_availability( if available_data is not None and reader_data[name] != available[name]: _logger.info( f"Requested all {name}s for {mode}, but previous config was a " - "strict subset. Recomputing." + "strict subset. Recomputation required." ) check_score = False @@ -241,7 +241,8 @@ def check_availability( if available_data is not None and not requested[name] <= available[name]: missing = requested[name] - available[name] _logger.info( - f"{name.capitalize()}(s) {missing} missing in previous evaluation. Recomputing." + f"{name.capitalize()}(s) {missing} missing in previous evaluation." + "Recomputation required." ) check_score = False diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 4bcc71d6f..1ed99c421 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -211,7 +211,7 @@ def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArr if score_path.exists(): with open(score_path) as f: data_dict = json.load(f) - score = xr.DataArray.from_dict(data_dict) # not a dict though + score = xr.DataArray.from_dict(data_dict) else: score = None return score @@ -261,12 +261,7 @@ def __init__( for region in regions: for metric in metrics: score = self.load_single_score(stream, region, metric) - if score is None: - raise ValueError( - f"JSONreader couldn't find {metric} for {run_id}, stream {stream}, " - f"region {region}. Use type: zarr instead if possible." - ) - else: + if score is not None: for name in coord_names: vals = set(score[name].values) all_coords[name].append(vals) @@ -302,33 +297,15 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non """Data reader class for WeatherGenerator model outputs stored in Zarr format.""" super().__init__(eval_cfg, run_id, private_paths) -<<<<<<< HEAD - fname_zarr_new = self.results_dir.joinpath( - f"validation_chkpt{self.mini_epoch:05d}_rank{self.rank:04d}.zarr" -======= zarr_ext = self.inference_cfg.get("zarr_store", "zarr") # for backwards compatibility assume zarr store is local i.e. .zarr format fname_zarr_new = self.results_dir.joinpath( f"validation_chkpt{self.mini_epoch:05d}_rank{self.rank:04d}.{zarr_ext}" ->>>>>>> upstream/develop ) fname_zarr_old = self.results_dir.joinpath( f"validation_epoch{self.mini_epoch:05d}_rank{self.rank:04d}.zarr" ) -<<<<<<< HEAD - - if fname_zarr_new.exists() or fname_zarr_new.is_dir(): - self.fname_zarr = fname_zarr_new - else: - self.fname_zarr = fname_zarr_old - - if not self.fname_zarr.exists() or not self.fname_zarr.is_dir(): - _logger.error(f"Zarr file {self.fname_zarr} does not exist.") - raise FileNotFoundError( - f"Zarr file {self.fname_zarr} does not exist or is not a directory." - ) -======= if fname_zarr_new.exists(): if (zarr_ext == "zarr" and fname_zarr_new.is_dir()) or ( zarr_ext == "zip" and fname_zarr_new.is_file() @@ -340,7 +317,6 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non if not self.fname_zarr.exists(): _logger.error(f"Zarr file {self.fname_zarr} does not exist.") raise FileNotFoundError(f"Zarr file {self.fname_zarr} does not exist") ->>>>>>> upstream/develop def get_data( self, diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index dce9a6725..64e9c1d3a 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -227,8 +227,11 @@ def _process_stream( # Parallel plotting if stream_dict.get("plotting"): - plot_data(reader, stream, global_plotting_opts) - + if type_ == "zarr": + plot_data(reader, stream, global_plotting_opts) + else: + _logger.info("skipped plot_data, use type: zarr for that.") + # Scoring per stream if not stream_dict.get("evaluation"): return run_id, stream, {} @@ -241,15 +244,22 @@ def _process_stream( scores_dict = stream_loaded_scores if missing_metrics or plot_score_maps: - regions_to_compute = list(set(missing_metrics.keys())) if missing_metrics else regions - metrics_to_compute = missing_metrics if missing_metrics else metrics - - stream_computed_scores = calc_scores_per_stream( - reader, stream, regions_to_compute, metrics_to_compute, plot_score_maps - ) + if type_ == "zarr": + regions_to_compute = list(set(missing_metrics.keys())) if missing_metrics else regions + metrics_to_compute = missing_metrics if missing_metrics else metrics - metric_list_to_json(reader, stream, stream_computed_scores, regions) - scores_dict = merge(stream_loaded_scores, stream_computed_scores) + stream_computed_scores = calc_scores_per_stream( + reader, stream, regions_to_compute, metrics_to_compute, plot_score_maps + ) + metric_list_to_json(reader, stream, stream_computed_scores, regions) + scores_dict = merge(stream_loaded_scores, stream_computed_scores) + else: + if missing_metrics: + _logger.info(f"The following metrics have not yet been computed:" + f"{missing_metrics}. Use type: zarr for that.") + if plot_score_maps: + _logger.info(f"score maps skipped, use type:zarr for that.") + return run_id, stream, scores_dict From 969ac63d66c0f49364d15b81dc2ce4ae0b070689 Mon Sep 17 00:00:00 2001 From: Sebastian Buschow Date: Mon, 19 Jan 2026 16:54:26 +0100 Subject: [PATCH 10/13] required reformatting --- .../evaluate/src/weathergen/evaluate/io/io_reader.py | 4 ++-- .../src/weathergen/evaluate/io/wegen_reader.py | 2 +- .../src/weathergen/evaluate/run_evaluation.py | 11 ++++++----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py index bc65b9366..658339cac 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/io_reader.py @@ -241,8 +241,8 @@ def check_availability( if available_data is not None and not requested[name] <= available[name]: missing = requested[name] - available[name] _logger.info( - f"{name.capitalize()}(s) {missing} missing in previous evaluation." - "Recomputation required." + f"{name.capitalize()}(s) {missing} missing in previous evaluation." + "Recomputation required." ) check_score = False diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 1ed99c421..5d6f8ab26 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -211,7 +211,7 @@ def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArr if score_path.exists(): with open(score_path) as f: data_dict = json.load(f) - score = xr.DataArray.from_dict(data_dict) + score = xr.DataArray.from_dict(data_dict) else: score = None return score diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 64e9c1d3a..3cf77f39c 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -231,7 +231,7 @@ def _process_stream( plot_data(reader, stream, global_plotting_opts) else: _logger.info("skipped plot_data, use type: zarr for that.") - + # Scoring per stream if not stream_dict.get("evaluation"): return run_id, stream, {} @@ -255,11 +255,12 @@ def _process_stream( scores_dict = merge(stream_loaded_scores, stream_computed_scores) else: if missing_metrics: - _logger.info(f"The following metrics have not yet been computed:" - f"{missing_metrics}. Use type: zarr for that.") + _logger.info( + f"The following metrics have not yet been computed:" + f"{missing_metrics}. Use type: zarr for that." + ) if plot_score_maps: - _logger.info(f"score maps skipped, use type:zarr for that.") - + _logger.info("score maps skipped, use type:zarr for that.") return run_id, stream, scores_dict From 9e009cd90aadcfa1df19a46f7d1ac1019347a676 Mon Sep 17 00:00:00 2001 From: Sebastian Buschow Date: Tue, 20 Jan 2026 08:47:12 +0100 Subject: [PATCH 11/13] move skipping of metrics to the reader class --- .../weathergen/evaluate/io/wegen_reader.py | 19 +++++++-- .../src/weathergen/evaluate/run_evaluation.py | 40 ++++++++----------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 5d6f8ab26..c958f38b7 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -171,8 +171,9 @@ def load_scores( ------- xr.DataArray The metric DataArray. - missing_metrics: - dictionary of missing regions and metrics that need to be recomputed. + computable_metrics: + dictionary of regions and metrics that can be recomputed + (empty for JSONreader). """ local_scores = {} @@ -196,8 +197,8 @@ def load_scores( # all other cases: recompute scores missing_metrics.setdefault(region, []).append(metric) continue - - return local_scores, missing_metrics + computable_metrics = self.check_computability(missing_metrics) + return local_scores, computable_metrics def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArray | None: """ @@ -216,6 +217,10 @@ def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArr score = None return score + def check_computability(self, metrics): + """determine whether given metrics can be re-computed.""" + return metrics + def get_inference_stream_attr(self, stream_name: str, key: str, default=None): """ Get the value of a key for a specific stream from the a model config. @@ -291,6 +296,12 @@ def get_data(self, *args, **kwargs): # it can still happen when a particular score was available for a different channel raise ValueError(f"Missing JSON data for run {self.run_id}.") + def check_computability(self, metrics): + _logger.info( + f"The following metrics have not yet been computed:{metrics}. Use type: zarr for that." + ) + return {} + class WeatherGenZarrReader(WeatherGenReader): def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 3cf77f39c..3eeaae1df 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -217,50 +217,42 @@ def _process_stream( plot_score_maps: Bool to define if the score maps need to be plotted or not. """ - # try: + type_ = run.get("type", "zarr") reader = get_reader(type_, run, run_id, private_paths, regions, metrics) + if plot_score_maps and type_ != "zarr": + _logger.info("score maps skipped, use type:zarr for that.") + plot_score_maps = False + stream_dict = reader.get_stream(stream) if not stream_dict: return run_id, stream, {} # Parallel plotting - if stream_dict.get("plotting"): - if type_ == "zarr": - plot_data(reader, stream, global_plotting_opts) - else: - _logger.info("skipped plot_data, use type: zarr for that.") + if stream_dict.get("plotting") and type_ == "zarr": + plot_data(reader, stream, global_plotting_opts) # Scoring per stream if not stream_dict.get("evaluation"): return run_id, stream, {} - stream_loaded_scores, missing_metrics = reader.load_scores( + stream_loaded_scores, computable_metrics = reader.load_scores( stream, regions, metrics, ) scores_dict = stream_loaded_scores - if missing_metrics or plot_score_maps: - if type_ == "zarr": - regions_to_compute = list(set(missing_metrics.keys())) if missing_metrics else regions - metrics_to_compute = missing_metrics if missing_metrics else metrics + if computable_metrics or plot_score_maps: + regions_to_compute = list(set(computable_metrics.keys())) if computable_metrics else regions + metrics_to_compute = computable_metrics if computable_metrics else metrics - stream_computed_scores = calc_scores_per_stream( - reader, stream, regions_to_compute, metrics_to_compute, plot_score_maps - ) - metric_list_to_json(reader, stream, stream_computed_scores, regions) - scores_dict = merge(stream_loaded_scores, stream_computed_scores) - else: - if missing_metrics: - _logger.info( - f"The following metrics have not yet been computed:" - f"{missing_metrics}. Use type: zarr for that." - ) - if plot_score_maps: - _logger.info("score maps skipped, use type:zarr for that.") + stream_computed_scores = calc_scores_per_stream( + reader, stream, regions_to_compute, metrics_to_compute, plot_score_maps + ) + metric_list_to_json(reader, stream, stream_computed_scores, regions) + scores_dict = merge(stream_loaded_scores, stream_computed_scores) return run_id, stream, scores_dict From f7fd406b834ff230d25276bfb101f8742c66f4cc Mon Sep 17 00:00:00 2001 From: Sebastian Buschow Date: Tue, 20 Jan 2026 14:33:45 +0100 Subject: [PATCH 12/13] slighly more explicit formulations --- .../src/weathergen/evaluate/io/wegen_reader.py | 8 ++++---- .../src/weathergen/evaluate/run_evaluation.py | 14 ++++++-------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index c958f38b7..04a9aefa5 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -197,8 +197,8 @@ def load_scores( # all other cases: recompute scores missing_metrics.setdefault(region, []).append(metric) continue - computable_metrics = self.check_computability(missing_metrics) - return local_scores, computable_metrics + recomputable_missing_metrics = self.get_recomputable_metrics(missing_metrics) + return local_scores, recomputable_missing_metrics def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArray | None: """ @@ -217,7 +217,7 @@ def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArr score = None return score - def check_computability(self, metrics): + def get_recomputable_metrics(self, metrics): """determine whether given metrics can be re-computed.""" return metrics @@ -296,7 +296,7 @@ def get_data(self, *args, **kwargs): # it can still happen when a particular score was available for a different channel raise ValueError(f"Missing JSON data for run {self.run_id}.") - def check_computability(self, metrics): + def get_recomputable_metrics(self, metrics): _logger.info( f"The following metrics have not yet been computed:{metrics}. Use type: zarr for that." ) diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 3eeaae1df..6b18ede61 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -221,10 +221,6 @@ def _process_stream( type_ = run.get("type", "zarr") reader = get_reader(type_, run, run_id, private_paths, regions, metrics) - if plot_score_maps and type_ != "zarr": - _logger.info("score maps skipped, use type:zarr for that.") - plot_score_maps = False - stream_dict = reader.get_stream(stream) if not stream_dict: return run_id, stream, {} @@ -237,16 +233,18 @@ def _process_stream( if not stream_dict.get("evaluation"): return run_id, stream, {} - stream_loaded_scores, computable_metrics = reader.load_scores( + stream_loaded_scores, recomputable_metrics = reader.load_scores( stream, regions, metrics, ) scores_dict = stream_loaded_scores - if computable_metrics or plot_score_maps: - regions_to_compute = list(set(computable_metrics.keys())) if computable_metrics else regions - metrics_to_compute = computable_metrics if computable_metrics else metrics + if recomputable_metrics or (plot_score_maps and type_ == "zarr"): + regions_to_compute = ( + list(set(recomputable_metrics.keys())) if recomputable_metrics else regions + ) + metrics_to_compute = recomputable_metrics if recomputable_metrics else metrics stream_computed_scores = calc_scores_per_stream( reader, stream, regions_to_compute, metrics_to_compute, plot_score_maps From 19b51c62ba2deee7127b4f85b644a5bd384379cb Mon Sep 17 00:00:00 2001 From: buschow1 Date: Tue, 3 Feb 2026 14:34:04 +0100 Subject: [PATCH 13/13] first attempt passing parameters to individual metrics --- .../weathergen/evaluate/io/wegen_reader.py | 22 +++++++++----- .../src/weathergen/evaluate/run_evaluation.py | 18 ++++++----- .../src/weathergen/evaluate/scores/score.py | 9 +++--- .../src/weathergen/evaluate/utils/utils.py | 30 +++++++++++++++---- 4 files changed, 55 insertions(+), 24 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 062031273..e75f85597 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -151,7 +151,7 @@ def get_channels(self, stream: str) -> list[str]: return all_channels def load_scores( - self, stream: str, regions: list[str], metrics: list[str] + self, stream: str, regions: list[str], metrics: list[str], metric_parameters: dict = {} ) -> xr.DataArray | None: """ Load multiple pre-computed scores for a given run, stream and metric and epoch. @@ -180,7 +180,8 @@ def load_scores( missing_metrics = {} for region in regions: for metric in metrics: - score = self.load_single_score(stream, region, metric) + parameters = metric_parameters.get(metric,{}) + score = self.load_single_score(stream, region, metric, parameters) if score is not None: available_data = self.check_availability(stream, score, mode="evaluation") if available_data.score_availability: @@ -200,7 +201,8 @@ def load_scores( recomputable_missing_metrics = self.get_recomputable_metrics(missing_metrics) return local_scores, recomputable_missing_metrics - def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArray | None: + def load_single_score(self, stream: str, region: str, metric: str, parameters: dict = {} + ) -> xr.DataArray | None: """ Load a single pre-computed score for a given run, stream and metric """ @@ -209,12 +211,16 @@ def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArr / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" ) _logger.debug(f"Looking for: {score_path}") + score = None if score_path.exists(): with open(score_path) as f: data_dict = json.load(f) - score = xr.DataArray.from_dict(data_dict) - else: - score = None + if not "scores" in data_dict: + data_dict = { "scores": [data_dict] } + for score_version in data_dict["scores"]: + if score_version["attrs"] == parameters: + score = xr.DataArray.from_dict(score_version) + break return score def get_recomputable_metrics(self, metrics): @@ -253,6 +259,7 @@ def __init__( private_paths: dict | None = None, regions: list[str] | None = None, metrics: list[str] | None = None, + metric_parameters: dict = {} ): super().__init__(eval_cfg, run_id, private_paths) # goes looking for the coordinates available for all streams, regions, metrics @@ -265,7 +272,8 @@ def __init__( for stream in streams: for region in regions: for metric in metrics: - score = self.load_single_score(stream, region, metric) + parameters = metric_parameters.get(metric, {}) + score = self.load_single_score(stream, region, metric, parameters) if score is not None: for name in coord_names: vals = set(score[name].values) diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 98326444f..a1f3819c6 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -164,13 +164,14 @@ def get_reader( private_paths: dict[str, str], region: str | None = None, metric: str | None = None, + metric_parameters: dict = {} ): if reader_type == "zarr": reader = WeatherGenZarrReader(run, run_id, private_paths) elif reader_type == "csv": reader = CsvReader(run, run_id, private_paths) elif reader_type == "json": - reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric) + reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric, metric_parameters) elif reader_type == "merge": reader = WeatherGenMergeReader(run, run_id, private_paths) else: @@ -193,6 +194,7 @@ def _process_stream( regions: list[str], metrics: list[str], plot_score_maps: bool, + metric_parameters: dict[str, object], ) -> tuple[str, str, dict[str, dict[str, dict[str, float]]]]: """ Worker function for a single stream of a single run. @@ -219,7 +221,7 @@ def _process_stream( """ type_ = run.get("type", "zarr") - reader = get_reader(type_, run, run_id, private_paths, regions, metrics) + reader = get_reader(type_, run, run_id, private_paths, regions, metrics, metric_parameters) stream_dict = reader.get_stream(stream) if not stream_dict: @@ -237,6 +239,7 @@ def _process_stream( stream, regions, metrics, + metric_parameters ) scores_dict = stream_loaded_scores @@ -247,9 +250,9 @@ def _process_stream( metrics_to_compute = recomputable_metrics if recomputable_metrics else metrics stream_computed_scores = calc_scores_per_stream( - reader, stream, regions_to_compute, metrics_to_compute, plot_score_maps + reader, stream, regions_to_compute, metrics_to_compute, plot_score_maps, metric_parameters ) - metric_list_to_json(reader, stream, stream_computed_scores, regions) + metric_list_to_json(reader, stream, stream_computed_scores, regions, metric_parameters) scores_dict = merge(stream_loaded_scores, stream_computed_scores) return run_id, stream, scores_dict @@ -276,6 +279,7 @@ def evaluate_from_config( private_paths = cfg.get("private_paths") summary_dir = Path(cfg.evaluation.get("summary_dir", _DEFAULT_PLOT_DIR)) metrics = cfg.evaluation.metrics + metric_parameters = cfg.evaluation.get("metric_parameters", {}) regions = cfg.evaluation.get("regions", ["global"]) plot_score_maps = cfg.evaluation.get("plot_score_maps", False) global_plotting_opts = cfg.get("global_plotting_options", {}) @@ -308,10 +312,7 @@ def evaluate_from_config( if "streams" not in run: run["streams"] = default_streams - regions = cfg.evaluation.regions - metrics = cfg.evaluation.metrics - - reader = get_reader(type_, run, run_id, private_paths, regions, metrics) + reader = get_reader(type_, run, run_id, private_paths, regions, metrics, metric_parameters) for stream in reader.streams: tasks.append( @@ -324,6 +325,7 @@ def evaluate_from_config( "regions": regions, "metrics": metrics, "plot_score_maps": plot_score_maps, + "metric_parameters": metric_parameters } ) diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index 3099b0bd4..770a445f2 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -107,6 +107,7 @@ def get_score( group_by_coord: str | None = None, ens_dim: str = "ens", compute: bool = False, + parameters: dict = {}, **kwargs, ) -> xr.DataArray: """ @@ -136,8 +137,7 @@ def get_score( Calculated score as an xarray DataArray. """ sc = Scores(agg_dims=agg_dims, ens_dim=ens_dim) - - score_data = sc.get_score(data, score_name, group_by_coord, **kwargs) + score_data = sc.get_score(data, score_name, group_by_coord, parameters=parameters, **kwargs) if compute: # If compute is True, compute the score immediately return score_data.compute() @@ -204,6 +204,7 @@ def get_score( score_name: str, group_by_coord: str | None = None, compute: bool = False, + parameters: dict = {}, **kwargs, ): """ @@ -309,14 +310,14 @@ def get_score( group_slice = { k: (v[name] if v is not None else v) for k, v in grouped_args.items() } - res = f(**group_slice) + res = f(**group_slice, **parameters) # Add coordinate for concatenation res = res.expand_dims({group_by_coord: [name]}) results.append(res) result = xr.concat(results, dim=group_by_coord) else: # No grouping: just call the function - result = f(**args) + result = f(**args, **parameters) if compute: return result.compute() diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index d911a370b..8a30065c7 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -60,6 +60,7 @@ def calc_scores_per_stream( regions: list[str], metrics_dict: dict, plot_score_maps: bool = False, + metric_parameters: dict[str, object] = {}, ): """ Calculate scores for a given run and stream using the specified metrics. @@ -170,8 +171,9 @@ def calc_scores_per_stream( valid_scores = [] for metric in metrics: + parameters = metric_parameters.get(metric, {}) score = get_score( - score_data, metric, agg_dims="ipoint", group_by_coord=group_by_coord + score_data, metric, agg_dims="ipoint", group_by_coord=group_by_coord, parameters=parameters ) if score is not None: valid_scores.append(score) @@ -228,9 +230,10 @@ def calc_scores_per_stream( # Build local dictionary for this region for metric in metrics: + metric_data = metric_stream.sel({"metric": metric}).assign_attrs( metric_parameters.get(metric,{}) ) local_scores.setdefault(metric, {}).setdefault(region, {}).setdefault(stream, {})[ reader.run_id - ] = metric_stream.sel({"metric": metric}) + ] = metric_data return local_scores @@ -452,6 +455,7 @@ def metric_list_to_json( stream: str, metrics_dict: list[xr.DataArray], regions: list[str], + metric_parameters: dict, ): """ Write the evaluation results collected in a list of xarray DataArrays for the metrics @@ -480,6 +484,7 @@ def metric_list_to_json( reader.metrics_dir.mkdir(parents=True, exist_ok=True) for metric, metric_stream in metrics_dict.items(): + parameters = metric_parameters.get(metric, {}) for region in regions: metric_now = metric_stream[region][stream] for run_id in metric_now.keys(): @@ -490,10 +495,25 @@ def metric_list_to_json( reader.metrics_dir / f"{run_id}_{stream}_{region}_{metric}_chkpt{reader.mini_epoch:05d}.json" ) - - _logger.info(f"Saving results to {save_path}") + if save_path.exists(): + _logger.info(f"{save_path} already present") + with save_path.open("r") as f: + data_dict = json.load(f) + if not "scores" in data_dict: + data_dict = { "scores": [data_dict] } + for i,score_version in enumerate(data_dict["scores"]): + if score_version["attrs"] == parameters: + _logger.warning(f"metric with same parameters found, replacing") + data_dict["scores"][i] = metric_now.to_dict() + break + else: + data_dict["scores"].append(metric_now.to_dict()) + _logger.info(f"Appending results to {save_path}") + else: + data_dict = {"scores":[metric_now.to_dict()]} + _logger.info(f"Saving results to new file {save_path}") with open(save_path, "w") as f: - json.dump(metric_now.to_dict(), f, indent=4) + json.dump(data_dict, f, indent=4) _logger.info( f"Saved all results of inference run {reader.run_id} - mini_epoch {reader.mini_epoch:d} "