diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 8449b053b..1cdd35c63 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -92,14 +92,6 @@ def __init__( self.batch_size = get_batch_size_from_config(mode_cfg) self.len_timedelta: np.timedelta64 = mode_cfg.time_window_len self.step_timedelta: np.timedelta64 = mode_cfg.time_window_step - self.time_window_handler = TimeWindowHandler( - mode_cfg.start_date, mode_cfg.end_date, self.len_timedelta, self.step_timedelta - ) - if is_root(): - logger.info(self.time_window_handler) - - index_range = self.time_window_handler.get_index_range() - perms_len = int(index_range.end - index_range.start) # Handle forecast_delta_hrs which might be int (hours) or string (timedelta) self.forecast_cfg = mode_cfg.get("forecast", {}) @@ -125,6 +117,28 @@ def __init__( fsm = self.list_num_forecast_steps[0] forecast_len = (self.time_step * (fsm + 1)) // self.step_timedelta + + # Handle too short time window causing duplication of samples + available_samples = (mode_cfg.end_date - mode_cfg.start_date) // self.time_step + if self.samples_per_mini_epoch >= available_samples: + # padding to widen time window + samples_diff = ( + self.samples_per_mini_epoch - available_samples + forecast_len + 1 + ) # extra step to be safe + new_end_date = mode_cfg.end_date + (samples_diff * self.time_step) + logger.warning(f"Using adjusted end date {new_end_date} instead of {mode_cfg.end_date}") + self.time_window_handler = TimeWindowHandler( + mode_cfg.start_date, new_end_date, self.len_timedelta, self.step_timedelta + ) + else: + self.time_window_handler = TimeWindowHandler( + mode_cfg.start_date, mode_cfg.end_date, self.len_timedelta, self.step_timedelta + ) + if is_root(): + logger.info(self.time_window_handler) + + index_range = self.time_window_handler.get_index_range() + perms_len = int(index_range.end - index_range.start) perms_len = perms_len - (forecast_len + self.output_offset) self.repeat_data = cf.data_loading.get("repeat_data_in_mini_epoch", False)