Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
adf8c64
add optional dependency group with gribjump dependencies
andreas-grafberger Oct 16, 2025
93c606a
chore: min ekh version
Oisin-M Oct 16, 2025
8385bae
feat: add opt_1d_index to outputted station mapping df
Oisin-M Oct 16, 2025
8943050
feat: first basic implementation using gribjump for timeseries extrac…
Oisin-M Oct 16, 2025
fc04787
fix: update returned xr metadata
Oisin-M Oct 16, 2025
3b7f8d2
fix: add station as dimension
Oisin-M Oct 16, 2025
b481a3e
refactor: same handling of ekd config for extract_timeseries and comp…
Oisin-M Nov 3, 2025
f25c4aa
feat: update notebooks with new config parsing
Oisin-M Nov 3, 2025
c2b6081
refactor: share load_da between extract_timeseries and compute_hydros…
Oisin-M Nov 3, 2025
b1dbb9b
fix: remove duplicate ekd.from_source call in load_da
andreas-grafberger Nov 4, 2025
d5f561a
opt: pass ranges instead of indices to gribjump source
andreas-grafberger Nov 5, 2025
0ccd02e
update earthkit-data minimum version to 0.18.0
andreas-grafberger Nov 13, 2025
37fd63d
test(extract_timeseries): add basic unit tests for extractor (non-gri…
andreas-grafberger Nov 13, 2025
8a6973c
refactor(extractor): simplify control flow and improve readability
andreas-grafberger Nov 13, 2025
f1c6953
feat(extractor): add type hints to some functions
andreas-grafberger Nov 13, 2025
7db064f
chore: add disclaimer to gribjumplib in pyproject.toml
andreas-grafberger Nov 17, 2025
d5dda4d
chore: add installation instructions for experimental gribjump extras
andreas-grafberger Nov 17, 2025
83c861b
chore: minor cosmetic changes like comments
andreas-grafberger Nov 19, 2025
11470b4
chore: bump minimum earthkit-data version to 0.18.2
andreas-grafberger Nov 19, 2025
d413ccb
Change ProgessBar import from dask
andreas-grafberger Nov 19, 2025
4d171bb
remove gribjumplib as a dependency
andreas-grafberger Nov 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ pip install -e .[dev]
pre-commit install
```

HAT provides **experimental** support for earthkit-data's [gribjump source](https://earthkit-data.readthedocs.io/en/latest/guide/sources.html#gribjump).
To install the gribjump extras for testing and experimentation, run:
```bash
pip install hydro-analysis-toolkit[gribjump]
```

> [!NOTE]
> The gribjump feature is experimental. It is not recommended for production use and may change or break in future releases.
> Information on how to build gribjump can be found in [GribJump's source code](https://github.com/ecmwf/gribjump/). Experimental
> wheels of `gribjumplib` can also be found [on PyPI](https://pypi.org/project/gribjumplib/).


## Licence

```
Expand Down
14 changes: 3 additions & 11 deletions hat/compute_hydrostats/stat_calc.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
import earthkit.data as ekd
from earthkit.hydro._readers import find_main_var
from hat.core import load_da
import numpy as np
import xarray as xr
from hat.compute_hydrostats import stats


def load_da(ds_config):
ds = ekd.from_source(*ds_config["source"]).to_xarray()
var_name = find_main_var(ds, 2)
da = ds[var_name]
return da


def find_valid_subset(sim_da, obs_da, sim_coords, obs_coords, new_coords):
sim_station_colname = sim_coords.get("s", "station")
obs_station_colname = obs_coords.get("s", "station")
Expand All @@ -35,9 +27,9 @@ def find_valid_subset(sim_da, obs_da, sim_coords, obs_coords, new_coords):

def stat_calc(config):
sim_config = config["sim"]
sim_da = load_da(config["sim"])
sim_da, _ = load_da(sim_config, 2)
obs_config = config["obs"]
obs_da = load_da(obs_config)
obs_da, _ = load_da(obs_config, 2)
new_coords = config["output"]["coords"]
sim_da, obs_da = find_valid_subset(sim_da, obs_da, sim_config["coords"], obs_config["coords"], new_coords)
stat_dict = {}
Expand Down
11 changes: 11 additions & 0 deletions hat/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import earthkit.data as ekd
from earthkit.hydro._readers import find_main_var


def load_da(ds_config, n_dims):
src_name = list(ds_config["source"].keys())[0]
source = ekd.from_source(src_name, **ds_config["source"][src_name])
ds = source.to_xarray(**ds_config.get("to_xarray_options", {}))
var_name = find_main_var(ds, n_dims)
da = ds[var_name]
return da, var_name
212 changes: 147 additions & 65 deletions hat/extract_timeseries/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,127 +2,209 @@
import pandas as pd
import xarray as xr
import numpy as np
import earthkit.data as ekd
from earthkit.hydro._readers import find_main_var
from typing import Any
from hat.core import load_da

from hat import _LOGGER as logger


def process_grid_inputs(grid_config):
src_name = list(grid_config["source"].keys())[0]
logger.info(f"Processing grid inputs from source: {src_name}")
logger.debug(f"Grid config: {grid_config['source'][src_name]}")
ds = ekd.from_source(src_name, **grid_config["source"][src_name]).to_xarray(
**grid_config.get("to_xarray_options", {})
)
var_name = find_main_var(ds, 3)
da = ds[var_name]
da, var_name = load_da(grid_config, 3)
logger.info(f"Xarray created from source:\n{da}\n")
gridx_colname = grid_config.get("coord_x", "lat")
gridy_colname = grid_config.get("coord_y", "lon")
da = da.sortby([gridx_colname, gridy_colname])
shape = da[gridx_colname].shape[0], da[gridy_colname].shape[0]
return da, var_name, gridx_colname, gridy_colname, shape
coord_config = grid_config.get("coords", {})
x_dim = coord_config.get("x", "lat")
y_dim = coord_config.get("y", "lon")
da = da.sortby([x_dim, y_dim])
shape = da[x_dim].shape[0], da[y_dim].shape[0]
return da, var_name, x_dim, y_dim, shape


def construct_mask(indx, indy, shape):
def construct_mask(x_indices, y_indices, shape):
mask = np.zeros(shape, dtype=bool)
mask[indx, indy] = True
mask[x_indices, y_indices] = True

flat_indices = np.ravel_multi_index((indx, indy), shape)
_, inverse = np.unique(flat_indices, return_inverse=True)
return mask, inverse
flat_indices = np.ravel_multi_index((x_indices, y_indices), shape)
_, duplication_indexes = np.unique(flat_indices, return_inverse=True)
return mask, duplication_indexes


def create_mask_from_index(index_config, df, shape):
logger.info(f"Creating mask {shape} from index: {index_config}")
def create_mask_from_index(df, shape):
logger.info(f"Creating mask {shape} from index")
logger.debug(f"DataFrame columns: {df.columns.tolist()}")
indx_colname = index_config.get("x", "opt_x_index")
indy_colname = index_config.get("y", "opt_y_index")
indx, indy = df[indx_colname].values, df[indy_colname].values
mask, duplication_indexes = construct_mask(indx, indy, shape)
x_indices = df["x_index"].values
y_indices = df["y_index"].values
if np.any(x_indices < 0) or np.any(x_indices >= shape[0]) or np.any(y_indices < 0) or np.any(y_indices >= shape[1]):
raise ValueError(
f"Station indices out of grid bounds. Grid shape={shape}, "
f"x_index range=({int(x_indices.min())},{int(x_indices.max())}), "
f"y_index range=({int(y_indices.min())},{int(y_indices.max())})"
)
mask, duplication_indexes = construct_mask(x_indices, y_indices, shape)
return mask, duplication_indexes


def create_mask_from_coords(coords_config, df, gridx, gridy, shape):
logger.info(f"Creating mask {shape} from coordinates: {coords_config}")
def create_mask_from_coords(df, gridx, gridy, shape):
logger.info(f"Creating mask {shape} from coordinates")
logger.debug(f"DataFrame columns: {df.columns.tolist()}")
x_colname = coords_config.get("x", "opt_x_coord")
y_colname = coords_config.get("y", "opt_y_coord")
xs = df[x_colname].values
ys = df[y_colname].values
station_x = df["x_coord"].values
station_y = df["y_coord"].values

diffx = np.abs(xs[:, np.newaxis] - gridx)
indx = np.argmin(diffx, axis=1)
diffy = np.abs(ys[:, np.newaxis] - gridy)
indy = np.argmin(diffy, axis=1)
x_distances = np.abs(station_x[:, np.newaxis] - gridx)
x_indices = np.argmin(x_distances, axis=1)
y_distances = np.abs(station_y[:, np.newaxis] - gridy)
y_indices = np.argmin(y_distances, axis=1)

mask, duplication_indexes = construct_mask(indx, indy, shape)
mask, duplication_indexes = construct_mask(x_indices, y_indices, shape)
return mask, duplication_indexes


def process_inputs(station_config, grid_config):
def parse_stations(station_config: dict[str, Any]) -> pd.DataFrame:
"""Read, filter, and normalize station DataFrame to canonical column names."""
logger.debug(f"Reading station file, {station_config}")
if "name" not in station_config:
raise ValueError("Station config must include a 'name' key mapping to the station column")
df = pd.read_csv(station_config["file"])
filters = station_config.get("filter")
if filters is not None:
logger.debug(f"Applying filters: {filters} to station DataFrame")
df = df.query(filters)
station_names = df[station_config["name"]].values

index_config = station_config.get("index", None)
coords_config = station_config.get("coords", None)
if len(df) == 0:
raise ValueError("No stations found. Check station file or filter.")

has_index = "index" in station_config
has_coords = "coords" in station_config
has_index_1d = "index_1d" in station_config

if not has_index_1d:
if has_index and has_coords:
raise ValueError("Station config must use either 'index' or 'coords', not both.")
if not has_index and not has_coords:
raise ValueError("Station config must provide either 'index' or 'coords' for station mapping.")

if index_config is not None and coords_config is not None:
raise ValueError("Use either index or coords, not both.")
renames = {}
renames[station_config["name"]] = "station_name"

da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config)
if has_index:
index_config = station_config["index"]
x_col = index_config.get("x", "opt_x_index")
y_col = index_config.get("y", "opt_y_index")
renames[x_col] = "x_index"
renames[y_col] = "y_index"

