Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ def process(
if len(missing_features) > 0:
msg = (
"The features requested in the feature_config are absent from "
"the forecast parquet file and the input cubes. The missing fields are: "
f"{missing_features}."
"the forecast parquet file and the input cubes. "
f"The missing fields are: {missing_features}."
)
raise ValueError(msg)
return forecast_df, truth_df, cube_inputs
Expand Down Expand Up @@ -421,9 +421,12 @@ def _add_static_features_from_cubes_to_df(
constr = iris.Constraint(name=feature_name)
# Static features can be provided either as a cube or as a column
# in the forecast DataFrame.
if not cube_inputs.extract(constr):
try:
feature_cube = cube_inputs.extract_cube(constr)
except iris.exceptions.ConstraintMismatchError:
feature_cube = None
if not feature_cube:
continue
feature_cube = cube_inputs.extract_cube(constr)
feature_df = as_data_frame(feature_cube, add_aux_coords=True)
forecast_df = forecast_df.merge(
feature_df[[*self.unique_site_id_keys, feature_name]],
Expand All @@ -448,14 +451,18 @@ def filter_bad_sites(
- DataFrame containing the forecast data with bad sites removed.
- DataFrame containing the truth data with bad sites removed.
"""
truth_df.dropna(subset=["ob_value"], inplace=True)
truth_df.dropna(subset=["ob_value"] + [*self.unique_site_id_keys], inplace=True)

if truth_df.empty:
msg = "Empty truth DataFrame after removing NaNs."
raise ValueError(msg)

forecast_index = forecast_df.set_index([*self.unique_site_id_keys]).index
truth_index = truth_df.set_index([*self.unique_site_id_keys]).index
# Include time in the index, so that forecasts will be dropped if they
# correspond to a site and time that is not in the truth data.
forecast_index = forecast_df.set_index(
[*self.unique_site_id_keys] + ["time"]
).index
truth_index = truth_df.set_index([*self.unique_site_id_keys] + ["time"]).index
forecast_df = forecast_df[forecast_index.isin(truth_index)]
truth_df = truth_df[truth_index.isin(forecast_index)]

Expand Down
12 changes: 3 additions & 9 deletions improver/calibration/quantile_regression_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,9 @@ def prep_feature(

# For a subset of the input DataFrame compute the mean or standard deviation
# over the representation column, grouped by the groupby columns.
if feature_name == "mean":
subset_df = df[subset_cols].groupby(groupby_cols).mean()
elif feature_name == "std":
subset_df = df[subset_cols].groupby(groupby_cols).std()
elif feature_name == "min":
subset_df = df[subset_cols].groupby(groupby_cols).min()
elif feature_name == "max":
subset_df = df[subset_cols].groupby(groupby_cols).max()
if feature_name in ["min", "max", "mean", "std"]:
subset_grouped = df[subset_cols].groupby(groupby_cols)
subset_df = getattr(subset_grouped, feature_name)()
elif feature_name.startswith("members_below"):
threshold = float(feature_name.split("_")[2])
if transformation is not None:
Expand All @@ -119,7 +114,6 @@ def prep_feature(
)
subset_df.rename(variable_name, inplace=True)
subset_df = subset_df.astype(orig_dtype)
# subset_df[variable_name].astype(orig_dtype)
elif feature_name.startswith("members_above"):
threshold = float(feature_name.split("_")[2])
if transformation is not None:
Expand Down
9 changes: 5 additions & 4 deletions improver/cli/apply_quantile_regression_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ def process(
Args:
file_paths (cli.inputpaths):
A list of input paths containing:
- The path to a QRF trained model in pickle file format to be used
for calibration.
- The path to the pickle file produced by training the QRF model.
The pickle file contains the QRF model and the transformation and
pre_transform_addition values if a transformation was applied. If no
transformation was applied then the transformation and
pre_transform_addition values will be None and 0, respectively.
- The path to a NetCDF file containing the forecast to be calibrated.
- Optionally, paths to NetCDF files containing additional predictors.
feature_config (dict):
Expand All @@ -50,8 +53,6 @@ def process(
A string containing the CF name of the forecast to be
calibrated e.g. air_temperature. This will be used to separate it from
the rest of the feature cubes, if present.
The names of the coordinates that uniquely identify each site,
e.g. "wmo_id" or "latitude,longitude".
unique_site_id_keys (str):
The names of the coordinates that uniquely identify each site,
e.g. "wmo_id" or "latitude,longitude".
Expand Down
2 changes: 2 additions & 0 deletions improver/cli/train_quantile_regression_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,7 @@ def process(
unique_site_id_keys=unique_site_id_keys,
**kwargs,
)(forecast_df, truth_df, cube_inputs)
if result == (None, None, None):
return None

return result
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,51 @@ def test_missing_inputs(
]

assert run_cli(compulsory_args + named_args) is None
# Check no file has been written to disk.
assert not output_path.exists()


def test_invalid_cycletime(
tmp_path,
):
"""
Test train-quantile-regression-random-forest CLI when no training can occur
because there is no valid data in the parquet files to calibrate the cycletime
provided.
"""
kgo_dir = acc.kgo_root() / CLI
config_path = kgo_dir / "config.json"
output_path = tmp_path / "output.pickle"
history_path = kgo_dir / "spot_calibration_tables"
truth_path = kgo_dir / "spot_observation_tables"
compulsory_args = [history_path, truth_path]
named_args = [
"--feature-config",
config_path,
"--parquet-diagnostic-names",
"temperature_at_screen_level",
"--target-cf-name",
"air_temperature",
"--forecast-periods",
"6:18:6",
"--cycletime",
"20250704T0000Z",
"--training-length",
"2",
"--experiment",
"mix-latestblend",
"--n-estimators",
"10",
"--max-depth",
"5",
"--random-state",
"42",
"--compression-level",
"5",
"--output",
output_path,
]

assert run_cli(compulsory_args + named_args) is None
# Check no file has been written to disk.
assert not output_path.exists()
Loading