Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,6 @@ def __init__(
self.batch_size = get_batch_size_from_config(mode_cfg)
self.len_timedelta: np.timedelta64 = mode_cfg.time_window_len
self.step_timedelta: np.timedelta64 = mode_cfg.time_window_step
self.time_window_handler = TimeWindowHandler(
mode_cfg.start_date, mode_cfg.end_date, self.len_timedelta, self.step_timedelta
)
if is_root():
logger.info(self.time_window_handler)

index_range = self.time_window_handler.get_index_range()
perms_len = int(index_range.end - index_range.start)

# Handle forecast_delta_hrs which might be int (hours) or string (timedelta)
self.forecast_cfg = mode_cfg.get("forecast", {})
Expand All @@ -125,6 +117,28 @@ def __init__(

fsm = self.list_num_forecast_steps[0]
forecast_len = (self.time_step * (fsm + 1)) // self.step_timedelta

# Handle too short time window causing duplication of samples
available_samples = (mode_cfg.end_date - mode_cfg.start_date) // self.time_step
if self.samples_per_mini_epoch >= available_samples:
# padding to widen time window
samples_diff = (
self.samples_per_mini_epoch - available_samples + forecast_len + 1
) # extra step to be safe
new_end_date = mode_cfg.end_date + (samples_diff * self.time_step)
logger.warning(f"Using adjusted end date {new_end_date} instead of {mode_cfg.end_date}")
self.time_window_handler = TimeWindowHandler(
mode_cfg.start_date, new_end_date, self.len_timedelta, self.step_timedelta
)
else:
self.time_window_handler = TimeWindowHandler(
mode_cfg.start_date, mode_cfg.end_date, self.len_timedelta, self.step_timedelta
)
if is_root():
logger.info(self.time_window_handler)

index_range = self.time_window_handler.get_index_range()
perms_len = int(index_range.end - index_range.start)
perms_len = perms_len - (forecast_len + self.output_offset)

self.repeat_data = cf.data_loading.get("repeat_data_in_mini_epoch", False)
Expand Down
Loading