From cf80a6c50122088c1b9d49545b6bf969a31f4524 Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Wed, 28 Jan 2026 14:36:39 +0100 Subject: [PATCH 1/2] generate zero tensors rather than None for case with no target channels, eg ft jepa model for synop preds --- src/weathergen/utils/validation_io.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 94d1d0c08..5199861b9 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -60,6 +60,20 @@ def write_output( preds_s, targets_s, t_coords_s, t_times_s = [], [], [], [] targets_lens[-1] += [[]] + # for the case where no target channels are defined + # for a stream, e.g., when only using it as input + # we produce None preds and we need to handle this case + if preds is None or targets is None: + num_channels = len(stream_info.get("val_target_channels", [])) + ens_size = int(stream_info.get("pred_head", {}).get("ens_size", 1)) + + preds_all[-1] += [np.zeros((ens_size, 0, num_channels), dtype=np.float32)] + targets_all[-1] += [np.zeros((0, num_channels), dtype=np.float32)] + targets_coords_all[-1] += [np.zeros((0, 2), dtype=np.float32)] + targets_times_all[-1] += [np.array([], dtype="datetime64[ns]")] + targets_lens[-1] += [[0 for _ in range(batch_size)]] + continue + for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)): pred, target = pred.to(fp32), target.to(fp32) From 4e161700808a9a07b0d540aba11ea9a9f8a39c46 Mon Sep 17 00:00:00 2001 From: Seb Hickman <56727418+shmh40@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:18:46 +0000 Subject: [PATCH 2/2] Fix targets_lens assignment for batch size --- src/weathergen/utils/validation_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 5199861b9..a1f4d3834 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -71,7 +71,7 @@ def write_output( targets_all[-1] += [np.zeros((0, num_channels), dtype=np.float32)] targets_coords_all[-1] += [np.zeros((0, 2), dtype=np.float32)] targets_times_all[-1] += [np.array([], dtype="datetime64[ns]")] - targets_lens[-1] += [[0 for _ in range(batch_size)]] + targets_lens[-1][-1] = [0 for _ in range(batch_size)] continue for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)):