diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index 83e0a7b84..407a73608 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -138,6 +138,25 @@ run_ids : sample: [0, 1, 2, 3] ensemble: "all" #supported + #WeatherGeneratorMerge JSON example: + ############################### + + #Example of syntax to stack multiple runs over the ensemble dimension + merge_test: #<------ This is the new run_id name of the merged dataset. NB. you always need to specify one + type: "jsonmerge" # <------- VERY IMPORTANT + merge_run_ids: + - so67dku4 + - c9cg8ql3 + metrics_dir: "./merge_test/metrics/" #<------- VERY IMPORTANT + label: "Merged Results" + results_base_dir : "./results/" + streams: + ERA5: + channels: ["z_500", "t_850", "u_850", "v_850", "q_850"] + evaluation: + forecast_step: [2,4,6] + sample: [0, 1, 2, 3] + ensemble: "all" # CSV Reader example: ############################### @@ -165,4 +184,4 @@ run_ids : forecast_step: "all" sample: "all" - + \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 062031273..4da20fd0c 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -29,6 +29,7 @@ from weathergen.evaluate.io.io_reader import Reader, ReaderOutput from weathergen.evaluate.scores.score_utils import to_list from weathergen.evaluate.utils.derived_channels import DeriveChannels +from weathergen.evaluate.utils.utils import merge _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -205,7 +206,8 @@ def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArr Load a single pre-computed score for a given run, stream and metric """ score_path = ( - Path(self.metrics_dir) + Path(self.results_base_dir) + / "evaluation" / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" ) _logger.debug(f"Looking for: {score_path}") @@ -691,20 +693,52 @@ def _force_consistent_grids(ref: list[xr.DataArray]) -> xr.DataArray: class WeatherGenMergeReader(Reader): - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): - """Data reader class for WeatherGenerator model outputs stored in Zarr format.""" + def __init__( + self, + eval_cfg: dict, + run_id: str, + private_paths: dict | None = None, + regions: list[str] | None = None, + metrics: list[str] | None = None, + reader_type: str = "zarr", + ): + """ + Data reader class for merging WeatherGenerator model outputs stored in Zarr or JSON format. + Parameters + ---------- + eval_cfg: dict + config with plotting and evaluation options for that run id + run_id: str + run id of the model + private_paths: dict + dictionary of private paths for the supported HPC + regions: list[str] + names of predefined bounding box for a region + metrics: list[str] + names of the metric scores to compute + reader_type: str + The type of the internal reader. If zarr, WeatherGenZarrReader is used, + WeatherGenJSONReader otherwise. Default: zarr + """ + super().__init__(eval_cfg, run_id, private_paths) self.run_ids = eval_cfg.get("merge_run_ids", []) self.metrics_dir = Path(eval_cfg.get("metrics_dir")) - self.mini_epoch = eval_cfg.get("mini_epoch", eval_cfg.get("epoch")) + self.mini_epoch = eval_cfg.get("mini_epoch", 0) - super().__init__(eval_cfg, run_id, private_paths) self.readers = [] _logger.info(f"MERGE READERS: {self.run_ids} ...") + + for run_id in self.run_ids: - reader = WeatherGenZarrReader(self.eval_cfg, run_id, self.private_paths) + if reader_type == "zarr": + reader = WeatherGenZarrReader(self.eval_cfg, run_id, self.private_paths) + else: + reader = WeatherGenJSONReader( + self.eval_cfg, run_id, self.private_paths, regions, metrics + ) self.readers.append(reader) def get_data( @@ -813,7 +847,9 @@ def _concat_over_ens(self, da_merge, fsteps_merge): return da_ens - def load_scores(self, stream: str, regions: str, metrics: str) -> xr.DataArray | None: + def load_scores( + self, stream: str, regions: list[str], metrics: list[str] + ) -> xr.DataArray | None: """ Load the pre-computed scores for a given run, stream and metric and epoch. @@ -834,14 +870,36 @@ def load_scores(self, stream: str, regions: str, metrics: str) -> xr.DataArray | missing_metrics: dictionary of missing regions and metrics that need to be recomputed. """ - # TODO: implement this properly. Not it is skipping loading scores - local_scores = {} missing_metrics = {} - for region in regions: - for metric in metrics: - # all other cases: recompute scores - missing_metrics.setdefault(region, []).append(metric) + + if isinstance(self.readers[0], WeatherGenZarrReader): + # TODO: implement this properly. Not it is skipping loading scores + for region in regions: + for metric in metrics: + # all other cases: recompute scores + missing_metrics.setdefault(region, []).append(metric) + else: #JsonReader + # deep merge dicts + for reader in self.readers: + scores, missing = reader.load_scores(stream, regions, metrics) + merge(local_scores, scores) + merge(missing_metrics, missing) + + # merge runs into one with all scores concatenated + for metric in local_scores.keys(): + for region in local_scores[metric].keys(): + for stream in local_scores[metric][region].keys(): + scores = ( + local_scores[metric][region][stream].pop(run_id) + for run_id in self.run_ids + ) + local_scores[metric][region][stream].setdefault( + self.run_id, + xr.concat(scores, dim="ens").assign_coords( + ens=range(len(self.readers)) + ), + ) return local_scores, missing_metrics @@ -929,7 +987,7 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: for reader in self.readers: all_ensembles.append(reader.get_ensemble(stream)) - if all(e == ["0"] or e == [0] for e in all_ensembles): + if all(e == ["0"] or e == [0] or e == {0} for e in all_ensembles): return set(range(len(self.readers))) else: raise NotImplementedError( diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 98326444f..70dd18dd4 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -173,6 +173,10 @@ def get_reader( reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric) elif reader_type == "merge": reader = WeatherGenMergeReader(run, run_id, private_paths) + elif reader_type == "jsonmerge": + reader = WeatherGenMergeReader( + run, run_id, private_paths, region, metric, reader_type="json" + ) else: raise ValueError(f"Unknown reader type: {reader_type}") return reader diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index 0ef8f88ac..6ee00b1c0 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -532,22 +532,14 @@ def metric_list_to_json( Parameters ---------- - reader: + reader: Reader Reader object containing all info about the run_id. - metrics_list : + stream: str + Stream name. + metrics_dict: list Metrics per stream. - npoints_sample_list : - Number of points per sample per stream. - streams : - Stream names. - region : - Region name. - metric_dir : - Output directory. - run_id : - Identifier of the inference run. - mini_epoch : - Mini_epoch number. + regions: list + Region names. """ # stream_loaded_scores['rmse']['nhem']['ERA5']['jjqce6x5'] reader.metrics_dir.mkdir(parents=True, exist_ok=True)