diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 44ec60a3b..e3cdfefa6 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -607,7 +607,7 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]: image_paths += names if image_paths: - image_paths=sorted(image_paths) + image_paths = sorted(image_paths) images = [Image.open(path) for path in image_paths] images[0].save( f"{map_output_dir}/animation_{self.run_id}_{tag}_{sa}_{self.stream}_{region}_{var}.gif", diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index ab322ab28..c5d4cf8d0 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -190,6 +190,7 @@ def __init__( "grad_amplitude": self.calc_spatial_variability, "psnr": self.calc_psnr, "seeps": self.calc_seeps, + "nse": self.calc_nse, } self.prob_metrics_dict = { "ssr": self.calc_ssr, @@ -1199,6 +1200,34 @@ def seeps(ground_truth, prediction, thr_light, thr_heavy, seeps_weights): return seeps_values + def calc_nse(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: + """ + Calculate Nash–Sutcliffe_model_efficiency_coefficient (NSE) + of forecast data vs reference data + Metrics broadly used in hydrology + Parameters + ---------- + p: xr.DataArray + Forecast data array + gt: xr.DataArray + Ground truth data array + Returns + ------- + xr.DataArray + Nash–Sutcliffe_model_efficiency_coefficient (NSE) + + """ + + obs_mean = gt.mean(dim=self._agg_dims) + + num = ((gt - p) ** 2).sum(dim=self._agg_dims) + + den = ((gt - obs_mean) ** 2).sum(dim=self._agg_dims) + + nse = 1 - num / den + + return nse + ### Probablistic scores def calc_spread(self, p: xr.DataArray, **kwargs) -> xr.DataArray: diff --git a/src/weathergen/utils/cli.py b/src/weathergen/utils/cli.py index 1c7cba6a8..2bd9fe2a2 100644 --- a/src/weathergen/utils/cli.py +++ b/src/weathergen/utils/cli.py @@ -14,7 +14,7 @@ class Stage(enum.StrEnum): def get_main_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(allow_abbrev=False) subparsers = parser.add_subparsers(dest="stage") - + train_parser = subparsers.add_parser( Stage.train, help="Train a WeatherGenerator configuration from the ground up.",