From 6c26d0ff8f30c37a617d1b84938f925da098e63d Mon Sep 17 00:00:00 2001 From: Sorcha Owens Date: Thu, 29 Jan 2026 17:13:51 +0100 Subject: [PATCH] adding padding to end_date to avoid duplicate samples --- .../datasets/multi_stream_data_sampler.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 8449b053b..1cdd35c63 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -92,14 +92,6 @@ def __init__( self.batch_size = get_batch_size_from_config(mode_cfg) self.len_timedelta: np.timedelta64 = mode_cfg.time_window_len self.step_timedelta: np.timedelta64 = mode_cfg.time_window_step - self.time_window_handler = TimeWindowHandler( - mode_cfg.start_date, mode_cfg.end_date, self.len_timedelta, self.step_timedelta - ) - if is_root(): - logger.info(self.time_window_handler) - - index_range = self.time_window_handler.get_index_range() - perms_len = int(index_range.end - index_range.start) # Handle forecast_delta_hrs which might be int (hours) or string (timedelta) self.forecast_cfg = mode_cfg.get("forecast", {}) @@ -125,6 +117,28 @@ def __init__( fsm = self.list_num_forecast_steps[0] forecast_len = (self.time_step * (fsm + 1)) // self.step_timedelta + + # Handle too short time window causing duplication of samples + available_samples = (mode_cfg.end_date - mode_cfg.start_date) // self.time_step + if self.samples_per_mini_epoch >= available_samples: + # padding to widen time window + samples_diff = ( + self.samples_per_mini_epoch - available_samples + forecast_len + 1 + ) # extra step to be safe + new_end_date = mode_cfg.end_date + (samples_diff * self.time_step) + logger.warning(f"Using adjusted end date {new_end_date} instead of {mode_cfg.end_date}") + self.time_window_handler = TimeWindowHandler( + mode_cfg.start_date, new_end_date, self.len_timedelta, self.step_timedelta + ) + else: + self.time_window_handler = TimeWindowHandler( + mode_cfg.start_date, mode_cfg.end_date, self.len_timedelta, self.step_timedelta + ) + if is_root(): + logger.info(self.time_window_handler) + + index_range = self.time_window_handler.get_index_range() + perms_len = int(index_range.end - index_range.start) perms_len = perms_len - (forecast_len + self.output_offset) self.repeat_data = cf.data_loading.get("repeat_data_in_mini_epoch", False)