From 6c161b6578b69244d1cf4fd550a76ed91edfb8e3 Mon Sep 17 00:00:00 2001 From: MWesselkamp Date: Fri, 13 Feb 2026 08:28:26 +0100 Subject: [PATCH 1/2] transfer LST datareader and configs to develop --- config/evaluate/eval_config_lst.yml | 29 ++ config/lst_config.yml | 194 ++++++++++ config/streams/seviri_lst/era5.yml | 30 ++ config/streams/seviri_lst/seviri_lst.yml | 36 ++ .../src/weathergen/evaluate/utils/regions.py | 1 + src/weathergen/datasets/data_reader_seviri.py | 344 ++++++++++++++++++ .../datasets/multi_stream_data_sampler.py | 3 + 7 files changed, 637 insertions(+) create mode 100644 config/evaluate/eval_config_lst.yml create mode 100644 config/lst_config.yml create mode 100644 config/streams/seviri_lst/era5.yml create mode 100644 config/streams/seviri_lst/seviri_lst.yml create mode 100644 src/weathergen/datasets/data_reader_seviri.py diff --git a/config/evaluate/eval_config_lst.yml b/config/evaluate/eval_config_lst.yml new file mode 100644 index 000000000..d8790073e --- /dev/null +++ b/config/evaluate/eval_config_lst.yml @@ -0,0 +1,29 @@ +image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. +dpi_val : 300 +summary_plots : true +print_summary: true + +evaluation: + metrics : ["rmse", "mae"] + regions: ["madagaskar"] + summary_dir: "./plots/" + plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + print_summary: false #print out score values on screen. it can be verbose + +run_ids : + + ndl2qget : # Inference run id. + label: "One-shot LST prediction" + mini_epoch: 0 + rank: 0 + streams: + SEVIRI_LST: + channels: ["LST"] #["2t", "q_850", ] #["LST"] # ["LST"] #["2t", "q_850", ] + evaluation: + sample: "all" + forecast_step: "all" + plotting: + sample: [0, 1] + forecast_step: [ 1, 2, 3, 4, 5, 6] #, 2, 3, 4] #, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + plot_maps: true + plot_histograms: true \ No newline at end of file diff --git a/config/lst_config.yml b/config/lst_config.yml new file mode 100644 index 000000000..4aa16ab0a --- /dev/null +++ b/config/lst_config.yml @@ -0,0 +1,194 @@ +streams_directory: "./config/streams/seviri_lst/" + +embed_orientation: "channels" +embed_local_coords: True +embed_centroids_local_coords: False +embed_size_centroids: 0 +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +target_cell_local_prediction: True + +ae_local_dim_embed: 1024 +ae_local_num_blocks: 2 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 8 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 + +ae_aggregation_num_blocks: 2 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +forecast_offset : 0 +forecast_delta_hrs: 0 +forecast_steps: 0 +forecast_policy: null +forecast_att_dense_rate: 1.0 +fe_num_blocks: 0 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +impute_latent_noise_std: 0.0 # 1e-4 + +healpix_level: 5 + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +batch_size_per_gpu: 1 +batch_size_validation_per_gpu: 1 + +# a regex that needs to fully match the name of the modules you want to freeze +# e.g. ".*ERA5" will match any module whose name ends in ERA5\ +# encoders and decoders that exist per stream have the stream name attached at the end +freeze_modules: "" + +# whether to track the exponential moving average of weights for validation +validate_with_ema: True +ema_ramp_up_ratio: 0.09 +ema_halflife_in_thousands: 1e-3 + +# training mode: "forecast" or "masking" (masked token modeling) +# for "masking" to train with auto-encoder mode, forecast_offset should be 0 +training_mode: "masking" +training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]},} + } +# training_mode_config: {"loss": {LossPhysical: [['mse', 0.7]], +# LossLatent: [['mse', 0.3]], +# LossStudentTeacher: [{'iBOT': {}, 'JEPA': {options}}],} +# } +validation_mode_config: {"losses": {LossPhysical: {weight: 1.0, loss_fcts: [['mse', 1.0]]},} + } +# masking rate when training mode is "masking"; ignored in foreacast mode +masking_rate: 0.6 +# sample the masking rate (with normal distribution centered at masking_rate) +# note that a sampled masking rate leads to varying requirements +masking_rate_sampling: True +# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream) +sampling_rate_target: 1.0 +# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination" +masking_strategy: "random" +# masking_strategy_config is a dictionary of additional parameters for the masking strategy +# required for "healpix" and "channel" masking strategies +# "healpix": requires healpix mask level to be specified with `hl_mask` +# "channel": requires "mode" to be specified, "per_cell" or "global", +masking_strategy_config: {"strategies": ["random", "healpix", "channel"], + "probabilities": [0.34, 0.33, 0.33], + "hl_mask": 3, "mode": "per_cell", + "same_strategy_per_batch": false + } + +num_mini_epochs: 32 +samples_per_mini_epoch: 4096 +samples_per_validation: 512 + +shuffle: True + +lr_scaling_policy: "sqrt" +lr_start: 1e-6 +lr_max: 5e-5 +lr_final_decay: 1e-6 +lr_final: 0.0 +lr_steps_warmup: 512 +lr_steps_cooldown: 512 +lr_policy_warmup: "cosine" +lr_policy_decay: "constant" +lr_policy_cooldown: "linear" + +grad_clip: 1.0 +weight_decay: 0.1 +norm_type: "LayerNorm" +nn_module: "te" +log_grad_norms: False + +start_date: 197901010000 +end_date: 202012310000 +start_date_val: 201705010000 #202101010000 +end_date_val: 20170630000 #202201010000 +len_hrs: 6 +step_hrs: 6 +input_window_steps: 1 + +val_initial: False + +loader_num_workers: 8 +log_validation: 0 +streams_output: ["ERA5"] + +istep: 0 +run_history: [] + +desc: "" +data_loader_rng_seed: ??? +run_id: ??? + +# The period to log in the training loop (in number of batch steps) +train_log_freq: + terminal: 10 + metrics: 20 + checkpoint: 250 + + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings of + # the organizations codenames in https://confluence.ecmwf.int/display/MAEL/Staff+Contact+List + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: mpg + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: lst_finetune + # *** Experiment-specific tags *** + grid: v0 \ No newline at end of file diff --git a/config/streams/seviri_lst/era5.yml b/config/streams/seviri_lst/era5.yml new file mode 100644 index 000000000..e6bed15ee --- /dev/null +++ b/config/streams/seviri_lst/era5.yml @@ -0,0 +1,30 @@ +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + forcing: True + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/config/streams/seviri_lst/seviri_lst.yml b/config/streams/seviri_lst/seviri_lst.yml new file mode 100644 index 000000000..dd70ae2e9 --- /dev/null +++ b/config/streams/seviri_lst/seviri_lst.yml @@ -0,0 +1,36 @@ +SEVIRI_LST : + type : msg_lst + stream_id: 1 + filenames : ['mpg_seviri_l2_2017-18_v0/lst_test.zarr'] # use ['mpg_seviri_l2_2017-18_v0/seviri.zarr'] after zarr3 is enabled + data_start_time : "2017-02-01 00:00" + data_end_time : "2017-06-30 00:00" + target: ["LST"] + source: [] + geoinfos: [] #["DEM"] #, "LANDCOV"] + metadata: "/leonardo_work/AIFAC_5C0_154/weathergen/data/mpg_seviri_l2_2017-18_v1/metadata" # uses one scene over south africa for finetuning + scene: "scenes_train_scene_001.npz" + spatial_stride: 24 + temporal_stride: 6 + sampling_rate_target: 0.005 # use 10% of spatial points + loss_weight : 1.0 + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 64 + tokenize_spacetime : True + max_num_targets: -1 #-1 + embed : + net : transformer + num_tokens : 1 + num_heads : 2 + dim_embed : 16 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 16 + target_readout : + type : 'obs_value' + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/utils/regions.py b/packages/evaluate/src/weathergen/evaluate/utils/regions.py index db631a6af..3e05127a4 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/regions.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/regions.py @@ -28,6 +28,7 @@ class RegionLibrary: "shem": (-90.0, 0.0, -180.0, 180.0), "tropics": (-30.0, 30.0, -180.0, 180.0), "belgium": (49, 52, 2, 7), + "madagaskar": (-25, -10, 43, 50), } diff --git a/src/weathergen/datasets/data_reader_seviri.py b/src/weathergen/datasets/data_reader_seviri.py new file mode 100644 index 000000000..73f8205c1 --- /dev/null +++ b/src/weathergen/datasets/data_reader_seviri.py @@ -0,0 +1,344 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# for interactive debugging +import logging +import os +from pathlib import Path +from typing import override + +import numpy as np +import xarray as xr +import zarr +from numpy.typing import NDArray + +os.environ["ZARR_V3_EXPERIMENTAL_API"] = "1" # doesn't seem to work + +from weathergen.datasets.data_reader_base import ( + DataReaderTimestep, + ReaderData, + TimeWindowHandler, + TIndex, + check_reader_data, +) + +_logger = logging.getLogger(__name__) + + +class DataReaderSeviri(DataReaderTimestep): + """Data reader for SEVIRI satellite data.""" + + def __init__( + self, + tw_handler: TimeWindowHandler, + filename: Path, + stream_info: dict, + ) -> None: + """Initialize the SEVIRI data reader.""" + + self.fillvalue = np.nan + np32 = np.float32 + + # set sampling parameters + self.stride_temporal = stream_info["temporal_stride"] # downsample to six hourly timesteps + self.stride_spatial = stream_info["spatial_stride"] # use every 8th point to reduce memory usage on workers + + index_path = Path(stream_info["metadata"]) / stream_info["scene"] + self.spatial_indices = np.load(index_path)["seviri_indices"] + + self._zarr_path = filename + self._ds = None # opened lazily + + # Open temporarily with xarray just for init metadata (time handling is easier) + ds_xr = xr.open_zarr(filename, group="seviri") + ds_xr["time"] = ds_xr["time"].astype("datetime64[ns]") + ds_xr = ds_xr.sel(time=slice(stream_info["data_start_time"], stream_info["data_end_time"])) + + col_extent = ds_xr["longitude"].shape[0] + lat_idx = self.spatial_indices // col_extent + lon_idx = self.spatial_indices % col_extent + + # Cache spatial indices for zarr access + self._lat_idx = np.array(lat_idx[:: self.stride_spatial]) + self._lon_idx = np.array(lon_idx[:: self.stride_spatial]) + + # code.interact(local=locals()) + + # Apply spatial subset + ds_xr = ds_xr.isel(latitude=self._lat_idx, longitude=self._lon_idx) + + # Cache time values as numpy (avoid zarr access for time later) + self._time_values = np.array(ds_xr.time.values) + + # Find time indices in the full zarr that correspond to our time selection + ds_full = xr.open_zarr(filename, group="seviri") + ds_full["time"] = ds_full["time"].astype("datetime64[ns]") + full_times = ds_full.time.values + start_time = ds_xr.time.min().values + self._time_offset = int(np.searchsorted(full_times, start_time)) + + # caches lats and lons + lat_name = stream_info.get("latitude_name", "latitude") + self.latitudes = _clip_lat(np.array(ds_xr[lat_name], dtype=np32)) + lon_name = stream_info.get("longitude_name", "longitude") + self.longitudes = _clip_lon(np.array(ds_xr[lon_name], dtype=np32)) + + # check if the data overlaps with the time window, otherwise initialises as empty datareader + if tw_handler.t_start >= ds_xr.time.max() or tw_handler.t_end <= ds_xr.time.min(): + name = stream_info["name"] + _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") + super().__init__(tw_handler, stream_info) + self.init_empty() + return + + if "frequency" in stream_info: + assert False, "Frequency sub-sampling currently not supported" + + period = np.timedelta64(self.stride_temporal, "h") + + data_start_time = ds_xr.time[0].values + data_end_time = ds_xr.time[-1].values + + assert data_start_time is not None and data_end_time is not None, ( + data_start_time, + data_end_time, + ) + + # sets the time window handler and stream info in the base class + super().__init__( + tw_handler, + stream_info, + data_start_time, + data_end_time, + period, + ) + + # If there is no overlap with the time range, no need to keep the dataset. + if tw_handler.t_start >= data_end_time or tw_handler.t_end <= data_start_time: + self.init_empty() + return + else: + self.len = len(ds_xr["time"]) // self.stride_temporal + + self.exclude = {"LWMASK", "LANDCOV", "_indices", "quality_flag"} + self.channels_file = [k for k in ds_xr.keys()] + + self.geoinfo_channels = stream_info.get("geoinfos", []) + self.geoinfo_idx = [self.channels_file.index(ch) for ch in self.geoinfo_channels] + + # cache geoinfos + if len(self.geoinfo_channels) != 0: + self.geoinfo_data = np.stack( + [np.array(ds_xr[ch], dtype=np32) for ch in self.geoinfo_channels] + ) + self._geoinfo_flat = self.geoinfo_data.transpose([1, 2, 0]).reshape( + (-1, len(self.geoinfo_channels)) + ) + + # select/filter requested target channels + self.target_idx, self.target_channels = self.select_channels(ds_xr, "target") + + self.source_channels = stream_info.get("source", []) + self.source_idx = [self.channels_file.index(ch) for ch in self.source_channels] + + ds_name = stream_info["name"] + _logger.info(f"{ds_name}: target channels: {self.target_channels}") + + self.properties = { + "stream_id": 0, + } + + self.mean, self.stdev = self._create_statistics() + self.mean_geoinfo, self.stdev_geoinfo = ( + self.mean[self.geoinfo_idx], + self.stdev[self.geoinfo_idx], + ) + + print(f"geoinfo_channels: {self.geoinfo_channels}, _geoinfo_flat shape: {getattr(self, '_geoinfo_flat', 'NOT SET')}") + # Close xarray, force lazy zarr open in workers + ds_xr.close() + ds_full.close() + self._ds = None + + def _open_ds(self): + store = zarr.open(self._zarr_path, mode="r") + return store["seviri"] + + @property + def ds(self): + if self._ds is None: + self._ds = self._open_ds() + return self._ds + + @ds.setter + def ds(self, value): + self._ds = value + + def _create_statistics(self): + statistics = Path(self.stream_info["metadata"]) / "statistics_global.npz" + df_stats = _assemble_statistics_from_npz(statistics) + + mean, stdev = [], [] + + for ch in self.channels_file: + if ch in self.exclude: + mean.append(0.0) + stdev.append(1.0) + else: + mean.append(df_stats[ch]["mean"]) + stdev.append(df_stats[ch]["std"]) + + mean = np.array(mean) + stdev = np.array(stdev) + + return mean, stdev + + @override + def init_empty(self) -> None: + super().init_empty() + self._ds = None + self.len = 0 + + @override + def length(self) -> int: + return self.len + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + """ + Get data for window (for either source or target, through public interface) + """ + print(f"geoinfo_channels: {self.geoinfo_channels}, _geoinfo_flat shape: {getattr(self, '_geoinfo_flat', 'NOT SET')}") + + (t_idxs, dtr) = self._get_dataset_idxs(idx) + + if self._ds is None and self.len == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + if len(t_idxs) == 0 or len(channels_idx) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + assert t_idxs[0] >= 0, "index must be non-negative" + + # Convert to actual zarr indices (accounting for time offset and stride) + didx_start = self._time_offset + t_idxs[0] * self.stride_temporal + didx_end = self._time_offset + t_idxs[-1] * self.stride_temporal + 1 + + sel_channels = [self.channels_file[i] for i in channels_idx] + + # Access zarr directly with numpy advanced indexing + data_list = [] + for ch in sel_channels: + # zarr array: shape is (time, lat, lon) + ch_data = self.ds[ch][didx_start : didx_end : self.stride_temporal, self._lat_idx, :][ + :, :, self._lon_idx + ] + data_list.append(ch_data) + + data = np.stack(data_list, axis=-1) # shape: (n_times, n_lats, n_lons, n_channels) + + n_times = data.shape[0] + n_lats = data.shape[1] + n_lons = data.shape[2] + n_spatial = n_lats * n_lons + + # flatten along time dimension + data = data.reshape((n_times * n_spatial, len(channels_idx))) + + # prepare geoinfos + if len(self.geoinfo_channels) != 0: + geoinfos = np.tile(self._geoinfo_flat, (n_times, 1)) + else: + geoinfos = np.zeros((n_spatial * n_times, 0), dtype=np.float32) + + # construct lat/lon coords + lat2d, lon2d = np.meshgrid( + self.latitudes, + self.longitudes, + indexing="ij", + ) + lat_flat = lat2d.reshape(-1) + lon_flat = lon2d.reshape(-1) + + # Tile spatial coordinates for each timestep + coords = np.tile(np.column_stack((lat_flat, lon_flat)), (n_times, 1)) + + # Use cached time values + time_indices = slice( + t_idxs[0] * self.stride_temporal, + t_idxs[-1] * self.stride_temporal + 1, + self.stride_temporal, + ) + datetimes = np.repeat(self._time_values[time_indices], n_spatial) + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes, + ) + check_reader_data(rd, dtr) + + return rd + + def select_channels(self, ds, ch_type: str) -> NDArray[np.int64]: + """Select channels based on stream info for either source or target.""" + + channels = self.stream_info.get(ch_type) + assert channels is not None, f"{ch_type} channels need to be specified" + + is_empty = len(channels) == 0 if channels is not None else False + if is_empty: + stream_name = self.stream_info["name"] + _logger.warning(f"No channel for {stream_name} for {ch_type}.") + chs_idx = np.empty(shape=[0], dtype=int) + channels = [] + else: + chs_idx = np.sort([self.channels_file.index(ch) for ch in channels]) + channels = [self.channels_file[i] for i in chs_idx] + + return np.array(chs_idx), channels + + +def _clip_lat(lats: NDArray) -> NDArray[np.float32]: + """Clip latitudes to the range [-90, 90] and ensure periodicity.""" + return (2 * np.clip(lats, -90.0, 90.0) - lats).astype(np.float32) + + +def _clip_lon(lons: NDArray) -> NDArray[np.float32]: + """Clip longitudes to the range [-180, 180] and ensure periodicity.""" + return ((lons + 180.0) % 360.0 - 180.0).astype(np.float32) + + +def _assemble_statistics_from_npz(src: str | Path) -> dict[str, dict[str, float]]: + """ + Loads statistics saved with `save_statistics_npz`. + Returns: + dict[var_name, dict[stat_name, value]] + """ + out: dict[str, dict[str, float]] = {} + + # If it's path-like, normalize to Path; otherwise assume it's file-like + if isinstance(src, (str | Path)): + src = Path(src) + + with np.load(src, allow_pickle=True) as z: + variables = list(z["variables"]) + stat_names = [k for k in z.files if k != "variables"] + + for i, var in enumerate(variables): + out[str(var)] = {} + for stat in stat_names: + out[str(var)][stat] = np.asarray(z[stat][i]).item() + + return out diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 86049d389..dada1eb30 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -24,6 +24,7 @@ ) from weathergen.datasets.data_reader_fesom import DataReaderFesom from weathergen.datasets.data_reader_obs import DataReaderObs +from weathergen.datasets.data_reader_seviri import DataReaderSeviri from weathergen.datasets.masking import Masker from weathergen.datasets.stream_data import StreamData, spoof from weathergen.datasets.tokenizer_masking import TokenizerMasking @@ -155,6 +156,8 @@ def __init__( dataset = DataReaderAnemoi case "fesom": dataset = DataReaderFesom + case "msg_lst": + dataset = DataReaderSeviri case type_name: dataset = get_extra_reader(type_name) if dataset is None: From dd60960807d3e374c53a183e313df04f3deed7a2 Mon Sep 17 00:00:00 2001 From: MWesselkamp Date: Fri, 13 Feb 2026 08:44:40 +0100 Subject: [PATCH 2/2] chore: linting and integration tests --- .../evaluate/src/weathergen/evaluate/plotting/plotter.py | 2 +- src/weathergen/datasets/data_reader_seviri.py | 7 +++---- src/weathergen/utils/cli.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 44ec60a3b..e3cdfefa6 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -607,7 +607,7 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]: image_paths += names if image_paths: - image_paths=sorted(image_paths) + image_paths = sorted(image_paths) images = [Image.open(path) for path in image_paths] images[0].save( f"{map_output_dir}/animation_{self.run_id}_{tag}_{sa}_{self.stream}_{region}_{var}.gif", diff --git a/src/weathergen/datasets/data_reader_seviri.py b/src/weathergen/datasets/data_reader_seviri.py index 73f8205c1..e6e43aa52 100644 --- a/src/weathergen/datasets/data_reader_seviri.py +++ b/src/weathergen/datasets/data_reader_seviri.py @@ -47,7 +47,9 @@ def __init__( # set sampling parameters self.stride_temporal = stream_info["temporal_stride"] # downsample to six hourly timesteps - self.stride_spatial = stream_info["spatial_stride"] # use every 8th point to reduce memory usage on workers + self.stride_spatial = stream_info[ + "spatial_stride" + ] # use every 8th point to reduce memory usage on workers index_path = Path(stream_info["metadata"]) / stream_info["scene"] self.spatial_indices = np.load(index_path)["seviri_indices"] @@ -160,7 +162,6 @@ def __init__( self.stdev[self.geoinfo_idx], ) - print(f"geoinfo_channels: {self.geoinfo_channels}, _geoinfo_flat shape: {getattr(self, '_geoinfo_flat', 'NOT SET')}") # Close xarray, force lazy zarr open in workers ds_xr.close() ds_full.close() @@ -214,8 +215,6 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: """ Get data for window (for either source or target, through public interface) """ - print(f"geoinfo_channels: {self.geoinfo_channels}, _geoinfo_flat shape: {getattr(self, '_geoinfo_flat', 'NOT SET')}") - (t_idxs, dtr) = self._get_dataset_idxs(idx) if self._ds is None and self.len == 0: diff --git a/src/weathergen/utils/cli.py b/src/weathergen/utils/cli.py index 1c7cba6a8..2bd9fe2a2 100644 --- a/src/weathergen/utils/cli.py +++ b/src/weathergen/utils/cli.py @@ -14,7 +14,7 @@ class Stage(enum.StrEnum): def get_main_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(allow_abbrev=False) subparsers = parser.add_subparsers(dest="stage") - + train_parser = subparsers.add_parser( Stage.train, help="Train a WeatherGenerator configuration from the ground up.",