diff --git a/README.md b/README.md index 7585587..6a420fa 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,18 @@ pip install -e .[dev] pre-commit install ``` +HAT provides **experimental** support for earthkit-data's [gribjump source](https://earthkit-data.readthedocs.io/en/latest/guide/sources.html#gribjump). +To install the gribjump extras for testing and experimentation, run: +```bash +pip install hydro-analysis-toolkit[gribjump] +``` + +> [!NOTE] +> The gribjump feature is experimental. It is not recommended for production use and may change or break in future releases. +> Information on how to build gribjump can be found in [GribJump's source code](https://github.com/ecmwf/gribjump/). Experimental +> wheels of `gribjumplib` can also be found [on PyPI](https://pypi.org/project/gribjumplib/). + + ## Licence ``` diff --git a/hat/compute_hydrostats/stat_calc.py b/hat/compute_hydrostats/stat_calc.py index fd147ff..4da6a15 100644 --- a/hat/compute_hydrostats/stat_calc.py +++ b/hat/compute_hydrostats/stat_calc.py @@ -1,17 +1,9 @@ -import earthkit.data as ekd -from earthkit.hydro._readers import find_main_var +from hat.core import load_da import numpy as np import xarray as xr from hat.compute_hydrostats import stats -def load_da(ds_config): - ds = ekd.from_source(*ds_config["source"]).to_xarray() - var_name = find_main_var(ds, 2) - da = ds[var_name] - return da - - def find_valid_subset(sim_da, obs_da, sim_coords, obs_coords, new_coords): sim_station_colname = sim_coords.get("s", "station") obs_station_colname = obs_coords.get("s", "station") @@ -35,9 +27,9 @@ def find_valid_subset(sim_da, obs_da, sim_coords, obs_coords, new_coords): def stat_calc(config): sim_config = config["sim"] - sim_da = load_da(config["sim"]) + sim_da, _ = load_da(sim_config, 2) obs_config = config["obs"] - obs_da = load_da(obs_config) + obs_da, _ = load_da(obs_config, 2) new_coords = config["output"]["coords"] sim_da, obs_da = find_valid_subset(sim_da, obs_da, sim_config["coords"], obs_config["coords"], new_coords) stat_dict = {} diff --git a/hat/core.py b/hat/core.py new file mode 100644 index 0000000..12213a5 --- /dev/null +++ b/hat/core.py @@ -0,0 +1,11 @@ +import earthkit.data as ekd +from earthkit.hydro._readers import find_main_var + + +def load_da(ds_config, n_dims): + src_name = list(ds_config["source"].keys())[0] + source = ekd.from_source(src_name, **ds_config["source"][src_name]) + ds = source.to_xarray(**ds_config.get("to_xarray_options", {})) + var_name = find_main_var(ds, n_dims) + da = ds[var_name] + return da, var_name diff --git a/hat/extract_timeseries/extractor.py b/hat/extract_timeseries/extractor.py index 6e557bb..7b76a2b 100644 --- a/hat/extract_timeseries/extractor.py +++ b/hat/extract_timeseries/extractor.py @@ -2,111 +2,200 @@ import pandas as pd import xarray as xr import numpy as np -import earthkit.data as ekd -from earthkit.hydro._readers import find_main_var +from typing import Any +from hat.core import load_da from hat import _LOGGER as logger def process_grid_inputs(grid_config): - src_name = list(grid_config["source"].keys())[0] - logger.info(f"Processing grid inputs from source: {src_name}") - logger.debug(f"Grid config: {grid_config['source'][src_name]}") - ds = ekd.from_source(src_name, **grid_config["source"][src_name]).to_xarray( - **grid_config.get("to_xarray_options", {}) - ) - var_name = find_main_var(ds, 3) - da = ds[var_name] + da, var_name = load_da(grid_config, 3) logger.info(f"Xarray created from source:\n{da}\n") - gridx_colname = grid_config.get("coord_x", "lat") - gridy_colname = grid_config.get("coord_y", "lon") - da = da.sortby([gridx_colname, gridy_colname]) - shape = da[gridx_colname].shape[0], da[gridy_colname].shape[0] - return da, var_name, gridx_colname, gridy_colname, shape + coord_config = grid_config.get("coords", {}) + x_dim = coord_config.get("x", "lat") + y_dim = coord_config.get("y", "lon") + da = da.sortby([x_dim, y_dim]) + shape = da[x_dim].shape[0], da[y_dim].shape[0] + return da, var_name, x_dim, y_dim, shape -def construct_mask(indx, indy, shape): +def construct_mask(x_indices, y_indices, shape): mask = np.zeros(shape, dtype=bool) - mask[indx, indy] = True + mask[x_indices, y_indices] = True - flat_indices = np.ravel_multi_index((indx, indy), shape) - _, inverse = np.unique(flat_indices, return_inverse=True) - return mask, inverse + flat_indices = np.ravel_multi_index((x_indices, y_indices), shape) + _, duplication_indexes = np.unique(flat_indices, return_inverse=True) + return mask, duplication_indexes -def create_mask_from_index(index_config, df, shape): - logger.info(f"Creating mask {shape} from index: {index_config}") +def create_mask_from_index(df, shape): + logger.info(f"Creating mask {shape} from index") logger.debug(f"DataFrame columns: {df.columns.tolist()}") - indx_colname = index_config.get("x", "opt_x_index") - indy_colname = index_config.get("y", "opt_y_index") - indx, indy = df[indx_colname].values, df[indy_colname].values - mask, duplication_indexes = construct_mask(indx, indy, shape) + x_indices = df["x_index"].values + y_indices = df["y_index"].values + if np.any(x_indices < 0) or np.any(x_indices >= shape[0]) or np.any(y_indices < 0) or np.any(y_indices >= shape[1]): + raise ValueError( + f"Station indices out of grid bounds. Grid shape={shape}, " + f"x_index range=({int(x_indices.min())},{int(x_indices.max())}), " + f"y_index range=({int(y_indices.min())},{int(y_indices.max())})" + ) + mask, duplication_indexes = construct_mask(x_indices, y_indices, shape) return mask, duplication_indexes -def create_mask_from_coords(coords_config, df, gridx, gridy, shape): - logger.info(f"Creating mask {shape} from coordinates: {coords_config}") +def create_mask_from_coords(df, gridx, gridy, shape): + logger.info(f"Creating mask {shape} from coordinates") logger.debug(f"DataFrame columns: {df.columns.tolist()}") - x_colname = coords_config.get("x", "opt_x_coord") - y_colname = coords_config.get("y", "opt_y_coord") - xs = df[x_colname].values - ys = df[y_colname].values + station_x = df["x_coord"].values + station_y = df["y_coord"].values - diffx = np.abs(xs[:, np.newaxis] - gridx) - indx = np.argmin(diffx, axis=1) - diffy = np.abs(ys[:, np.newaxis] - gridy) - indy = np.argmin(diffy, axis=1) + x_distances = np.abs(station_x[:, np.newaxis] - gridx) + x_indices = np.argmin(x_distances, axis=1) + y_distances = np.abs(station_y[:, np.newaxis] - gridy) + y_indices = np.argmin(y_distances, axis=1) - mask, duplication_indexes = construct_mask(indx, indy, shape) + mask, duplication_indexes = construct_mask(x_indices, y_indices, shape) return mask, duplication_indexes -def process_inputs(station_config, grid_config): +def parse_stations(station_config: dict[str, Any]) -> pd.DataFrame: + """Read, filter, and normalize station DataFrame to canonical column names.""" logger.debug(f"Reading station file, {station_config}") + if "name" not in station_config: + raise ValueError("Station config must include a 'name' key mapping to the station column") df = pd.read_csv(station_config["file"]) filters = station_config.get("filter") if filters is not None: logger.debug(f"Applying filters: {filters} to station DataFrame") df = df.query(filters) - station_names = df[station_config["name"]].values - index_config = station_config.get("index", None) - coords_config = station_config.get("coords", None) + if len(df) == 0: + raise ValueError("No stations found. Check station file or filter.") + + has_index = "index" in station_config + has_coords = "coords" in station_config + has_index_1d = "index_1d" in station_config + + if not has_index_1d: + if has_index and has_coords: + raise ValueError("Station config must use either 'index' or 'coords', not both.") + if not has_index and not has_coords: + raise ValueError("Station config must provide either 'index' or 'coords' for station mapping.") - if index_config is not None and coords_config is not None: - raise ValueError("Use either index or coords, not both.") + renames = {} + renames[station_config["name"]] = "station_name" - da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config) + if has_index: + index_config = station_config["index"] + x_col = index_config.get("x", "opt_x_index") + y_col = index_config.get("y", "opt_y_index") + renames[x_col] = "x_index" + renames[y_col] = "y_index" - if index_config is not None: - mask, duplication_indexes = create_mask_from_index(index_config, df, shape) - elif coords_config is not None: - mask, duplication_indexes = create_mask_from_coords( - coords_config, df, da[gridx_colname].values, da[gridy_colname].values, shape + if has_coords: + coords_config = station_config["coords"] + x_col = coords_config.get("x", "opt_x_coord") + y_col = coords_config.get("y", "opt_y_coord") + renames[x_col] = "x_coord" + renames[y_col] = "y_coord" + + if has_index_1d: + renames[station_config["index_1d"]] = "index_1d" + + df_renamed = df.rename(columns=renames) + + if has_index and ("x_index" not in df_renamed.columns or "y_index" not in df_renamed.columns): + raise ValueError( + "Station file missing required index columns. Expected columns to map to 'x_index' and 'y_index'." + ) + if has_coords and ("x_coord" not in df_renamed.columns or "y_coord" not in df_renamed.columns): + raise ValueError( + "Station file missing required coordinate columns. Expected columns to map to 'x_coord' and 'y_coord'." ) + if has_index_1d and "index_1d" not in df_renamed.columns: + raise ValueError("Station file missing required 'index_1d' column.") + + return df_renamed + + +def _process_gribjump(grid_config: dict[str, Any], df: pd.DataFrame) -> xr.Dataset: + if "index_1d" not in df.columns: + raise ValueError("Gribjump source requires 'index_1d' in station config.") + + station_names = df["station_name"].values + unique_indices, duplication_indexes = np.unique(df["index_1d"].values, return_inverse=True) # type: ignore[call-overload] + + # Converting indices to ranges is currently faster than using indices + # directly. This is a problem in the earthkit-data gribjump source and will + # be fixed there. + ranges = [(i, i + 1) for i in unique_indices] + + gribjump_config = { + "source": { + "gribjump": { + **grid_config["source"]["gribjump"], + "ranges": ranges, + # fetch_coords_from_fdb is currently very slow. Needs fix in + # earthkit-data gribjump source. + # "fetch_coords_from_fdb": True, + } + }, + "to_xarray_options": grid_config.get("to_xarray_options", {}), + } + + masked_da, var_name = load_da(gribjump_config, 2) + + ds = xr.Dataset({var_name: masked_da}) + ds = ds.isel(index=duplication_indexes) + ds = ds.rename({"index": "station"}) + ds["station"] = station_names + return ds + + +def _process_regular(grid_config: dict[str, Any], df: pd.DataFrame) -> xr.Dataset: + station_names = df["station_name"].values + da, var_name, x_dim, y_dim, shape = process_grid_inputs(grid_config) + + use_index = "x_index" in df.columns and "y_index" in df.columns + + if use_index: + mask, duplication_indexes = create_mask_from_index(df, shape) else: - # default to index approach - mask, duplication_indexes = create_mask_from_index(index_config, df, shape) + mask, duplication_indexes = create_mask_from_coords(df, da[x_dim].values, da[y_dim].values, shape) + + logger.info("Extracting timeseries at selected stations") + masked_da = apply_mask(da, mask, x_dim, y_dim) + + ds = xr.Dataset({var_name: masked_da}) + ds = ds.isel(index=duplication_indexes) + ds = ds.rename({"index": "station"}) + ds["station"] = station_names + return ds - return da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes +def process_inputs(station_config: dict[str, Any], grid_config: dict[str, Any]) -> xr.Dataset: + df = parse_stations(station_config) + if "gribjump" in grid_config.get("source", {}): + return _process_gribjump(grid_config, df) + return _process_regular(grid_config, df) -def mask_array_np(arr, mask): + +def mask_array_np(arr: np.ndarray, mask: np.ndarray) -> np.ndarray: return arr[..., mask] -def apply_mask(da, mask, coordx, coordy): +def apply_mask(da: xr.DataArray, mask: np.ndarray, coordx: str, coordy: str) -> xr.DataArray: task = xr.apply_ufunc( mask_array_np, da, mask, input_core_dims=[(coordx, coordy), (coordx, coordy)], - output_core_dims=[["station"]], + output_core_dims=[["index"]], output_dtypes=[da.dtype], exclude_dims={coordx, coordy}, dask="parallelized", dask_gufunc_kwargs={ - "output_sizes": {"station": int(mask.sum())}, + "output_sizes": {"index": int(mask.sum())}, "allow_rechunk": True, }, ) @@ -114,15 +203,8 @@ def apply_mask(da, mask, coordx, coordy): return task.compute() -def extractor(config): - da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes = process_inputs( - config["station"], config["grid"] - ) - logger.info("Extracting timeseries at selected stations") - masked_da = apply_mask(da, mask, gridx_colname, gridy_colname) - ds = xr.Dataset({da_varname: masked_da}) - ds = ds.isel(station=duplication_indexes) - ds["station"] = station_names +def extractor(config: dict[str, Any]) -> xr.Dataset: + ds = process_inputs(config["station"], config["grid"]) if config.get("output", None) is not None: logger.info(f"Saving output to {config['output']['file']}") ds.to_netcdf(config["output"]["file"]) diff --git a/hat/station_mapping/mapper.py b/hat/station_mapping/mapper.py index 6fab72f..140bc63 100644 --- a/hat/station_mapping/mapper.py +++ b/hat/station_mapping/mapper.py @@ -47,7 +47,7 @@ def apply_blacklist(blacklist_config, metric_grid, grid_area_coords1, grid_area_ return metric_grid, grid_area_coords1, grid_area_coords2 -def outputs_to_df(df, indx, indy, cindx, cindy, errors, grid_area_coords1, grid_area_coords2, filename): +def outputs_to_df(df, indx, indy, cindx, cindy, errors, grid_area_coords1, grid_area_coords2, shape, filename): df["opt_x_index"] = indx df["opt_y_index"] = indy df["near_x_index"] = cindx @@ -55,6 +55,7 @@ def outputs_to_df(df, indx, indy, cindx, cindy, errors, grid_area_coords1, grid_ df["opt_error"] = errors df["opt_x_coord"] = grid_area_coords1[indx, 0] df["opt_y_coord"] = grid_area_coords2[0, indy] + df["opt_1d_index"] = indy + shape[1] * indx if filename is not None: df.to_csv(filename, index=False) return df @@ -109,6 +110,7 @@ def mapper(config): *mapping_outputs, grid_area_coords1, grid_area_coords2, + shape=grid_area_coords1.shape, filename=config["output"]["file"] if config.get("output", None) is not None else None, ) generate_summary_plots(df, config.get("plot", None)) diff --git a/notebooks/workflow/hydrostats_computation.ipynb b/notebooks/workflow/hydrostats_computation.ipynb index 2292595..bb82958 100644 --- a/notebooks/workflow/hydrostats_computation.ipynb +++ b/notebooks/workflow/hydrostats_computation.ipynb @@ -19,14 +19,14 @@ "source": [ "config = {\n", " \"sim\": {\n", - " \"source\": [\"file\", \"extracted_timeseries.nc\"],\n", + " \"source\": {\"file\": \"extracted_timeseries.nc\"},\n", " \"coords\": {\n", " \"s\": \"station\",\n", " \"t\": \"time\"\n", " }\n", " },\n", " \"obs\": {\n", - " \"source\": [\"file\", \"observations.nc\"],\n", + " \"source\": {\"file\": \"observations.nc\"},\n", " \"coords\": {\n", " \"s\": \"station\",\n", " \"t\": \"time\"\n", @@ -49,7 +49,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "hat", "language": "python", "name": "python3" }, diff --git a/notebooks/workflow/timeseries_extraction.ipynb b/notebooks/workflow/timeseries_extraction.ipynb index 21136f8..d5de09e 100644 --- a/notebooks/workflow/timeseries_extraction.ipynb +++ b/notebooks/workflow/timeseries_extraction.ipynb @@ -35,7 +35,7 @@ " \"name\": \"station_id\"\n", " },\n", " \"grid\": {\n", - " \"source\": [\"file\", \"./sim.nc\"],\n", + " \"source\": {\"file\": \"./sim.nc\"},\n", " \"coords\": {\n", " \"x\": \"lat\",\n", " \"y\": \"lon\",\n", diff --git a/pyproject.toml b/pyproject.toml index 490e424..5094c94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,8 +41,8 @@ dependencies = [ "tqdm", "ipyleaflet", "ipywidgets", - "earthkit-data>=0.13.8", - "earthkit-hydro", + "earthkit-data>=0.18.2", + "earthkit-hydro>=1.0.0", "earthkit-meteo", "cfgrib", # check if necessary "netCDF4", # check if necessary @@ -61,13 +61,18 @@ dependencies = [ [project.optional-dependencies] test = [ - "pytest" + "pytest", + "pytest-cov" ] dev = [ "pytest", + "pytest-cov", "ruff", "pre-commit" ] + gribjump = [ + "earthkit-data[gribjump]" + ] [project.scripts] hat-extract-timeseries = "hat.cli:extractor_cli" @@ -97,6 +102,11 @@ addopts = "--pdbcls=IPython.terminal.debugger:Pdb" testpaths = [ "tests", ] +filterwarnings = [ + # Probably harmless numpy 2.x ABI warning from netCDF4's Cython extension + # See: https://github.com/Unidata/netcdf4-python/issues/1354 + "ignore:numpy.ndarray size changed:RuntimeWarning", +] # Packaging/setuptools options [tool.setuptools] diff --git a/tests/test_extractor.py b/tests/test_extractor.py new file mode 100644 index 0000000..41ed39d --- /dev/null +++ b/tests/test_extractor.py @@ -0,0 +1,288 @@ +"""Unit tests for the extractor function.""" + +import pytest +import pandas as pd +import numpy as np +import xarray as xr +from unittest.mock import Mock, patch + +from hat.extract_timeseries.extractor import extractor + + +@pytest.fixture +def dummy_grid_data(): + """4x5 grid, 2 timesteps, temperature variable.""" + lats = np.array([40.0, 41.0, 42.0, 43.0]) + lons = np.array([10.0, 11.0, 12.0, 13.0, 14.0]) + + temperature_values = np.array( + [ + # t=0: 2024-01-01 + [ + [10.0, 11.0, 12.0, 13.0, 14.0], # lat=40 + [15.0, 16.0, 17.0, 18.0, 19.0], # lat=41 + [20.0, 21.0, 22.0, 23.0, 24.0], # lat=42 + [25.0, 26.0, 27.0, 28.0, 29.0], # lat=43 + ], + # t=1: 2024-01-02 + [ + [30.0, 31.0, 32.0, 33.0, 34.0], # lat=40 + [35.0, 36.0, 37.0, 38.0, 39.0], # lat=41 + [40.0, 41.0, 42.0, 43.0, 44.0], # lat=42 + [45.0, 46.0, 47.0, 48.0, 49.0], # lat=43 + ], + ] + ) + + list_of_dicts = [] + + list_of_dicts.append( + { + "values": temperature_values[0].flatten(), + "param": "temperature", + "date": 20240101, + "time": 0, + "distinctLatitudes": lats.tolist(), + "distinctLongitudes": lons.tolist(), + } + ) + + list_of_dicts.append( + { + "values": temperature_values[1].flatten(), + "param": "temperature", + "date": 20240102, + "time": 0, + "distinctLatitudes": lats.tolist(), + "distinctLongitudes": lons.tolist(), + } + ) + + return list_of_dicts + + +@pytest.fixture +def station_dataframe(): + """2 stations with both index and coordinate columns.""" + return pd.DataFrame( + { + "station_id": ["STATION_A", "STATION_B"], + "opt_x_index": [1, 2], + "opt_y_index": [2, 3], + "opt_x_coord": [41.1, 41.9], # offset to test nearest-neighbor + "opt_y_coord": [12.2, 13.1], + } + ) + + +@pytest.fixture +def station_csv_file(station_dataframe, tmp_path): + """Write station DataFrame to temporary CSV.""" + csv_path = tmp_path / "stations.csv" + station_dataframe.to_csv(csv_path, index=False) + return str(csv_path) + + +@pytest.mark.parametrize( + "mapping_config", + [ + {"index": {"x": "opt_x_index", "y": "opt_y_index"}}, + {"coords": {"x": "opt_x_coord", "y": "opt_y_coord"}}, + ], + ids=["index", "coords"], +) +def test_extractor_with_temperature(dummy_grid_data, station_csv_file, mapping_config): + """Test extraction with both index and coords mapping.""" + config = { + "station": { + "file": station_csv_file, + "name": "station_id", + **mapping_config, + }, + "grid": { + "source": { + "list-of-dicts": { + "list_of_dicts": dummy_grid_data, + } + }, + "coords": { + "x": "latitude", + "y": "longitude", + }, + }, + } + + result_ds = extractor(config) + + assert isinstance(result_ds, xr.Dataset) + assert "temperature" in result_ds.data_vars + assert "station" in result_ds.dims + assert len(result_ds.station) == 2 + assert list(result_ds.station.values) == ["STATION_A", "STATION_B"] + + time_dim = "time" if "time" in result_ds.dims else "forecast_reference_time" + assert len(result_ds[time_dim]) == 2 + + # Station A: indices [1,2] -> values [17.0, 37.0] + # Station B: indices [2,3] -> values [23.0, 43.0] + np.testing.assert_allclose(result_ds["temperature"].sel(station="STATION_A").values, [17.0, 37.0]) + np.testing.assert_allclose(result_ds["temperature"].sel(station="STATION_B").values, [23.0, 43.0]) + + +def test_extractor_with_station_filter(dummy_grid_data, tmp_path): + """Test station filtering.""" + df = pd.DataFrame( + { + "station_id": ["S1", "S2", "S3"], + "opt_x_index": [1, 2, 1], + "opt_y_index": [2, 3, 3], + "network": ["primary", "secondary", "primary"], + } + ) + csv_path = tmp_path / "stations.csv" + df.to_csv(csv_path, index=False) + + config = { + "station": { + "file": str(csv_path), + "name": "station_id", + "filter": "network == 'primary'", + "index": {"x": "opt_x_index", "y": "opt_y_index"}, + }, + "grid": { + "source": { + "list-of-dicts": { + "list_of_dicts": dummy_grid_data, + } + }, + "coords": { + "x": "latitude", + "y": "longitude", + }, + }, + } + + result_ds = extractor(config) + + assert len(result_ds.station) == 2 + assert list(result_ds.station.values) == ["S1", "S3"] + assert result_ds["temperature"].sel(station="S1").values[0] == 17.0 + assert result_ds["temperature"].sel(station="S3").values[0] == 18.0 + + +def test_extractor_rejects_both_index_and_coords(dummy_grid_data, station_csv_file): + """Test that providing both index and coords raises ValueError.""" + config = { + "station": { + "file": station_csv_file, + "name": "station_id", + "index": {"x": "opt_x_index", "y": "opt_y_index"}, + "coords": {"x": "opt_x_coord", "y": "opt_y_coord"}, # Both provided + }, + "grid": { + "source": {"list-of-dicts": {"list_of_dicts": dummy_grid_data}}, + "coords": {"x": "latitude", "y": "longitude"}, + }, + } + + with pytest.raises(ValueError, match="must use either 'index' or 'coords', not both"): + extractor(config) + + +def test_extractor_with_output_file(dummy_grid_data, station_csv_file, tmp_path): + """Test output file writing.""" + output_file = tmp_path / "output.nc" + + config = { + "station": { + "file": station_csv_file, + "name": "station_id", + "index": {"x": "opt_x_index", "y": "opt_y_index"}, + }, + "grid": { + "source": { + "list-of-dicts": { + "list_of_dicts": dummy_grid_data, + } + }, + "coords": { + "x": "latitude", + "y": "longitude", + }, + }, + "output": { + "file": str(output_file), + }, + } + + result_ds = extractor(config) + + assert output_file.exists() + + loaded_ds = xr.open_dataset(output_file) + assert "temperature" in loaded_ds.data_vars + assert "station" in loaded_ds.dims + assert len(loaded_ds.station) == 2 + + xr.testing.assert_allclose(result_ds["temperature"], loaded_ds["temperature"]) + xr.testing.assert_equal(result_ds.station, loaded_ds.station) + + loaded_ds.close() + + +def test_extractor_with_empty_stations(dummy_grid_data, tmp_path): + """Test that extractor raises clear error for empty station list.""" + empty_csv = tmp_path / "empty_stations.csv" + pd.DataFrame(columns=["station_id", "opt_x_index", "opt_y_index"]).to_csv(empty_csv, index=False) + + config = { + "station": { + "file": str(empty_csv), + "name": "station_id", + "index": {"x": "opt_x_index", "y": "opt_y_index"}, + }, + "grid": { + "source": {"list-of-dicts": {"list_of_dicts": dummy_grid_data}}, + "coords": {"x": "latitude", "y": "longitude"}, + }, + } + + with pytest.raises(ValueError, match="No stations found"): + extractor(config) + + +@patch("earthkit.data.from_source") +def test_extractor_gribjump(mock_from_source, tmp_path): + """Test gribjump path: verifies ranges computation and earthkit call.""" + + # Mock returns object with to_xarray() that returns minimal dataset + mock_source = Mock() + mock_source.to_xarray.return_value = xr.Dataset( + {"temperature": xr.DataArray([[15.0, 25.0], [35.0, 45.0]], dims=["index", "time"])} + ) + mock_from_source.return_value = mock_source + + # Station CSV with index_1d (includes duplicate to test deduplication) + csv_file = tmp_path / "stations.csv" + pd.DataFrame( + { + "name": ["S1", "S2", "S3"], + "idx": [100, 200, 100], # S1 and S3 share index 100 + } + ).to_csv(csv_file, index=False) + + config = { + "station": {"file": str(csv_file), "name": "name", "index_1d": "idx"}, + "grid": {"source": {"gribjump": {"request": {"class": "od", "expver": "0001", "stream": "oper"}}}}, + } + + result = extractor(config) + + # Verify earthkit.data.from_source was called correctly + mock_from_source.assert_called_once_with( + "gribjump", request={"class": "od", "expver": "0001", "stream": "oper"}, ranges=[(100, 101), (200, 201)] + ) + + # Verify output + assert len(result.station) == 3 + assert list(result.station.values) == ["S1", "S2", "S3"]