Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
79aba95
split WeatherGenReader functionality to allow reading only JSON
Dec 4, 2025
7131983
informative error when metrics are not there
Dec 4, 2025
4831b40
Merge branch 'develop' into json_reader
s6sebusc Jan 7, 2026
45dfaf2
restore JSONreader after rebase
Jan 7, 2026
28db658
JSONreader mostly restored
Jan 7, 2026
0a79833
Merge branch 'develop' into json_reader
s6sebusc Jan 7, 2026
8fd16e9
MLFlow logging independent of JSON/zarr
Jan 7, 2026
305c63c
Merge branch 'develop' into json_reader
iluise Jan 8, 2026
2863b2f
Merge branch 'develop' into json_reader
s6sebusc Jan 8, 2026
7cbcdba
linting, properly cheking fsteps, ens, samples in JSONreader
Jan 9, 2026
c08f0cc
Merge branch 'develop' into json_reader
s6sebusc Jan 12, 2026
c06a653
Merge branch 'develop' into json_reader
s6sebusc Jan 13, 2026
90bfe4b
tiny change to restore the MergeReader
Jan 13, 2026
816e8cd
lint
iluise Jan 14, 2026
8645bd9
adding upstream changes to zarr file handling
Jan 19, 2026
6fe6d49
enabling JSONreader to skip plots and missing scores gracefully
Jan 19, 2026
969ac63
required reformatting
Jan 19, 2026
9e009cd
move skipping of metrics to the reader class
Jan 20, 2026
f7fd406
slighly more explicit formulations
Jan 20, 2026
aca608a
Merge branch 'develop' into json_reader
s6sebusc Jan 20, 2026
2c88a5e
Merge remote-tracking branch 'upstream/develop' into json_reader
Feb 3, 2026
19b51c6
first attempt passing parameters to individual metrics
Feb 3, 2026
5827a23
Merge branch 'develop' into metric_parameters
s6sebusc Feb 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def get_channels(self, stream: str) -> list[str]:
return all_channels