if index_config is not None:
mask, duplication_indexes = create_mask_from_index(index_config, df, shape)
elif coords_config is not None:
mask, duplication_indexes = create_mask_from_coords(
coords_config, df, da[gridx_colname].values, da[gridy_colname].values, shape
if has_coords:
coords_config = station_config["coords"]
x_col = coords_config.get("x", "opt_x_coord")
y_col = coords_config.get("y", "opt_y_coord")
renames[x_col] = "x_coord"
renames[y_col] = "y_coord"

if has_index_1d:
renames[station_config["index_1d"]] = "index_1d"

df_renamed = df.rename(columns=renames)

if has_index and ("x_index" not in df_renamed.columns or "y_index" not in df_renamed.columns):
raise ValueError(
"Station file missing required index columns. Expected columns to map to 'x_index' and 'y_index'."
)
if has_coords and ("x_coord" not in df_renamed.columns or "y_coord" not in df_renamed.columns):
raise ValueError(
"Station file missing required coordinate columns. Expected columns to map to 'x_coord' and 'y_coord'."
)
if has_index_1d and "index_1d" not in df_renamed.columns:
raise ValueError("Station file missing required 'index_1d' column.")

return df_renamed


def _process_gribjump(grid_config: dict[str, Any], df: pd.DataFrame) -> xr.Dataset:
if "index_1d" not in df.columns:
raise ValueError("Gribjump source requires 'index_1d' in station config.")

station_names = df["station_name"].values
unique_indices, duplication_indexes = np.unique(df["index_1d"].values, return_inverse=True) # type: ignore[call-overload]

# Converting indices to ranges is currently faster than using indices
# directly. This is a problem in the earthkit-data gribjump source and will
# be fixed there.
ranges = [(i, i + 1) for i in unique_indices]

gribjump_config = {
"source": {
"gribjump": {
**grid_config["source"]["gribjump"],
"ranges": ranges,
# fetch_coords_from_fdb is currently very slow. Needs fix in
# earthkit-data gribjump source.
# "fetch_coords_from_fdb": True,
}
},
"to_xarray_options": grid_config.get("to_xarray_options", {}),
}

masked_da, var_name = load_da(gribjump_config, 2)

ds = xr.Dataset({var_name: masked_da})
ds = ds.isel(index=duplication_indexes)
ds = ds.rename({"index": "station"})
ds["station"] = station_names
return ds


def _process_regular(grid_config: dict[str, Any], df: pd.DataFrame) -> xr.Dataset:
station_names = df["station_name"].values
da, var_name, x_dim, y_dim, shape = process_grid_inputs(grid_config)

use_index = "x_index" in df.columns and "y_index" in df.columns

if use_index:
mask, duplication_indexes = create_mask_from_index(df, shape)
else:
# default to index approach
mask, duplication_indexes = create_mask_from_index(index_config, df, shape)
mask, duplication_indexes = create_mask_from_coords(df, da[x_dim].values, da[y_dim].values, shape)

logger.info("Extracting timeseries at selected stations")
masked_da = apply_mask(da, mask, x_dim, y_dim)

ds = xr.Dataset({var_name: masked_da})
ds = ds.isel(index=duplication_indexes)
ds = ds.rename({"index": "station"})
ds["station"] = station_names
return ds

return da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes

def process_inputs(station_config: dict[str, Any], grid_config: dict[str, Any]) -> xr.Dataset:
df = parse_stations(station_config)
if "gribjump" in grid_config.get("source", {}):
return _process_gribjump(grid_config, df)
return _process_regular(grid_config, df)

def mask_array_np(arr, mask):

def mask_array_np(arr: np.ndarray, mask: np.ndarray) -> np.ndarray:
return arr[..., mask]


def apply_mask(da, mask, coordx, coordy):
def apply_mask(da: xr.DataArray, mask: np.ndarray, coordx: str, coordy: str) -> xr.DataArray:
task = xr.apply_ufunc(
mask_array_np,
da,
mask,
input_core_dims=[(coordx, coordy), (coordx, coordy)],
output_core_dims=[["station"]],
output_core_dims=[["index"]],
output_dtypes=[da.dtype],
exclude_dims={coordx, coordy},
dask="parallelized",
dask_gufunc_kwargs={
"output_sizes": {"station": int(mask.sum())},
"output_sizes": {"index": int(mask.sum())},
"allow_rechunk": True,
},
)
with ProgressBar(dt=15):
return task.compute()


def extractor(config):
da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes = process_inputs(
config["station"], config["grid"]
)
logger.info("Extracting timeseries at selected stations")
masked_da = apply_mask(da, mask, gridx_colname, gridy_colname)
ds = xr.Dataset({da_varname: masked_da})
ds = ds.isel(station=duplication_indexes)
ds["station"] = station_names
def extractor(config: dict[str, Any]) -> xr.Dataset:
ds = process_inputs(config["station"], config["grid"])
if config.get("output", None) is not None:
logger.info(f"Saving output to {config['output']['file']}")
ds.to_netcdf(config["output"]["file"])
Expand Down
4 changes: 3 additions & 1 deletion hat/station_mapping/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@ def apply_blacklist(blacklist_config, metric_grid, grid_area_coords1, grid_area_
return metric_grid, grid_area_coords1, grid_area_coords2


def outputs_to_df(df, indx, indy, cindx, cindy, errors, grid_area_coords1, grid_area_coords2, filename):
def outputs_to_df(df, indx, indy, cindx, cindy, errors, grid_area_coords1, grid_area_coords2, shape, filename):
df["opt_x_index"] = indx
df["opt_y_index"] = indy
df["near_x_index"] = cindx
df["near_y_index"] = cindy
df["opt_error"] = errors
df["opt_x_coord"] = grid_area_coords1[indx, 0]
df["opt_y_coord"] = grid_area_coords2[0, indy]
df["opt_1d_index"] = indy + shape[1] * indx
if filename is not None:
df.to_csv(filename, index=False)
return df
Expand Down Expand Up @@ -109,6 +110,7 @@ def mapper(config):
*mapping_outputs,
grid_area_coords1,
grid_area_coords2,
shape=grid_area_coords1.shape,
filename=config["output"]["file"] if config.get("output", None) is not None else None,
)
generate_summary_plots(df, config.get("plot", None))
Expand Down
6 changes: 3 additions & 3 deletions notebooks/workflow/hydrostats_computation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
"source": [
"config = {\n",
" \"sim\": {\n",
" \"source\": [\"file\", \"extracted_timeseries.nc\"],\n",
" \"source\": {\"file\": \"extracted_timeseries.nc\"},\n",
" \"coords\": {\n",
" \"s\": \"station\",\n",
" \"t\": \"time\"\n",
" }\n",
" },\n",
" \"obs\": {\n",
" \"source\": [\"file\", \"observations.nc\"],\n",
" \"source\": {\"file\": \"observations.nc\"},\n",
" \"coords\": {\n",
" \"s\": \"station\",\n",
" \"t\": \"time\"\n",
Expand All @@ -49,7 +49,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "hat",
"language": "python",
"name": "python3"
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/workflow/timeseries_extraction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
" \"name\": \"station_id\"\n",
" },\n",
" \"grid\": {\n",
" \"source\": [\"file\", \"./sim.nc\"],\n",
" \"source\": {\"file\": \"./sim.nc\"},\n",
" \"coords\": {\n",
" \"x\": \"lat\",\n",
" \"y\": \"lon\",\n",
Expand Down
Loading