diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 062031273..e75f85597 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -151,7 +151,7 @@ def get_channels(self, stream: str) -> list[str]: return all_channels def load_scores( - self, stream: str, regions: list[str], metrics: list[str] + self, stream: str, regions: list[str], metrics: list[str], metric_parameters: dict = {} ) -> xr.DataArray | None: """ Load multiple pre-computed scores for a given run, stream and metric and epoch. @@ -180,7 +180,8 @@ def load_scores( missing_metrics = {} for region in regions: for metric in metrics: - score = self.load_single_score(stream, region, metric) + parameters = metric_parameters.get(metric,{}) + score = self.load_single_score(stream, region, metric, parameters) if score is not None: available_data = self.check_availability(stream, score, mode="evaluation") if available_data.score_availability: @@ -200,7 +201,8 @@ def load_scores( recomputable_missing_metrics = self.get_recomputable_metrics(missing_metrics) return local_scores, recomputable_missing_metrics - def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArray | None: + def load_single_score(self, stream: str, region: str, metric: str, parameters: dict = {} + ) -> xr.DataArray | None: """ Load a single pre-computed score for a given run, stream and metric """ @@ -209,12 +211,16 @@ def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArr / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" ) _logger.debug(f"Looking for: {score_path}") + score = None if score_path.exists(): with open(score_path) as f: data_dict = json.load(f) - score = xr.DataArray.from_dict(data_dict) - else: - score = None + if not "scores" in data_dict: + data_dict = { "scores": [data_dict] } + for score_version in data_dict["scores"]: + if score_version["attrs"] == parameters: + score = xr.DataArray.from_dict(score_version) + break return score def get_recomputable_metrics(self, metrics): @@ -253,6 +259,7 @@ def __init__( private_paths: dict | None = None, regions: list[str] | None = None, metrics: list[str] | None = None, + metric_parameters: dict = {} ): super().__init__(eval_cfg, run_id, private_paths) # goes looking for the coordinates available for all streams, regions, metrics @@ -265,7 +272,8 @@ def __init__( for stream in streams: for region in regions: for metric in metrics: - score = self.load_single_score(stream, region, metric) + parameters = metric_parameters.get(metric, {}) + score = self.load_single_score(stream, region, metric, parameters) if score is not None: for name in coord_names: vals = set(score[name].values) diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 98326444f..a1f3819c6 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -164,13 +164,14 @@ def get_reader( private_paths: dict[str, str], region: str | None = None, metric: str | None = None, + metric_parameters: dict = {} ): if reader_type == "zarr": reader = WeatherGenZarrReader(run, run_id, private_paths) elif reader_type == "csv": reader = CsvReader(run, run_id, private_paths) elif reader_type == "json": - reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric) + reader = WeatherGenJSONReader(run, run_id, private_paths, region, metric, metric_parameters) elif reader_type == "merge": reader = WeatherGenMergeReader(run, run_id, private_paths) else: @@ -193,6 +194,7 @@ def _process_stream( regions: list[str], metrics: list[str], plot_score_maps: bool, + metric_parameters: dict[str, object], ) -> tuple[str, str, dict[str, dict[str, dict[str, float]]]]: """ Worker function for a single stream of a single run. @@ -219,7 +221,7 @@ def _process_stream( """ type_ = run.get("type", "zarr") - reader = get_reader(type_, run, run_id, private_paths, regions, metrics) + reader = get_reader(type_, run, run_id, private_paths, regions, metrics, metric_parameters) stream_dict = reader.get_stream(stream) if not stream_dict: @@ -237,6 +239,7 @@ def _process_stream( stream, regions, metrics, + metric_parameters ) scores_dict = stream_loaded_scores @@ -247,9 +250,9 @@ def _process_stream( metrics_to_compute = recomputable_metrics if recomputable_metrics else metrics stream_computed_scores = calc_scores_per_stream( - reader, stream, regions_to_compute, metrics_to_compute, plot_score_maps + reader, stream, regions_to_compute, metrics_to_compute, plot_score_maps, metric_parameters ) - metric_list_to_json(reader, stream, stream_computed_scores, regions) + metric_list_to_json(reader, stream, stream_computed_scores, regions, metric_parameters) scores_dict = merge(stream_loaded_scores, stream_computed_scores) return run_id, stream, scores_dict @@ -276,6 +279,7 @@ def evaluate_from_config( private_paths = cfg.get("private_paths") summary_dir = Path(cfg.evaluation.get("summary_dir", _DEFAULT_PLOT_DIR)) metrics = cfg.evaluation.metrics + metric_parameters = cfg.evaluation.get("metric_parameters", {}) regions = cfg.evaluation.get("regions", ["global"]) plot_score_maps = cfg.evaluation.get("plot_score_maps", False) global_plotting_opts = cfg.get("global_plotting_options", {}) @@ -308,10 +312,7 @@ def evaluate_from_config( if "streams" not in run: run["streams"] = default_streams - regions = cfg.evaluation.regions - metrics = cfg.evaluation.metrics - - reader = get_reader(type_, run, run_id, private_paths, regions, metrics) + reader = get_reader(type_, run, run_id, private_paths, regions, metrics, metric_parameters) for stream in reader.streams: tasks.append( @@ -324,6 +325,7 @@ def evaluate_from_config( "regions": regions, "metrics": metrics, "plot_score_maps": plot_score_maps, + "metric_parameters": metric_parameters } ) diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index 3099b0bd4..770a445f2 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -107,6 +107,7 @@ def get_score( group_by_coord: str | None = None, ens_dim: str = "ens", compute: bool = False, + parameters: dict = {}, **kwargs, ) -> xr.DataArray: """ @@ -136,8 +137,7 @@ def get_score( Calculated score as an xarray DataArray. """ sc = Scores(agg_dims=agg_dims, ens_dim=ens_dim) - - score_data = sc.get_score(data, score_name, group_by_coord, **kwargs) + score_data = sc.get_score(data, score_name, group_by_coord, parameters=parameters, **kwargs) if compute: # If compute is True, compute the score immediately return score_data.compute() @@ -204,6 +204,7 @@ def get_score( score_name: str, group_by_coord: str | None = None, compute: bool = False, + parameters: dict = {}, **kwargs, ): """ @@ -309,14 +310,14 @@ def get_score( group_slice = { k: (v[name] if v is not None else v) for k, v in grouped_args.items() } - res = f(**group_slice) + res = f(**group_slice, **parameters) # Add coordinate for concatenation res = res.expand_dims({group_by_coord: [name]}) results.append(res) result = xr.concat(results, dim=group_by_coord) else: # No grouping: just call the function - result = f(**args) + result = f(**args, **parameters) if compute: return result.compute() diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index d911a370b..8a30065c7 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -60,6 +60,7 @@ def calc_scores_per_stream( regions: list[str], metrics_dict: dict, plot_score_maps: bool = False, + metric_parameters: dict[str, object] = {}, ): """ Calculate scores for a given run and stream using the specified metrics. @@ -170,8 +171,9 @@ def calc_scores_per_stream( valid_scores = [] for metric in metrics: + parameters = metric_parameters.get(metric, {}) score = get_score( - score_data, metric, agg_dims="ipoint", group_by_coord=group_by_coord + score_data, metric, agg_dims="ipoint", group_by_coord=group_by_coord, parameters=parameters ) if score is not None: valid_scores.append(score) @@ -228,9 +230,10 @@ def calc_scores_per_stream( # Build local dictionary for this region for metric in metrics: + metric_data = metric_stream.sel({"metric": metric}).assign_attrs( metric_parameters.get(metric,{}) ) local_scores.setdefault(metric, {}).setdefault(region, {}).setdefault(stream, {})[ reader.run_id - ] = metric_stream.sel({"metric": metric}) + ] = metric_data return local_scores @@ -452,6 +455,7 @@ def metric_list_to_json( stream: str, metrics_dict: list[xr.DataArray], regions: list[str], + metric_parameters: dict, ): """ Write the evaluation results collected in a list of xarray DataArrays for the metrics @@ -480,6 +484,7 @@ def metric_list_to_json( reader.metrics_dir.mkdir(parents=True, exist_ok=True) for metric, metric_stream in metrics_dict.items(): + parameters = metric_parameters.get(metric, {}) for region in regions: metric_now = metric_stream[region][stream] for run_id in metric_now.keys(): @@ -490,10 +495,25 @@ def metric_list_to_json( reader.metrics_dir / f"{run_id}_{stream}_{region}_{metric}_chkpt{reader.mini_epoch:05d}.json" ) - - _logger.info(f"Saving results to {save_path}") + if save_path.exists(): + _logger.info(f"{save_path} already present") + with save_path.open("r") as f: + data_dict = json.load(f) + if not "scores" in data_dict: + data_dict = { "scores": [data_dict] } + for i,score_version in enumerate(data_dict["scores"]): + if score_version["attrs"] == parameters: + _logger.warning(f"metric with same parameters found, replacing") + data_dict["scores"][i] = metric_now.to_dict() + break + else: + data_dict["scores"].append(metric_now.to_dict()) + _logger.info(f"Appending results to {save_path}") + else: + data_dict = {"scores":[metric_now.to_dict()]} + _logger.info(f"Saving results to new file {save_path}") with open(save_path, "w") as f: - json.dump(metric_now.to_dict(), f, indent=4) + json.dump(data_dict, f, indent=4) _logger.info( f"Saved all results of inference run {reader.run_id} - mini_epoch {reader.mini_epoch:d} "