def load_scores(
self, stream: str, regions: list[str], metrics: list[str]
self, stream: str, regions: list[str], metrics: list[str], metric_parameters: dict = {}
) -> xr.DataArray | None:
"""
Load multiple pre-computed scores for a given run, stream and metric and epoch.
Expand Down Expand Up @@ -180,7 +180,8 @@ def load_scores(
missing_metrics = {}
for region in regions:
for metric in metrics:
score = self.load_single_score(stream, region, metric)
parameters = metric_parameters.get(metric,{})
score = self.load_single_score(stream, region, metric, parameters)
if score is not None:
available_data = self.check_availability(stream, score, mode="evaluation")
if available_data.score_availability:
Expand All @@ -200,7 +201,8 @@ def load_scores(
recomputable_missing_metrics = self.get_recomputable_metrics(missing_metrics)
return local_scores, recomputable_missing_metrics

def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArray | None:
def load_single_score(self, stream: str, region: str, metric: str, parameters: dict = {}
) -> xr.DataArray | None:
"""
Load a single pre-computed score for a given run, stream and metric
"""
Expand All @@ -209,12 +211,16 @@ def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArr
/ f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json"
)
_logger.debug(f"Looking for: {score_path}")
score = None
if score_path.exists():
with open(score_path) as f:
data_dict = json.load(f)
score = xr.DataArray.from_dict(data_dict)
else:
score = None
if not "scores" in data_dict:
data_dict = { "scores": [data_dict] }
for score_version in data_dict["scores"]:
if score_version["attrs"] == parameters:
score = xr.DataArray.from_dict(score_version)
break
return score

def get_recomputable_metrics(self, metrics):
Expand Down Expand Up @@ -253,6 +259,7 @@ def __init__(
private_paths: dict | None = None,
regions: list[str] | None = None,
metrics: list[str] | None = None,
metric_parameters: dict = {}
):
super().__init__(eval_cfg, run_id, private_paths)
# goes looking for the coordinates available for all streams, regions, metrics
Expand All @@ -265,7 +272,8 @@ def __init__(
for stream in streams:
for region in regions:
for metric in metrics:
score = self.load_single_score(stream, region, metric)
parameters = metric_parameters.get(metric, {})
score = self.load_single_score(stream, region, metric, parameters)
if score is not None:
for name in coord_names:
vals = set(score[name].values)
Expand Down
18 changes: 10 additions & 8 deletions packages/evaluate/src/weathergen/evaluate/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,14 @@ def get_reader(
private_paths: dict[str, str],
region: str | None = None,
metric: str | None = None,
metric_parameters: dict = {}
):
if reader_type == "zarr":
reader = WeatherGenZarrReader(run, run_id, private_paths)
elif reader_type == "csv":
reader = CsvReader(run, run_id, private_paths)
elif reader_type == "json":
reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric)
reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric, metric_parameters)
elif reader_type == "merge":
reader = WeatherGenMergeReader(run, run_id, private_paths)
else:
Expand All @@ -193,6 +194,7 @@ def _process_stream(
regions: list[str],
metrics: list[str],
plot_score_maps: bool,
metric_parameters: dict[str, object],
) -> tuple[str, str, dict[str, dict[str, dict[str, float]]]]:
"""
Worker function for a single stream of a single run.
Expand All @@ -219,7 +221,7 @@ def _process_stream(
"""

type_ = run.get("type", "zarr")
reader = get_reader(type_, run, run_id, private_paths, regions, metrics)
reader = get_reader(type_, run, run_id, private_paths, regions, metrics, metric_parameters)

stream_dict = reader.get_stream(stream)
if not stream_dict:
Expand All @@ -237,6 +239,7 @@ def _process_stream(
stream,
regions,
metrics,
metric_parameters
)
scores_dict = stream_loaded_scores

Expand All @@ -247,9 +250,9 @@ def _process_stream(
metrics_to_compute = recomputable_metrics if recomputable_metrics else metrics

stream_computed_scores = calc_scores_per_stream(
reader, stream, regions_to_compute, metrics_to_compute, plot_score_maps
reader, stream, regions_to_compute, metrics_to_compute, plot_score_maps, metric_parameters
)
metric_list_to_json(reader, stream, stream_computed_scores, regions)
metric_list_to_json(reader, stream, stream_computed_scores, regions, metric_parameters)
scores_dict = merge(stream_loaded_scores, stream_computed_scores)

return run_id, stream, scores_dict
Expand All @@ -276,6 +279,7 @@ def evaluate_from_config(
private_paths = cfg.get("private_paths")
summary_dir = Path(cfg.evaluation.get("summary_dir", _DEFAULT_PLOT_DIR))
metrics = cfg.evaluation.metrics
metric_parameters = cfg.evaluation.get("metric_parameters", {})
regions = cfg.evaluation.get("regions", ["global"])
plot_score_maps = cfg.evaluation.get("plot_score_maps", False)
global_plotting_opts = cfg.get("global_plotting_options", {})
Expand Down Expand Up @@ -308,10 +312,7 @@ def evaluate_from_config(
if "streams" not in run:
run["streams"] = default_streams

regions = cfg.evaluation.regions
metrics = cfg.evaluation.metrics

reader = get_reader(type_, run, run_id, private_paths, regions, metrics)
reader = get_reader(type_, run, run_id, private_paths, regions, metrics, metric_parameters)

for stream in reader.streams:
tasks.append(
Expand All @@ -324,6 +325,7 @@ def evaluate_from_config(
"regions": regions,
"metrics": metrics,
"plot_score_maps": plot_score_maps,
"metric_parameters": metric_parameters
}
)

Expand Down
9 changes: 5 additions & 4 deletions packages/evaluate/src/weathergen/evaluate/scores/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def get_score(
group_by_coord: str | None = None,
ens_dim: str = "ens",
compute: bool = False,
parameters: dict = {},
**kwargs,
) -> xr.DataArray:
"""
Expand Down Expand Up @@ -136,8 +137,7 @@ def get_score(
Calculated score as an xarray DataArray.
"""
sc = Scores(agg_dims=agg_dims, ens_dim=ens_dim)

score_data = sc.get_score(data, score_name, group_by_coord, **kwargs)
score_data = sc.get_score(data, score_name, group_by_coord, parameters=parameters, **kwargs)
if compute:
# If compute is True, compute the score immediately
return score_data.compute()
Expand Down Expand Up @@ -204,6 +204,7 @@ def get_score(
score_name: str,
group_by_coord: str | None = None,
compute: bool = False,
parameters: dict = {},
**kwargs,
):
"""
Expand Down Expand Up @@ -309,14 +310,14 @@ def get_score(
group_slice = {
k: (v[name] if v is not None else v) for k, v in grouped_args.items()
}
res = f(**group_slice)
res = f(**group_slice, **parameters)
# Add coordinate for concatenation
res = res.expand_dims({group_by_coord: [name]})
results.append(res)
result = xr.concat(results, dim=group_by_coord)
else:
# No grouping: just call the function
result = f(**args)
result = f(**args, **parameters)

if compute:
return result.compute()
Expand Down
30 changes: 25 additions & 5 deletions packages/evaluate/src/weathergen/evaluate/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def calc_scores_per_stream(
regions: list[str],
metrics_dict: dict,
plot_score_maps: bool = False,
metric_parameters: dict[str, object] = {},
):
"""
Calculate scores for a given run and stream using the specified metrics.
Expand Down Expand Up @@ -170,8 +171,9 @@ def calc_scores_per_stream(
valid_scores = []

for metric in metrics:
parameters = metric_parameters.get(metric, {})
score = get_score(
score_data, metric, agg_dims="ipoint", group_by_coord=group_by_coord
score_data, metric, agg_dims="ipoint", group_by_coord=group_by_coord, parameters=parameters
)
if score is not None:
valid_scores.append(score)
Expand Down Expand Up @@ -228,9 +230,10 @@ def calc_scores_per_stream(

# Build local dictionary for this region
for metric in metrics:
metric_data = metric_stream.sel({"metric": metric}).assign_attrs( metric_parameters.get(metric,{}) )
local_scores.setdefault(metric, {}).setdefault(region, {}).setdefault(stream, {})[
reader.run_id
] = metric_stream.sel({"metric": metric})
] = metric_data

return local_scores

Expand Down Expand Up @@ -452,6 +455,7 @@ def metric_list_to_json(
stream: str,
metrics_dict: list[xr.DataArray],
regions: list[str],
metric_parameters: dict,
):
"""
Write the evaluation results collected in a list of xarray DataArrays for the metrics
Expand Down Expand Up @@ -480,6 +484,7 @@ def metric_list_to_json(
reader.metrics_dir.mkdir(parents=True, exist_ok=True)

for metric, metric_stream in metrics_dict.items():
parameters = metric_parameters.get(metric, {})
for region in regions:
metric_now = metric_stream[region][stream]
for run_id in metric_now.keys():
Expand All @@ -490,10 +495,25 @@ def metric_list_to_json(
reader.metrics_dir
/ f"{run_id}_{stream}_{region}_{metric}_chkpt{reader.mini_epoch:05d}.json"
)

_logger.info(f"Saving results to {save_path}")
if save_path.exists():
_logger.info(f"{save_path} already present")
with save_path.open("r") as f:
data_dict = json.load(f)
if not "scores" in data_dict:
data_dict = { "scores": [data_dict] }
for i,score_version in enumerate(data_dict["scores"]):
if score_version["attrs"] == parameters:
_logger.warning(f"metric with same parameters found, replacing")
data_dict["scores"][i] = metric_now.to_dict()
break
else:
data_dict["scores"].append(metric_now.to_dict())
_logger.info(f"Appending results to {save_path}")
else:
data_dict = {"scores":[metric_now.to_dict()]}
_logger.info(f"Saving results to new file {save_path}")
with open(save_path, "w") as f:
json.dump(metric_now.to_dict(), f, indent=4)
json.dump(data_dict, f, indent=4)

_logger.info(
f"Saved all results of inference run {reader.run_id} - mini_epoch {reader.mini_epoch:d} "
Expand Down
Loading