diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 94d1d0c08..a1f4d3834 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -60,6 +60,20 @@ def write_output( preds_s, targets_s, t_coords_s, t_times_s = [], [], [], [] targets_lens[-1] += [[]] + # for the case where no target channels are defined + # for a stream, e.g., when only using it as input + # we produce None preds and we need to handle this case + if preds is None or targets is None: + num_channels = len(stream_info.get("val_target_channels", [])) + ens_size = int(stream_info.get("pred_head", {}).get("ens_size", 1)) + + preds_all[-1] += [np.zeros((ens_size, 0, num_channels), dtype=np.float32)] + targets_all[-1] += [np.zeros((0, num_channels), dtype=np.float32)] + targets_coords_all[-1] += [np.zeros((0, 2), dtype=np.float32)] + targets_times_all[-1] += [np.array([], dtype="datetime64[ns]")] + targets_lens[-1][-1] = [0 for _ in range(batch_size)] + continue + for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)): pred, target = pred.to(fp32), target.to(fp32)