Skip to content
21 changes: 20 additions & 1 deletion config/evaluate/eval_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,25 @@ run_ids :
sample: [0, 1, 2, 3]
ensemble: "all" #supported

#WeatherGeneratorMerge JSON example:
###############################

#Example of syntax to stack multiple runs over the ensemble dimension
merge_test: #<------ This is the new run_id name of the merged dataset. NB. you always need to specify one
type: "jsonmerge" # <------- VERY IMPORTANT
merge_run_ids:
- so67dku4
- c9cg8ql3
metrics_dir: "./merge_test/metrics/" #<------- VERY IMPORTANT
label: "Merged Results"
results_base_dir : "./results/"
streams:
ERA5:
channels: ["z_500", "t_850", "u_850", "v_850", "q_850"]
evaluation:
forecast_step: [2,4,6]
sample: [0, 1, 2, 3]
ensemble: "all"

# CSV Reader example:
###############################
Expand Down Expand Up @@ -165,4 +184,4 @@ run_ids :
forecast_step: "all"
sample: "all"



86 changes: 72 additions & 14 deletions packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from weathergen.evaluate.io.io_reader import Reader, ReaderOutput
from weathergen.evaluate.scores.score_utils import to_list
from weathergen.evaluate.utils.derived_channels import DeriveChannels
from weathergen.evaluate.utils.utils import merge

_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -205,7 +206,8 @@ def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArr
Load a single pre-computed score for a given run, stream and metric
"""
score_path = (
Path(self.metrics_dir)
Path(self.results_base_dir)
/ "evaluation"
/ f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json"
)
_logger.debug(f"Looking for: {score_path}")
Expand Down Expand Up @@ -691,20 +693,52 @@ def _force_consistent_grids(ref: list[xr.DataArray]) -> xr.DataArray:


class WeatherGenMergeReader(Reader):
def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None):
"""Data reader class for WeatherGenerator model outputs stored in Zarr format."""
def __init__(
self,
eval_cfg: dict,
run_id: str,
private_paths: dict | None = None,
regions: list[str] | None = None,
metrics: list[str] | None = None,
reader_type: str = "zarr",
):
"""
Data reader class for merging WeatherGenerator model outputs stored in Zarr or JSON format.

Parameters
----------
eval_cfg: dict
config with plotting and evaluation options for that run id
run_id: str
run id of the model
private_paths: dict
dictionary of private paths for the supported HPC
regions: list[str]
names of predefined bounding box for a region
metrics: list[str]
names of the metric scores to compute
reader_type: str
The type of the internal reader. If zarr, WeatherGenZarrReader is used,
WeatherGenJSONReader otherwise. Default: zarr
"""
super().__init__(eval_cfg, run_id, private_paths)
self.run_ids = eval_cfg.get("merge_run_ids", [])
self.metrics_dir = Path(eval_cfg.get("metrics_dir"))
self.mini_epoch = eval_cfg.get("mini_epoch", eval_cfg.get("epoch"))
self.mini_epoch = eval_cfg.get("mini_epoch", 0)

super().__init__(eval_cfg, run_id, private_paths)
self.readers = []

_logger.info(f"MERGE READERS: {self.run_ids} ...")



for run_id in self.run_ids:
reader = WeatherGenZarrReader(self.eval_cfg, run_id, self.private_paths)
if reader_type == "zarr":
reader = WeatherGenZarrReader(self.eval_cfg, run_id, self.private_paths)
else:
reader = WeatherGenJSONReader(
self.eval_cfg, run_id, self.private_paths, regions, metrics
)
self.readers.append(reader)

def get_data(
Expand Down Expand Up @@ -813,7 +847,9 @@ def _concat_over_ens(self, da_merge, fsteps_merge):

return da_ens

def load_scores(self, stream: str, regions: str, metrics: str) -> xr.DataArray | None:
def load_scores(
self, stream: str, regions: list[str], metrics: list[str]
) -> xr.DataArray | None:
"""
Load the pre-computed scores for a given run, stream and metric and epoch.

Expand All @@ -834,14 +870,36 @@ def load_scores(self, stream: str, regions: str, metrics: str) -> xr.DataArray |
missing_metrics:
dictionary of missing regions and metrics that need to be recomputed.
"""
# TODO: implement this properly. Not it is skipping loading scores

local_scores = {}
missing_metrics = {}
for region in regions:
for metric in metrics:
# all other cases: recompute scores
missing_metrics.setdefault(region, []).append(metric)

if isinstance(self.readers[0], WeatherGenZarrReader):
# TODO: implement this properly. Not it is skipping loading scores
for region in regions:
for metric in metrics:
# all other cases: recompute scores
missing_metrics.setdefault(region, []).append(metric)
else: #JsonReader
# deep merge dicts
for reader in self.readers:
scores, missing = reader.load_scores(stream, regions, metrics)
merge(local_scores, scores)
merge(missing_metrics, missing)

# merge runs into one with all scores concatenated
for metric in local_scores.keys():
for region in local_scores[metric].keys():
for stream in local_scores[metric][region].keys():
scores = (
local_scores[metric][region][stream].pop(run_id)
for run_id in self.run_ids
)
local_scores[metric][region][stream].setdefault(
self.run_id,
xr.concat(scores, dim="ens").assign_coords(
ens=range(len(self.readers))
),
)

return local_scores, missing_metrics

Expand Down Expand Up @@ -929,7 +987,7 @@ def get_ensemble(self, stream: str | None = None) -> list[str]:
for reader in self.readers:
all_ensembles.append(reader.get_ensemble(stream))

if all(e == ["0"] or e == [0] for e in all_ensembles):
if all(e == ["0"] or e == [0] or e == {0} for e in all_ensembles):
return set(range(len(self.readers)))
else:
raise NotImplementedError(
Expand Down
4 changes: 4 additions & 0 deletions packages/evaluate/src/weathergen/evaluate/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ def get_reader(
reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric)
elif reader_type == "merge":
reader = WeatherGenMergeReader(run, run_id, private_paths)
elif reader_type == "jsonmerge":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Selection via type: "jsonmerge" in config

reader = WeatherGenMergeReader(
run, run_id, private_paths, region, metric, reader_type="json"
)
else:
raise ValueError(f"Unknown reader type: {reader_type}")
return reader
Expand Down
20 changes: 6 additions & 14 deletions packages/evaluate/src/weathergen/evaluate/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,22 +532,14 @@ def metric_list_to_json(

Parameters
----------
reader:
reader: Reader
Reader object containing all info about the run_id.
metrics_list :
stream: str
Stream name.
metrics_dict: list
Metrics per stream.
npoints_sample_list :
Number of points per sample per stream.
streams :
Stream names.
region :
Region name.
metric_dir :
Output directory.
run_id :
Identifier of the inference run.
mini_epoch :
Mini_epoch number.
regions: list
Region names.
"""
# stream_loaded_scores['rmse']['nhem']['ERA5']['jjqce6x5']
reader.metrics_dir.mkdir(parents=True, exist_ok=True)
Expand Down
Loading