From adf8c647f5683dc2f1c2e96d911c78b3d0191466 Mon Sep 17 00:00:00 2001 From: Andreas Grafberger <18516896+andreas-grafberger@users.noreply.github.com> Date: Thu, 16 Oct 2025 08:19:04 +0000 Subject: [PATCH 01/21] add optional dependency group with gribjump dependencies --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 490e424..287ab1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,10 @@ dependencies = [ "ruff", "pre-commit" ] + gribjump = [ + "earthkit-data[gribjump]", + "gribjumplib==0.10.3.dev20250908" + ] [project.scripts] hat-extract-timeseries = "hat.cli:extractor_cli" From 93c606a54decd51b6a164be1a36571111edb1e14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ois=C3=ADn=20Morrison?= Date: Thu, 16 Oct 2025 09:13:42 +0000 Subject: [PATCH 02/21] chore: min ekh version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 287ab1e..b590a4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ "ipyleaflet", "ipywidgets", "earthkit-data>=0.13.8", - "earthkit-hydro", + "earthkit-hydro>=1.0.0", "earthkit-meteo", "cfgrib", # check if necessary "netCDF4", # check if necessary From 8385bae054938a63b52fb414cf8a271fe5ec7c73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ois=C3=ADn=20Morrison?= Date: Thu, 16 Oct 2025 09:14:16 +0000 Subject: [PATCH 03/21] feat: add opt_1d_index to outputted station mapping df --- hat/station_mapping/mapper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hat/station_mapping/mapper.py b/hat/station_mapping/mapper.py index 6fab72f..140bc63 100644 --- a/hat/station_mapping/mapper.py +++ b/hat/station_mapping/mapper.py @@ -47,7 +47,7 @@ def apply_blacklist(blacklist_config, metric_grid, grid_area_coords1, grid_area_ return metric_grid, grid_area_coords1, grid_area_coords2 -def outputs_to_df(df, indx, indy, cindx, cindy, errors, grid_area_coords1, grid_area_coords2, filename): +def outputs_to_df(df, indx, indy, cindx, cindy, errors, grid_area_coords1, grid_area_coords2, shape, filename): df["opt_x_index"] = indx df["opt_y_index"] = indy df["near_x_index"] = cindx @@ -55,6 +55,7 @@ def outputs_to_df(df, indx, indy, cindx, cindy, errors, grid_area_coords1, grid_ df["opt_error"] = errors df["opt_x_coord"] = grid_area_coords1[indx, 0] df["opt_y_coord"] = grid_area_coords2[0, indy] + df["opt_1d_index"] = indy + shape[1] * indx if filename is not None: df.to_csv(filename, index=False) return df @@ -109,6 +110,7 @@ def mapper(config): *mapping_outputs, grid_area_coords1, grid_area_coords2, + shape=grid_area_coords1.shape, filename=config["output"]["file"] if config.get("output", None) is not None else None, ) generate_summary_plots(df, config.get("plot", None)) From 89430503762c59bc131a913402f26d77f54646c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ois=C3=ADn=20Morrison?= Date: Thu, 16 Oct 2025 10:10:35 +0000 Subject: [PATCH 04/21] feat: first basic implementation using gribjump for timeseries extraction --- hat/extract_timeseries/extractor.py | 63 ++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 20 deletions(-) diff --git a/hat/extract_timeseries/extractor.py b/hat/extract_timeseries/extractor.py index 6e557bb..30b99d2 100644 --- a/hat/extract_timeseries/extractor.py +++ b/hat/extract_timeseries/extractor.py @@ -8,13 +8,18 @@ from hat import _LOGGER as logger -def process_grid_inputs(grid_config): +def load_ekd_source(grid_config): src_name = list(grid_config["source"].keys())[0] logger.info(f"Processing grid inputs from source: {src_name}") logger.debug(f"Grid config: {grid_config['source'][src_name]}") ds = ekd.from_source(src_name, **grid_config["source"][src_name]).to_xarray( **grid_config.get("to_xarray_options", {}) ) + return ds + + +def process_grid_inputs(grid_config): + ds = load_ekd_source(grid_config) var_name = find_main_var(ds, 3) da = ds[var_name] logger.info(f"Xarray created from source:\n{da}\n") @@ -61,7 +66,7 @@ def create_mask_from_coords(coords_config, df, gridx, gridy, shape): return mask, duplication_indexes -def process_inputs(station_config, grid_config): +def parse_stations(station_config): logger.debug(f"Reading station file, {station_config}") df = pd.read_csv(station_config["file"]) filters = station_config.get("filter") @@ -72,23 +77,44 @@ def process_inputs(station_config, grid_config): index_config = station_config.get("index", None) coords_config = station_config.get("coords", None) + index_1d_config = station_config.get("index_1d", None) + return index_config, coords_config, index_1d_config, station_names, df + +def process_inputs(station_config, grid_config): + index_config, coords_config, index_1d_config, station_names, df = parse_stations(station_config) + + # TODO: better malformed config handling if index_config is not None and coords_config is not None: raise ValueError("Use either index or coords, not both.") - da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config) + if list(grid_config["source"].keys())[0] == "gribjump": + assert index_1d_config is not None + unique_indices, duplication_indexes = np.unique(df[index_1d_config].values, return_inverse=True) + grid_config["source"]["gribjump"]["indices"] = unique_indices + masked_da = load_ekd_source(grid_config) + # TODO: implement + da_varname = "placeholder_variable_name" - if index_config is not None: - mask, duplication_indexes = create_mask_from_index(index_config, df, shape) - elif coords_config is not None: - mask, duplication_indexes = create_mask_from_coords( - coords_config, df, da[gridx_colname].values, da[gridy_colname].values, shape - ) + var_name = find_main_var(masked_da, 2) + masked_da = masked_da[var_name] else: - # default to index approach - mask, duplication_indexes = create_mask_from_index(index_config, df, shape) + da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config) - return da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes + if index_config is not None: + mask, duplication_indexes = create_mask_from_index(index_config, df, shape) + elif coords_config is not None: + mask, duplication_indexes = create_mask_from_coords( + coords_config, df, da[gridx_colname].values, da[gridy_colname].values, shape + ) + else: + # default to index approach + mask, duplication_indexes = create_mask_from_index(index_config, df, shape) + + logger.info("Extracting timeseries at selected stations") + masked_da = apply_mask(da, mask, gridx_colname, gridy_colname) + + return da_varname, station_names, duplication_indexes, masked_da def mask_array_np(arr, mask): @@ -101,12 +127,12 @@ def apply_mask(da, mask, coordx, coordy): da, mask, input_core_dims=[(coordx, coordy), (coordx, coordy)], - output_core_dims=[["station"]], + output_core_dims=[["index"]], output_dtypes=[da.dtype], exclude_dims={coordx, coordy}, dask="parallelized", dask_gufunc_kwargs={ - "output_sizes": {"station": int(mask.sum())}, + "output_sizes": {"index": int(mask.sum())}, "allow_rechunk": True, }, ) @@ -115,13 +141,10 @@ def apply_mask(da, mask, coordx, coordy): def extractor(config): - da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes = process_inputs( - config["station"], config["grid"] - ) - logger.info("Extracting timeseries at selected stations") - masked_da = apply_mask(da, mask, gridx_colname, gridy_colname) + da_varname, station_names, duplication_indexes, masked_da = process_inputs(config["station"], config["grid"]) + print(masked_da) ds = xr.Dataset({da_varname: masked_da}) - ds = ds.isel(station=duplication_indexes) + ds = ds.isel(index=duplication_indexes) ds["station"] = station_names if config.get("output", None) is not None: logger.info(f"Saving output to {config['output']['file']}") From fc04787d37926ea946029864bfae3e9c277adca4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ois=C3=ADn=20Morrison?= Date: Thu, 16 Oct 2025 11:46:41 +0000 Subject: [PATCH 05/21] fix: update returned xr metadata --- hat/extract_timeseries/extractor.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/hat/extract_timeseries/extractor.py b/hat/extract_timeseries/extractor.py index 30b99d2..e0bfd85 100644 --- a/hat/extract_timeseries/extractor.py +++ b/hat/extract_timeseries/extractor.py @@ -23,8 +23,9 @@ def process_grid_inputs(grid_config): var_name = find_main_var(ds, 3) da = ds[var_name] logger.info(f"Xarray created from source:\n{da}\n") - gridx_colname = grid_config.get("coord_x", "lat") - gridy_colname = grid_config.get("coord_y", "lon") + coord_config = grid_config.get("coords", {}) + gridx_colname = coord_config.get("x", "lat") + gridy_colname = coord_config.get("y", "lon") da = da.sortby([gridx_colname, gridy_colname]) shape = da[gridx_colname].shape[0], da[gridy_colname].shape[0] return da, var_name, gridx_colname, gridy_colname, shape @@ -93,11 +94,9 @@ def process_inputs(station_config, grid_config): unique_indices, duplication_indexes = np.unique(df[index_1d_config].values, return_inverse=True) grid_config["source"]["gribjump"]["indices"] = unique_indices masked_da = load_ekd_source(grid_config) - # TODO: implement - da_varname = "placeholder_variable_name" - var_name = find_main_var(masked_da, 2) masked_da = masked_da[var_name] + da_varname = var_name else: da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config) @@ -142,10 +141,9 @@ def apply_mask(da, mask, coordx, coordy): def extractor(config): da_varname, station_names, duplication_indexes, masked_da = process_inputs(config["station"], config["grid"]) - print(masked_da) ds = xr.Dataset({da_varname: masked_da}) ds = ds.isel(index=duplication_indexes) - ds["station"] = station_names + ds = ds.assign_coords({"station": ("index", station_names)}) if config.get("output", None) is not None: logger.info(f"Saving output to {config['output']['file']}") ds.to_netcdf(config["output"]["file"]) From 3b7f8d25304b734ce9d3618e52221e4133ba13da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ois=C3=ADn=20Morrison?= Date: Thu, 16 Oct 2025 12:19:36 +0000 Subject: [PATCH 06/21] fix: add station as dimension --- hat/extract_timeseries/extractor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hat/extract_timeseries/extractor.py b/hat/extract_timeseries/extractor.py index e0bfd85..37ec5af 100644 --- a/hat/extract_timeseries/extractor.py +++ b/hat/extract_timeseries/extractor.py @@ -143,7 +143,8 @@ def extractor(config): da_varname, station_names, duplication_indexes, masked_da = process_inputs(config["station"], config["grid"]) ds = xr.Dataset({da_varname: masked_da}) ds = ds.isel(index=duplication_indexes) - ds = ds.assign_coords({"station": ("index", station_names)}) + ds = ds.rename({"index": "station"}) + ds["station"] = station_names if config.get("output", None) is not None: logger.info(f"Saving output to {config['output']['file']}") ds.to_netcdf(config["output"]["file"]) From b481a3e4466e80e5ce12fea6cc4f9e650f283ebe Mon Sep 17 00:00:00 2001 From: oisin-m Date: Mon, 3 Nov 2025 14:10:19 +0100 Subject: [PATCH 07/21] refactor: same handling of ekd config for extract_timeseries and compute_hydrostats --- hat/compute_hydrostats/stat_calc.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/hat/compute_hydrostats/stat_calc.py b/hat/compute_hydrostats/stat_calc.py index fd147ff..8fdc47d 100644 --- a/hat/compute_hydrostats/stat_calc.py +++ b/hat/compute_hydrostats/stat_calc.py @@ -6,7 +6,13 @@ def load_da(ds_config): - ds = ekd.from_source(*ds_config["source"]).to_xarray() + src_name = list(ds_config["source"].keys())[0] + ds = ( + ekd + .from_source(*ds_config["source"]) + .from_source(src_name, **ds_config["source"][src_name]) + .to_xarray(**ds_config.get("to_xarray_options", {})) + ) var_name = find_main_var(ds, 2) da = ds[var_name] return da @@ -35,7 +41,7 @@ def find_valid_subset(sim_da, obs_da, sim_coords, obs_coords, new_coords): def stat_calc(config): sim_config = config["sim"] - sim_da = load_da(config["sim"]) + sim_da = load_da(sim_config) obs_config = config["obs"] obs_da = load_da(obs_config) new_coords = config["output"]["coords"] From f25c4aa04d467cee329e8a1fd2d4a4a29bc9e65d Mon Sep 17 00:00:00 2001 From: oisin-m Date: Mon, 3 Nov 2025 14:10:42 +0100 Subject: [PATCH 08/21] feat: update notebooks with new config parsing --- notebooks/workflow/hydrostats_computation.ipynb | 6 +++--- notebooks/workflow/timeseries_extraction.ipynb | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/notebooks/workflow/hydrostats_computation.ipynb b/notebooks/workflow/hydrostats_computation.ipynb index 2292595..bb82958 100644 --- a/notebooks/workflow/hydrostats_computation.ipynb +++ b/notebooks/workflow/hydrostats_computation.ipynb @@ -19,14 +19,14 @@ "source": [ "config = {\n", " \"sim\": {\n", - " \"source\": [\"file\", \"extracted_timeseries.nc\"],\n", + " \"source\": {\"file\": \"extracted_timeseries.nc\"},\n", " \"coords\": {\n", " \"s\": \"station\",\n", " \"t\": \"time\"\n", " }\n", " },\n", " \"obs\": {\n", - " \"source\": [\"file\", \"observations.nc\"],\n", + " \"source\": {\"file\": \"observations.nc\"},\n", " \"coords\": {\n", " \"s\": \"station\",\n", " \"t\": \"time\"\n", @@ -49,7 +49,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "hat", "language": "python", "name": "python3" }, diff --git a/notebooks/workflow/timeseries_extraction.ipynb b/notebooks/workflow/timeseries_extraction.ipynb index 21136f8..d5de09e 100644 --- a/notebooks/workflow/timeseries_extraction.ipynb +++ b/notebooks/workflow/timeseries_extraction.ipynb @@ -35,7 +35,7 @@ " \"name\": \"station_id\"\n", " },\n", " \"grid\": {\n", - " \"source\": [\"file\", \"./sim.nc\"],\n", + " \"source\": {\"file\": \"./sim.nc\"},\n", " \"coords\": {\n", " \"x\": \"lat\",\n", " \"y\": \"lon\",\n", From c2b6081ebf4d1aed3ddc1a6ee667e7681e290f49 Mon Sep 17 00:00:00 2001 From: oisin-m Date: Mon, 3 Nov 2025 14:16:18 +0100 Subject: [PATCH 09/21] refactor: share load_da between extract_timeseries and compute_hydrostats --- hat/compute_hydrostats/stat_calc.py | 20 +++----------------- hat/core.py | 14 ++++++++++++++ hat/extract_timeseries/extractor.py | 22 +++------------------- 3 files changed, 20 insertions(+), 36 deletions(-) create mode 100644 hat/core.py diff --git a/hat/compute_hydrostats/stat_calc.py b/hat/compute_hydrostats/stat_calc.py index 8fdc47d..4da6a15 100644 --- a/hat/compute_hydrostats/stat_calc.py +++ b/hat/compute_hydrostats/stat_calc.py @@ -1,23 +1,9 @@ -import earthkit.data as ekd -from earthkit.hydro._readers import find_main_var +from hat.core import load_da import numpy as np import xarray as xr from hat.compute_hydrostats import stats -def load_da(ds_config): - src_name = list(ds_config["source"].keys())[0] - ds = ( - ekd - .from_source(*ds_config["source"]) - .from_source(src_name, **ds_config["source"][src_name]) - .to_xarray(**ds_config.get("to_xarray_options", {})) - ) - var_name = find_main_var(ds, 2) - da = ds[var_name] - return da - - def find_valid_subset(sim_da, obs_da, sim_coords, obs_coords, new_coords): sim_station_colname = sim_coords.get("s", "station") obs_station_colname = obs_coords.get("s", "station") @@ -41,9 +27,9 @@ def find_valid_subset(sim_da, obs_da, sim_coords, obs_coords, new_coords): def stat_calc(config): sim_config = config["sim"] - sim_da = load_da(sim_config) + sim_da, _ = load_da(sim_config, 2) obs_config = config["obs"] - obs_da = load_da(obs_config) + obs_da, _ = load_da(obs_config, 2) new_coords = config["output"]["coords"] sim_da, obs_da = find_valid_subset(sim_da, obs_da, sim_config["coords"], obs_config["coords"], new_coords) stat_dict = {} diff --git a/hat/core.py b/hat/core.py new file mode 100644 index 0000000..f414e25 --- /dev/null +++ b/hat/core.py @@ -0,0 +1,14 @@ +import earthkit.data as ekd +from earthkit.hydro._readers import find_main_var + + +def load_da(ds_config, n_dims): + src_name = list(ds_config["source"].keys())[0] + ds = ( + ekd.from_source(*ds_config["source"]) + .from_source(src_name, **ds_config["source"][src_name]) + .to_xarray(**ds_config.get("to_xarray_options", {})) + ) + var_name = find_main_var(ds, n_dims) + da = ds[var_name] + return da, var_name diff --git a/hat/extract_timeseries/extractor.py b/hat/extract_timeseries/extractor.py index 37ec5af..f7e0bc1 100644 --- a/hat/extract_timeseries/extractor.py +++ b/hat/extract_timeseries/extractor.py @@ -2,26 +2,13 @@ import pandas as pd import xarray as xr import numpy as np -import earthkit.data as ekd -from earthkit.hydro._readers import find_main_var +from hat.core import load_da from hat import _LOGGER as logger -def load_ekd_source(grid_config): - src_name = list(grid_config["source"].keys())[0] - logger.info(f"Processing grid inputs from source: {src_name}") - logger.debug(f"Grid config: {grid_config['source'][src_name]}") - ds = ekd.from_source(src_name, **grid_config["source"][src_name]).to_xarray( - **grid_config.get("to_xarray_options", {}) - ) - return ds - - def process_grid_inputs(grid_config): - ds = load_ekd_source(grid_config) - var_name = find_main_var(ds, 3) - da = ds[var_name] + da, var_name = load_da(grid_config, 3) logger.info(f"Xarray created from source:\n{da}\n") coord_config = grid_config.get("coords", {}) gridx_colname = coord_config.get("x", "lat") @@ -93,10 +80,7 @@ def process_inputs(station_config, grid_config): assert index_1d_config is not None unique_indices, duplication_indexes = np.unique(df[index_1d_config].values, return_inverse=True) grid_config["source"]["gribjump"]["indices"] = unique_indices - masked_da = load_ekd_source(grid_config) - var_name = find_main_var(masked_da, 2) - masked_da = masked_da[var_name] - da_varname = var_name + masked_da, da_varname = load_da(grid_config, 2) else: da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config) From b1dbb9bb3c635bd6efaf9e4605c6f0e2bf03a7fe Mon Sep 17 00:00:00 2001 From: Andreas Grafberger <18516896+andreas-grafberger@users.noreply.github.com> Date: Tue, 4 Nov 2025 13:52:23 +0000 Subject: [PATCH 10/21] fix: remove duplicate ekd.from_source call in load_da --- hat/core.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/hat/core.py b/hat/core.py index f414e25..ba82abd 100644 --- a/hat/core.py +++ b/hat/core.py @@ -4,11 +4,7 @@ def load_da(ds_config, n_dims): src_name = list(ds_config["source"].keys())[0] - ds = ( - ekd.from_source(*ds_config["source"]) - .from_source(src_name, **ds_config["source"][src_name]) - .to_xarray(**ds_config.get("to_xarray_options", {})) - ) + ds = ekd.from_source(src_name, **ds_config["source"][src_name]).to_xarray(**ds_config.get("to_xarray_options", {})) var_name = find_main_var(ds, n_dims) da = ds[var_name] return da, var_name From d5f561aee12776f493968d093c7afc925d03675d Mon Sep 17 00:00:00 2001 From: Andreas Grafberger <18516896+andreas-grafberger@users.noreply.github.com> Date: Wed, 5 Nov 2025 17:15:09 +0000 Subject: [PATCH 11/21] opt: pass ranges instead of indices to gribjump source --- hat/extract_timeseries/extractor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/hat/extract_timeseries/extractor.py b/hat/extract_timeseries/extractor.py index f7e0bc1..741d652 100644 --- a/hat/extract_timeseries/extractor.py +++ b/hat/extract_timeseries/extractor.py @@ -79,7 +79,11 @@ def process_inputs(station_config, grid_config): if list(grid_config["source"].keys())[0] == "gribjump": assert index_1d_config is not None unique_indices, duplication_indexes = np.unique(df[index_1d_config].values, return_inverse=True) - grid_config["source"]["gribjump"]["indices"] = unique_indices + # TODO: Double-check this. Converting indices to ranges is currently + # faster than using indices directly, should be fixed in the gribjump + # source. + ranges = [(i, i + 1) for i in unique_indices] + grid_config["source"]["gribjump"]["ranges"] = ranges masked_da, da_varname = load_da(grid_config, 2) else: da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config) From 0ccd02eb60cb89dd5916fe755f8758eb2b0681ba Mon Sep 17 00:00:00 2001 From: Andreas Grafberger <18516896+andreas-grafberger@users.noreply.github.com> Date: Thu, 13 Nov 2025 19:19:09 +0000 Subject: [PATCH 12/21] update earthkit-data minimum version to 0.18.0 Version 0.18.0 introduced improved handling of request-based resources, allowing their usage solely via kwargs. This is a prerequisite to use the "fdb" source without complicating the user configuration or hat's earthkit-data usage. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b590a4e..8243f0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "tqdm", "ipyleaflet", "ipywidgets", - "earthkit-data>=0.13.8", + "earthkit-data>=0.18.0", "earthkit-hydro>=1.0.0", "earthkit-meteo", "cfgrib", # check if necessary From 37fd63dfbf8e2f93454afe5f465ced3bc302ea4f Mon Sep 17 00:00:00 2001 From: Andreas Grafberger <18516896+andreas-grafberger@users.noreply.github.com> Date: Thu, 13 Nov 2025 21:03:09 +0000 Subject: [PATCH 13/21] test(extract_timeseries): add basic unit tests for extractor (non-gribjump) Add basic unit tests for extractor function covering happy paths: - Test index-based and coordinate-based station mapping - Test station filtering functionality - Test output file writing to netCDF - Test validation error when both index and coords provided Also fix bug where invalid config would call create_mask_from_index(None, ...) by adding early validation that raises clear errors. Coverage: 91% for non-gribjump paths. Does not test gribjump source or edge cases (empty inputs, malformed data, etc.). Note: Suppressed NumPy/netCDF4 ABI warning in pytest config (known harmless issue) --- hat/extract_timeseries/extractor.py | 17 ++- pyproject.toml | 9 +- tests/test_extractor.py | 229 ++++++++++++++++++++++++++++ 3 files changed, 248 insertions(+), 7 deletions(-) create mode 100644 tests/test_extractor.py diff --git a/hat/extract_timeseries/extractor.py b/hat/extract_timeseries/extractor.py index 741d652..26c06f5 100644 --- a/hat/extract_timeseries/extractor.py +++ b/hat/extract_timeseries/extractor.py @@ -72,9 +72,17 @@ def parse_stations(station_config): def process_inputs(station_config, grid_config): index_config, coords_config, index_1d_config, station_names, df = parse_stations(station_config) - # TODO: better malformed config handling + # Validate station mapping configuration if index_config is not None and coords_config is not None: - raise ValueError("Use either index or coords, not both.") + raise ValueError("Station config must use either 'index' or 'coords', not both.") + + if list(grid_config["source"].keys())[0] == "gribjump": + if index_1d_config is None: + raise ValueError("Gribjump source requires 'index_1d' in station config.") + else: + # For non-gribjump sources, require either index or coords + if index_config is None and coords_config is None: + raise ValueError("Station config must provide either 'index' or 'coords' for station mapping.") if list(grid_config["source"].keys())[0] == "gribjump": assert index_1d_config is not None @@ -90,13 +98,10 @@ def process_inputs(station_config, grid_config): if index_config is not None: mask, duplication_indexes = create_mask_from_index(index_config, df, shape) - elif coords_config is not None: + else: # coords_config is not None (validated above) mask, duplication_indexes = create_mask_from_coords( coords_config, df, da[gridx_colname].values, da[gridy_colname].values, shape ) - else: - # default to index approach - mask, duplication_indexes = create_mask_from_index(index_config, df, shape) logger.info("Extracting timeseries at selected stations") masked_da = apply_mask(da, mask, gridx_colname, gridy_colname) diff --git a/pyproject.toml b/pyproject.toml index 8243f0a..a16e3ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,10 +61,12 @@ dependencies = [ [project.optional-dependencies] test = [ - "pytest" + "pytest", + "pytest-cov" ] dev = [ "pytest", + "pytest-cov", "ruff", "pre-commit" ] @@ -101,6 +103,11 @@ addopts = "--pdbcls=IPython.terminal.debugger:Pdb" testpaths = [ "tests", ] +filterwarnings = [ + # Probably harmless numpy 2.x ABI warning from netCDF4's Cython extension + # See: https://github.com/Unidata/netcdf4-python/issues/1354 + "ignore:numpy.ndarray size changed:RuntimeWarning", +] # Packaging/setuptools options [tool.setuptools] diff --git a/tests/test_extractor.py b/tests/test_extractor.py new file mode 100644 index 0000000..775d318 --- /dev/null +++ b/tests/test_extractor.py @@ -0,0 +1,229 @@ +"""Unit tests for the extractor function.""" + +import pytest +import pandas as pd +import numpy as np +import xarray as xr + +from hat.extract_timeseries.extractor import extractor + + +@pytest.fixture +def dummy_grid_data(): + """4x5 grid, 2 timesteps, temperature variable.""" + lats = np.array([40.0, 41.0, 42.0, 43.0]) + lons = np.array([10.0, 11.0, 12.0, 13.0, 14.0]) + + temperature_values = np.array( + [ + # t=0: 2024-01-01 + [ + [10.0, 11.0, 12.0, 13.0, 14.0], # lat=40 + [15.0, 16.0, 17.0, 18.0, 19.0], # lat=41 + [20.0, 21.0, 22.0, 23.0, 24.0], # lat=42 + [25.0, 26.0, 27.0, 28.0, 29.0], # lat=43 + ], + # t=1: 2024-01-02 + [ + [30.0, 31.0, 32.0, 33.0, 34.0], # lat=40 + [35.0, 36.0, 37.0, 38.0, 39.0], # lat=41 + [40.0, 41.0, 42.0, 43.0, 44.0], # lat=42 + [45.0, 46.0, 47.0, 48.0, 49.0], # lat=43 + ], + ] + ) + + list_of_dicts = [] + + list_of_dicts.append( + { + "values": temperature_values[0].flatten(), + "param": "temperature", + "date": 20240101, + "time": 0, + "distinctLatitudes": lats.tolist(), + "distinctLongitudes": lons.tolist(), + } + ) + + list_of_dicts.append( + { + "values": temperature_values[1].flatten(), + "param": "temperature", + "date": 20240102, + "time": 0, + "distinctLatitudes": lats.tolist(), + "distinctLongitudes": lons.tolist(), + } + ) + + return list_of_dicts + + +@pytest.fixture +def station_dataframe(): + """2 stations with both index and coordinate columns.""" + return pd.DataFrame( + { + "station_id": ["STATION_A", "STATION_B"], + "opt_x_index": [1, 2], + "opt_y_index": [2, 3], + "opt_x_coord": [41.1, 41.9], # offset to test nearest-neighbor + "opt_y_coord": [12.2, 13.1], + } + ) + + +@pytest.fixture +def station_csv_file(station_dataframe, tmp_path): + """Write station DataFrame to temporary CSV.""" + csv_path = tmp_path / "stations.csv" + station_dataframe.to_csv(csv_path, index=False) + return str(csv_path) + + +@pytest.mark.parametrize( + "mapping_config", + [ + {"index": {"x": "opt_x_index", "y": "opt_y_index"}}, + {"coords": {"x": "opt_x_coord", "y": "opt_y_coord"}}, + ], + ids=["index", "coords"], +) +def test_extractor_with_temperature(dummy_grid_data, station_csv_file, mapping_config): + """Test extraction with both index and coords mapping.""" + config = { + "station": { + "file": station_csv_file, + "name": "station_id", + **mapping_config, + }, + "grid": { + "source": { + "list-of-dicts": { + "list_of_dicts": dummy_grid_data, + } + }, + "coords": { + "x": "latitude", + "y": "longitude", + }, + }, + } + + result_ds = extractor(config) + + assert isinstance(result_ds, xr.Dataset) + assert "temperature" in result_ds.data_vars + assert "station" in result_ds.dims + assert len(result_ds.station) == 2 + assert list(result_ds.station.values) == ["STATION_A", "STATION_B"] + + time_dim = "time" if "time" in result_ds.dims else "forecast_reference_time" + assert len(result_ds[time_dim]) == 2 + + # Station A: indices [1,2] -> values [17.0, 37.0] + # Station B: indices [2,3] -> values [23.0, 43.0] + np.testing.assert_allclose(result_ds["temperature"].sel(station="STATION_A").values, [17.0, 37.0]) + np.testing.assert_allclose(result_ds["temperature"].sel(station="STATION_B").values, [23.0, 43.0]) + + +def test_extractor_with_station_filter(dummy_grid_data, tmp_path): + """Test station filtering.""" + df = pd.DataFrame( + { + "station_id": ["S1", "S2", "S3"], + "opt_x_index": [1, 2, 1], + "opt_y_index": [2, 3, 3], + "network": ["primary", "secondary", "primary"], + } + ) + csv_path = tmp_path / "stations.csv" + df.to_csv(csv_path, index=False) + + config = { + "station": { + "file": str(csv_path), + "name": "station_id", + "filter": "network == 'primary'", + "index": {"x": "opt_x_index", "y": "opt_y_index"}, + }, + "grid": { + "source": { + "list-of-dicts": { + "list_of_dicts": dummy_grid_data, + } + }, + "coords": { + "x": "latitude", + "y": "longitude", + }, + }, + } + + result_ds = extractor(config) + + assert len(result_ds.station) == 2 + assert list(result_ds.station.values) == ["S1", "S3"] + assert result_ds["temperature"].sel(station="S1").values[0] == 17.0 + assert result_ds["temperature"].sel(station="S3").values[0] == 18.0 + + +def test_extractor_rejects_both_index_and_coords(dummy_grid_data, station_csv_file): + """Test that providing both index and coords raises ValueError.""" + config = { + "station": { + "file": station_csv_file, + "name": "station_id", + "index": {"x": "opt_x_index", "y": "opt_y_index"}, + "coords": {"x": "opt_x_coord", "y": "opt_y_coord"}, # Both provided + }, + "grid": { + "source": {"list-of-dicts": {"list_of_dicts": dummy_grid_data}}, + "coords": {"x": "latitude", "y": "longitude"}, + }, + } + + with pytest.raises(ValueError, match="must use either 'index' or 'coords', not both"): + extractor(config) + + +def test_extractor_with_output_file(dummy_grid_data, station_csv_file, tmp_path): + """Test output file writing.""" + output_file = tmp_path / "output.nc" + + config = { + "station": { + "file": station_csv_file, + "name": "station_id", + "index": {"x": "opt_x_index", "y": "opt_y_index"}, + }, + "grid": { + "source": { + "list-of-dicts": { + "list_of_dicts": dummy_grid_data, + } + }, + "coords": { + "x": "latitude", + "y": "longitude", + }, + }, + "output": { + "file": str(output_file), + }, + } + + result_ds = extractor(config) + + assert output_file.exists() + + loaded_ds = xr.open_dataset(output_file) + assert "temperature" in loaded_ds.data_vars + assert "station" in loaded_ds.dims + assert len(loaded_ds.station) == 2 + + xr.testing.assert_allclose(result_ds["temperature"], loaded_ds["temperature"]) + xr.testing.assert_equal(result_ds.station, loaded_ds.station) + + loaded_ds.close() From 8a6973c0c81bb839ee7192bd1a82902674cd700e Mon Sep 17 00:00:00 2001 From: Andreas Grafberger <18516896+andreas-grafberger@users.noreply.github.com> Date: Thu, 13 Nov 2025 22:27:38 +0000 Subject: [PATCH 14/21] refactor(extractor): simplify control flow and improve readability Separate gribjump and regular processing to reduce nested decision logic. Streamline functions to return single values instead of tuples where possible. Add early user-friendly validation for station indices and improve naming throughout. - Separate gribjump vs. regular processing paths - Streamline function return values (reduce tuple unpacking) - Add early user-friendly bounds-check for station indices - Rename variables for consistency and remove redundant comments --- hat/extract_timeseries/extractor.py | 191 ++++++++++++++++++---------- tests/test_extractor.py | 59 +++++++++ 2 files changed, 181 insertions(+), 69 deletions(-) diff --git a/hat/extract_timeseries/extractor.py b/hat/extract_timeseries/extractor.py index 26c06f5..1208374 100644 --- a/hat/extract_timeseries/extractor.py +++ b/hat/extract_timeseries/extractor.py @@ -1,4 +1,4 @@ -from dask.diagnostics import ProgressBar +from dask.diagnostics.progress import ProgressBar import pandas as pd import xarray as xr import numpy as np @@ -11,102 +11,159 @@ def process_grid_inputs(grid_config): da, var_name = load_da(grid_config, 3) logger.info(f"Xarray created from source:\n{da}\n") coord_config = grid_config.get("coords", {}) - gridx_colname = coord_config.get("x", "lat") - gridy_colname = coord_config.get("y", "lon") - da = da.sortby([gridx_colname, gridy_colname]) - shape = da[gridx_colname].shape[0], da[gridy_colname].shape[0] - return da, var_name, gridx_colname, gridy_colname, shape + x_dim = coord_config.get("x", "lat") + y_dim = coord_config.get("y", "lon") + da = da.sortby([x_dim, y_dim]) + shape = da[x_dim].shape[0], da[y_dim].shape[0] + return da, var_name, x_dim, y_dim, shape -def construct_mask(indx, indy, shape): +def construct_mask(x_indices, y_indices, shape): mask = np.zeros(shape, dtype=bool) - mask[indx, indy] = True + mask[x_indices, y_indices] = True - flat_indices = np.ravel_multi_index((indx, indy), shape) - _, inverse = np.unique(flat_indices, return_inverse=True) - return mask, inverse + flat_indices = np.ravel_multi_index((x_indices, y_indices), shape) + _, duplication_indexes = np.unique(flat_indices, return_inverse=True) + return mask, duplication_indexes -def create_mask_from_index(index_config, df, shape): - logger.info(f"Creating mask {shape} from index: {index_config}") +def create_mask_from_index(df, shape): + logger.info(f"Creating mask {shape} from index") logger.debug(f"DataFrame columns: {df.columns.tolist()}") - indx_colname = index_config.get("x", "opt_x_index") - indy_colname = index_config.get("y", "opt_y_index") - indx, indy = df[indx_colname].values, df[indy_colname].values - mask, duplication_indexes = construct_mask(indx, indy, shape) + x_indices, y_indices = df["x_index"].values, df["y_index"].values + if np.any(x_indices < 0) or np.any(x_indices >= shape[0]) or np.any(y_indices < 0) or np.any(y_indices >= shape[1]): + raise ValueError( + f"Station indices out of grid bounds. Grid shape={shape}, " + f"x_index range=({int(x_indices.min())},{int(x_indices.max())}), " + f"y_index range=({int(y_indices.min())},{int(y_indices.max())})" + ) + mask, duplication_indexes = construct_mask(x_indices, y_indices, shape) return mask, duplication_indexes -def create_mask_from_coords(coords_config, df, gridx, gridy, shape): - logger.info(f"Creating mask {shape} from coordinates: {coords_config}") +def create_mask_from_coords(df, gridx, gridy, shape): + logger.info(f"Creating mask {shape} from coordinates") logger.debug(f"DataFrame columns: {df.columns.tolist()}") - x_colname = coords_config.get("x", "opt_x_coord") - y_colname = coords_config.get("y", "opt_y_coord") - xs = df[x_colname].values - ys = df[y_colname].values + station_x = df["x_coord"].values + station_y = df["y_coord"].values - diffx = np.abs(xs[:, np.newaxis] - gridx) - indx = np.argmin(diffx, axis=1) - diffy = np.abs(ys[:, np.newaxis] - gridy) - indy = np.argmin(diffy, axis=1) + x_distances = np.abs(station_x[:, np.newaxis] - gridx) + x_indices = np.argmin(x_distances, axis=1) + y_distances = np.abs(station_y[:, np.newaxis] - gridy) + y_indices = np.argmin(y_distances, axis=1) - mask, duplication_indexes = construct_mask(indx, indy, shape) + mask, duplication_indexes = construct_mask(x_indices, y_indices, shape) return mask, duplication_indexes def parse_stations(station_config): + """Read, filter, and normalize station DataFrame to canonical column names.""" logger.debug(f"Reading station file, {station_config}") + if "name" not in station_config: + raise ValueError("Station config must include a 'name' key mapping to the station column") df = pd.read_csv(station_config["file"]) filters = station_config.get("filter") if filters is not None: logger.debug(f"Applying filters: {filters} to station DataFrame") df = df.query(filters) - station_names = df[station_config["name"]].values - index_config = station_config.get("index", None) - coords_config = station_config.get("coords", None) - index_1d_config = station_config.get("index_1d", None) - return index_config, coords_config, index_1d_config, station_names, df + if len(df) == 0: + raise ValueError("No stations found. Check station file or filter.") + has_index = "index" in station_config + has_coords = "coords" in station_config + has_index_1d = "index_1d" in station_config -def process_inputs(station_config, grid_config): - index_config, coords_config, index_1d_config, station_names, df = parse_stations(station_config) + if not has_index_1d: + if has_index and has_coords: + raise ValueError("Station config must use either 'index' or 'coords', not both.") + if not has_index and not has_coords: + raise ValueError("Station config must provide either 'index' or 'coords' for station mapping.") - # Validate station mapping configuration - if index_config is not None and coords_config is not None: - raise ValueError("Station config must use either 'index' or 'coords', not both.") + renames = {} + renames[station_config["name"]] = "station_name" - if list(grid_config["source"].keys())[0] == "gribjump": - if index_1d_config is None: - raise ValueError("Gribjump source requires 'index_1d' in station config.") - else: - # For non-gribjump sources, require either index or coords - if index_config is None and coords_config is None: - raise ValueError("Station config must provide either 'index' or 'coords' for station mapping.") + if has_index: + index_config = station_config["index"] + x_col = index_config.get("x", "opt_x_index") + y_col = index_config.get("y", "opt_y_index") + renames[x_col] = "x_index" + renames[y_col] = "y_index" + + if has_coords: + coords_config = station_config["coords"] + x_col = coords_config.get("x", "opt_x_coord") + y_col = coords_config.get("y", "opt_y_coord") + renames[x_col] = "x_coord" + renames[y_col] = "y_coord" + + if has_index_1d: + renames[station_config["index_1d"]] = "index_1d" + + df_renamed = df.rename(columns=renames) + + if has_index and ("x_index" not in df_renamed.columns or "y_index" not in df_renamed.columns): + raise ValueError( + "Station file missing required index columns. Expected columns to map to 'x_index' and 'y_index'." + ) + if has_coords and ("x_coord" not in df_renamed.columns or "y_coord" not in df_renamed.columns): + raise ValueError( + "Station file missing required coordinate columns. Expected columns to map to 'x_coord' and 'y_coord'." + ) + if has_index_1d and "index_1d" not in df_renamed.columns: + raise ValueError("Station file missing required 'index_1d' column.") + + return df_renamed + + +def _process_gribjump(grid_config, df): + if "index_1d" not in df.columns: + raise ValueError("Gribjump source requires 'index_1d' in station config.") + + station_names = df["station_name"].values + unique_indices, duplication_indexes = np.unique(df["index_1d"].values, return_inverse=True) + # TODO: Double-check this. Converting indices to ranges is currently + # faster than using indices directly, should be fixed in the gribjump + # source. + ranges = [(i, i + 1) for i in unique_indices] + + gribjump_config = {"source": {"gribjump": {**grid_config["source"]["gribjump"], "ranges": ranges}}} + + masked_da, var_name = load_da(gribjump_config, 2) + + ds = xr.Dataset({var_name: masked_da}) + ds = ds.isel(index=duplication_indexes) + ds = ds.rename({"index": "station"}) + ds["station"] = station_names + return ds - if list(grid_config["source"].keys())[0] == "gribjump": - assert index_1d_config is not None - unique_indices, duplication_indexes = np.unique(df[index_1d_config].values, return_inverse=True) - # TODO: Double-check this. Converting indices to ranges is currently - # faster than using indices directly, should be fixed in the gribjump - # source. - ranges = [(i, i + 1) for i in unique_indices] - grid_config["source"]["gribjump"]["ranges"] = ranges - masked_da, da_varname = load_da(grid_config, 2) + +def _process_regular(grid_config, df): + station_names = df["station_name"].values + da, var_name, x_dim, y_dim, shape = process_grid_inputs(grid_config) + + use_index = "x_index" in df.columns and "y_index" in df.columns + + if use_index: + mask, duplication_indexes = create_mask_from_index(df, shape) else: - da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config) + mask, duplication_indexes = create_mask_from_coords(df, da[x_dim].values, da[y_dim].values, shape) + + logger.info("Extracting timeseries at selected stations") + masked_da = apply_mask(da, mask, x_dim, y_dim) - if index_config is not None: - mask, duplication_indexes = create_mask_from_index(index_config, df, shape) - else: # coords_config is not None (validated above) - mask, duplication_indexes = create_mask_from_coords( - coords_config, df, da[gridx_colname].values, da[gridy_colname].values, shape - ) + ds = xr.Dataset({var_name: masked_da}) + ds = ds.isel(index=duplication_indexes) + ds = ds.rename({"index": "station"}) + ds["station"] = station_names + return ds - logger.info("Extracting timeseries at selected stations") - masked_da = apply_mask(da, mask, gridx_colname, gridy_colname) - return da_varname, station_names, duplication_indexes, masked_da +def process_inputs(station_config, grid_config): + df = parse_stations(station_config) + if "gribjump" in grid_config.get("source", {}): + return _process_gribjump(grid_config, df) + return _process_regular(grid_config, df) def mask_array_np(arr, mask): @@ -133,11 +190,7 @@ def apply_mask(da, mask, coordx, coordy): def extractor(config): - da_varname, station_names, duplication_indexes, masked_da = process_inputs(config["station"], config["grid"]) - ds = xr.Dataset({da_varname: masked_da}) - ds = ds.isel(index=duplication_indexes) - ds = ds.rename({"index": "station"}) - ds["station"] = station_names + ds = process_inputs(config["station"], config["grid"]) if config.get("output", None) is not None: logger.info(f"Saving output to {config['output']['file']}") ds.to_netcdf(config["output"]["file"]) diff --git a/tests/test_extractor.py b/tests/test_extractor.py index 775d318..41ed39d 100644 --- a/tests/test_extractor.py +++ b/tests/test_extractor.py @@ -4,6 +4,7 @@ import pandas as pd import numpy as np import xarray as xr +from unittest.mock import Mock, patch from hat.extract_timeseries.extractor import extractor @@ -227,3 +228,61 @@ def test_extractor_with_output_file(dummy_grid_data, station_csv_file, tmp_path) xr.testing.assert_equal(result_ds.station, loaded_ds.station) loaded_ds.close() + + +def test_extractor_with_empty_stations(dummy_grid_data, tmp_path): + """Test that extractor raises clear error for empty station list.""" + empty_csv = tmp_path / "empty_stations.csv" + pd.DataFrame(columns=["station_id", "opt_x_index", "opt_y_index"]).to_csv(empty_csv, index=False) + + config = { + "station": { + "file": str(empty_csv), + "name": "station_id", + "index": {"x": "opt_x_index", "y": "opt_y_index"}, + }, + "grid": { + "source": {"list-of-dicts": {"list_of_dicts": dummy_grid_data}}, + "coords": {"x": "latitude", "y": "longitude"}, + }, + } + + with pytest.raises(ValueError, match="No stations found"): + extractor(config) + + +@patch("earthkit.data.from_source") +def test_extractor_gribjump(mock_from_source, tmp_path): + """Test gribjump path: verifies ranges computation and earthkit call.""" + + # Mock returns object with to_xarray() that returns minimal dataset + mock_source = Mock() + mock_source.to_xarray.return_value = xr.Dataset( + {"temperature": xr.DataArray([[15.0, 25.0], [35.0, 45.0]], dims=["index", "time"])} + ) + mock_from_source.return_value = mock_source + + # Station CSV with index_1d (includes duplicate to test deduplication) + csv_file = tmp_path / "stations.csv" + pd.DataFrame( + { + "name": ["S1", "S2", "S3"], + "idx": [100, 200, 100], # S1 and S3 share index 100 + } + ).to_csv(csv_file, index=False) + + config = { + "station": {"file": str(csv_file), "name": "name", "index_1d": "idx"}, + "grid": {"source": {"gribjump": {"request": {"class": "od", "expver": "0001", "stream": "oper"}}}}, + } + + result = extractor(config) + + # Verify earthkit.data.from_source was called correctly + mock_from_source.assert_called_once_with( + "gribjump", request={"class": "od", "expver": "0001", "stream": "oper"}, ranges=[(100, 101), (200, 201)] + ) + + # Verify output + assert len(result.station) == 3 + assert list(result.station.values) == ["S1", "S2", "S3"] From f1c69530c2a7bc3005be3c74e895ebfcac8ca65b Mon Sep 17 00:00:00 2001 From: Andreas Grafberger <18516896+andreas-grafberger@users.noreply.github.com> Date: Thu, 13 Nov 2025 22:40:06 +0000 Subject: [PATCH 15/21] feat(extractor): add type hints to some functions --- hat/extract_timeseries/extractor.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/hat/extract_timeseries/extractor.py b/hat/extract_timeseries/extractor.py index 1208374..bcf53e1 100644 --- a/hat/extract_timeseries/extractor.py +++ b/hat/extract_timeseries/extractor.py @@ -2,6 +2,7 @@ import pandas as pd import xarray as xr import numpy as np +from typing import Any from hat.core import load_da from hat import _LOGGER as logger @@ -30,7 +31,8 @@ def construct_mask(x_indices, y_indices, shape): def create_mask_from_index(df, shape): logger.info(f"Creating mask {shape} from index") logger.debug(f"DataFrame columns: {df.columns.tolist()}") - x_indices, y_indices = df["x_index"].values, df["y_index"].values + x_indices = df["x_index"].values + y_indices = df["y_index"].values if np.any(x_indices < 0) or np.any(x_indices >= shape[0]) or np.any(y_indices < 0) or np.any(y_indices >= shape[1]): raise ValueError( f"Station indices out of grid bounds. Grid shape={shape}, " @@ -56,7 +58,7 @@ def create_mask_from_coords(df, gridx, gridy, shape): return mask, duplication_indexes -def parse_stations(station_config): +def parse_stations(station_config: dict[str, Any]) -> pd.DataFrame: """Read, filter, and normalize station DataFrame to canonical column names.""" logger.debug(f"Reading station file, {station_config}") if "name" not in station_config: @@ -116,12 +118,12 @@ def parse_stations(station_config): return df_renamed -def _process_gribjump(grid_config, df): +def _process_gribjump(grid_config: dict[str, Any], df: pd.DataFrame) -> xr.Dataset: if "index_1d" not in df.columns: raise ValueError("Gribjump source requires 'index_1d' in station config.") station_names = df["station_name"].values - unique_indices, duplication_indexes = np.unique(df["index_1d"].values, return_inverse=True) + unique_indices, duplication_indexes = np.unique(df["index_1d"].values, return_inverse=True) # type: ignore[call-overload] # TODO: Double-check this. Converting indices to ranges is currently # faster than using indices directly, should be fixed in the gribjump # source. @@ -138,7 +140,7 @@ def _process_gribjump(grid_config, df): return ds -def _process_regular(grid_config, df): +def _process_regular(grid_config: dict[str, Any], df: pd.DataFrame) -> xr.Dataset: station_names = df["station_name"].values da, var_name, x_dim, y_dim, shape = process_grid_inputs(grid_config) @@ -159,7 +161,7 @@ def _process_regular(grid_config, df): return ds -def process_inputs(station_config, grid_config): +def process_inputs(station_config: dict[str, Any], grid_config: dict[str, Any]) -> xr.Dataset: df = parse_stations(station_config) if "gribjump" in grid_config.get("source", {}): return _process_gribjump(grid_config, df) @@ -189,7 +191,7 @@ def apply_mask(da, mask, coordx, coordy): return task.compute() -def extractor(config): +def extractor(config: dict[str, Any]) -> xr.Dataset: ds = process_inputs(config["station"], config["grid"]) if config.get("output", None) is not None: logger.info(f"Saving output to {config['output']['file']}") From 7db064f4253d9669d779d2ef5b430e5d19ecf58e Mon Sep 17 00:00:00 2001 From: Andreas Grafberger <18516896+andreas-grafberger@users.noreply.github.com> Date: Mon, 17 Nov 2025 12:41:58 +0000 Subject: [PATCH 16/21] chore: add disclaimer to gribjumplib in pyproject.toml --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a16e3ad..f7e25ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,10 @@ dependencies = [ ] gribjump = [ "earthkit-data[gribjump]", + # At this time, there is not an official release on PyPI which can be easily + # installed without potential ABI issues with other dependencies. As long as + # this is the case and using gribjump is an experimental feature in HAT, we + # point to a development release that should work for most users. "gribjumplib==0.10.3.dev20250908" ] From d5dda4d40083dbd4182e461d68ef567666398752 Mon Sep 17 00:00:00 2001 From: Andreas Grafberger <18516896+andreas-grafberger@users.noreply.github.com> Date: Mon, 17 Nov 2025 12:50:49 +0000 Subject: [PATCH 17/21] chore: add installation instructions for experimental gribjump extras --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index 7585587..077a3ab 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,16 @@ pip install -e .[dev] pre-commit install ``` +HAT provides **experimental** support for earthkit-data's [gribjump source](https://earthkit-data.readthedocs.io/en/latest/guide/sources.html#gribjump). +To install the gribjump extras for testing and experimentation, run: +```bash +pip install hydro-analysis-toolkit[gribjump] +``` + +> [!NOTE] +> The gribjump feature is experimental and depends on a pre-release version of `gribjumplib`. This feature is not recommended for production use and may change or break in future releases. + + ## Licence ``` From 83c861b6077df9890e47bca810ddfe23417c989b Mon Sep 17 00:00:00 2001 From: Andreas Grafberger <18516896+andreas-grafberger@users.noreply.github.com> Date: Wed, 19 Nov 2025 09:48:32 +0000 Subject: [PATCH 18/21] chore: minor cosmetic changes like comments --- hat/core.py | 3 ++- hat/extract_timeseries/extractor.py | 24 ++++++++++++++++++------ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/hat/core.py b/hat/core.py index ba82abd..12213a5 100644 --- a/hat/core.py +++ b/hat/core.py @@ -4,7 +4,8 @@ def load_da(ds_config, n_dims): src_name = list(ds_config["source"].keys())[0] - ds = ekd.from_source(src_name, **ds_config["source"][src_name]).to_xarray(**ds_config.get("to_xarray_options", {})) + source = ekd.from_source(src_name, **ds_config["source"][src_name]) + ds = source.to_xarray(**ds_config.get("to_xarray_options", {})) var_name = find_main_var(ds, n_dims) da = ds[var_name] return da, var_name diff --git a/hat/extract_timeseries/extractor.py b/hat/extract_timeseries/extractor.py index bcf53e1..1975213 100644 --- a/hat/extract_timeseries/extractor.py +++ b/hat/extract_timeseries/extractor.py @@ -124,12 +124,24 @@ def _process_gribjump(grid_config: dict[str, Any], df: pd.DataFrame) -> xr.Datas station_names = df["station_name"].values unique_indices, duplication_indexes = np.unique(df["index_1d"].values, return_inverse=True) # type: ignore[call-overload] - # TODO: Double-check this. Converting indices to ranges is currently - # faster than using indices directly, should be fixed in the gribjump - # source. + + # Converting indices to ranges is currently faster than using indices + # directly. This is a problem in the earthkit-data gribjump source and will + # be fixed there. ranges = [(i, i + 1) for i in unique_indices] - gribjump_config = {"source": {"gribjump": {**grid_config["source"]["gribjump"], "ranges": ranges}}} + gribjump_config = { + "source": { + "gribjump": { + **grid_config["source"]["gribjump"], + "ranges": ranges, + # fetch_coords_from_fdb is currently very slow. Needs fix in + # earthkit-data gribjump source. + # "fetch_coords_from_fdb": True, + } + }, + "to_xarray_options": grid_config.get("to_xarray_options", {}), + } masked_da, var_name = load_da(gribjump_config, 2) @@ -168,11 +180,11 @@ def process_inputs(station_config: dict[str, Any], grid_config: dict[str, Any]) return _process_regular(grid_config, df) -def mask_array_np(arr, mask): +def mask_array_np(arr: np.ndarray, mask: np.ndarray) -> np.ndarray: return arr[..., mask] -def apply_mask(da, mask, coordx, coordy): +def apply_mask(da: xr.DataArray, mask: np.ndarray, coordx: str, coordy: str) -> xr.DataArray: task = xr.apply_ufunc( mask_array_np, da, From 11470b428f1a93f8eccc5502617695c255d54c07 Mon Sep 17 00:00:00 2001 From: Andreas Grafberger <18516896+andreas-grafberger@users.noreply.github.com> Date: Wed, 19 Nov 2025 09:51:34 +0000 Subject: [PATCH 19/21] chore: bump minimum earthkit-data version to 0.18.2 This new version includes a bugfix that affected the gribjump source and made time_dim_mode="valid_time" not work in the xarray dataset creation. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f7e25ad..7e0fc76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "tqdm", "ipyleaflet", "ipywidgets", - "earthkit-data>=0.18.0", + "earthkit-data>=0.18.2", "earthkit-hydro>=1.0.0", "earthkit-meteo", "cfgrib", # check if necessary From d413ccb1293ba81dc3ae0a365d00a50dbd3bf7b4 Mon Sep 17 00:00:00 2001 From: Andreas Grafberger <18516896+andreas-grafberger@users.noreply.github.com> Date: Wed, 19 Nov 2025 13:19:01 +0100 Subject: [PATCH 20/21] Change ProgessBar import from dask --- hat/extract_timeseries/extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hat/extract_timeseries/extractor.py b/hat/extract_timeseries/extractor.py index 1975213..7b76a2b 100644 --- a/hat/extract_timeseries/extractor.py +++ b/hat/extract_timeseries/extractor.py @@ -1,4 +1,4 @@ -from dask.diagnostics.progress import ProgressBar +from dask.diagnostics import ProgressBar import pandas as pd import xarray as xr import numpy as np From 4d171bb2c71eb5433666e8070e04bf38807af96c Mon Sep 17 00:00:00 2001 From: Andreas Grafberger <18516896+andreas-grafberger@users.noreply.github.com> Date: Tue, 25 Nov 2025 10:17:27 +0000 Subject: [PATCH 21/21] remove gribjumplib as a dependency --- README.md | 4 +++- pyproject.toml | 7 +------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 077a3ab..6a420fa 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,9 @@ pip install hydro-analysis-toolkit[gribjump] ``` > [!NOTE] -> The gribjump feature is experimental and depends on a pre-release version of `gribjumplib`. This feature is not recommended for production use and may change or break in future releases. +> The gribjump feature is experimental. It is not recommended for production use and may change or break in future releases. +> Information on how to build gribjump can be found in [GribJump's source code](https://github.com/ecmwf/gribjump/). Experimental +> wheels of `gribjumplib` can also be found [on PyPI](https://pypi.org/project/gribjumplib/). ## Licence diff --git a/pyproject.toml b/pyproject.toml index 7e0fc76..5094c94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,12 +71,7 @@ dependencies = [ "pre-commit" ] gribjump = [ - "earthkit-data[gribjump]", - # At this time, there is not an official release on PyPI which can be easily - # installed without potential ABI issues with other dependencies. As long as - # this is the case and using gribjump is an experimental feature in HAT, we - # point to a development release that should work for most users. - "gribjumplib==0.10.3.dev20250908" + "earthkit-data[gribjump]" ] [project.scripts]