From 505efdebec7e56b848b6a86f39f19f7958f4143c Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Mon, 2 Feb 2026 21:53:09 +0100 Subject: [PATCH 01/38] Push to test --- config/default_config.yml | 19 +- src/weathergen/model/encoder.py | 27 ++ src/weathergen/train/lr_scheduler.py | 35 +- src/weathergen/train/optimizer.py | 564 +++++++++++++++++++++++++++ src/weathergen/train/trainer.py | 93 ++++- tests/test_optimizer.py | 447 +++++++++++++++++++++ 6 files changed, 1161 insertions(+), 24 deletions(-) create mode 100644 src/weathergen/train/optimizer.py create mode 100644 tests/test_optimizer.py diff --git a/config/default_config.yml b/config/default_config.yml index 4bff1abfd..12b18084d 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -161,14 +161,25 @@ training_config: parallel_scaling_policy: "sqrt" optimizer: + # Optimizer type: "adamw" (default) or "muon_adamw" (Muon for hidden weights, AdamW for embeddings/heads) + type: "adamw" grad_clip: 1.0 weight_decay: 0.1 log_grad_norms: False - adamw : + adamw: # parameters are scaled by number of DDP workers - beta1 : 0.975 - beta2 : 0.9875 - eps : 2e-08 + beta1: 0.975 + beta2: 0.9875 + eps: 2e-08 + muon: + # Learning rate multiplier for Muon relative to base LR (muon_lr = base_lr * lr_multiplier) + lr_multiplier: 20.0 + # Momentum factor for Muon SGD + momentum: 0.95 + # Use Nesterov momentum + nesterov: true + # Weight decay for Muon parameters (uses optimizer.weight_decay if not specified) + weight_decay: 0.1 losses : { "physical": { diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 47e059014..9cb2e624b 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -153,7 +153,34 @@ def assimilate_local_project_chunked(self, tokens, tokens_global, cell_lens, q_c # combined cell lens for all tokens in batch across all input steps zero_pad = torch.zeros(1, device=tokens.device, dtype=torch.int32) +<<<<<<< Updated upstream # subdivision factor for required splitting +======= + # Identify non-empty cells (sorted indices from torch.where) + non_empty_mask = cell_lens > 0 + num_non_empty = non_empty_mask.sum().item() + non_empty_indices = torch.where(non_empty_mask)[0] + + if num_non_empty == 0: + assert False, "No non-empty cells found - cannot process empty input" + + # Gather cell_lens and tokens_global for non-empty cells only + # non_empty_indices is sorted, so output will be in original cell order + cell_lens_non_empty = cell_lens[non_empty_indices] + tokens_global_non_empty = tokens_global[non_empty_indices] + + # Reorder tokens: gather tokens for non-empty cells in their original order + cumsum_orig = torch.cat([zero_pad, cell_lens.cumsum(0)]) + token_indices = torch.cat( + [ + torch.arange(cumsum_orig[idx], cumsum_orig[idx + 1], device=tokens.device) + for idx in non_empty_indices + ] + ) + tokens_reordered = tokens[token_indices] + + # Fixed number of chunks based on healpix cells (same for all ranks) +>>>>>>> Stashed changes clen = self.num_healpix_cells // (2 if self.cf.healpix_level <= 5 else 8) tokens_global_unmasked = [] posteriors = [] diff --git a/src/weathergen/train/lr_scheduler.py b/src/weathergen/train/lr_scheduler.py index e85cd1abf..db90eb65f 100644 --- a/src/weathergen/train/lr_scheduler.py +++ b/src/weathergen/train/lr_scheduler.py @@ -208,18 +208,18 @@ def step(self): if self.i_step > 0 else self.lr_max_scaled ) - for g in self.optimizer.param_groups: - g["lr"] = self.lr + self._set_param_group_lrs(self.lr) elif self.policy_decay == "constant" and phase_decay: cur_lr = self.lr self.lr = self.lr_max_scaled # make sure lr_max_scaled rate is used if warm-up end is not lr_max_scaled if cur_lr < self.lr: - for g in self.optimizer.param_groups: - g["lr"] = self.lr + self._set_param_group_lrs(self.lr) else: self.cur_scheduler.step() self.lr = self.cur_scheduler.get_last_lr()[0] + # Apply per-group LR multipliers after scheduler step + self._apply_lr_multipliers() # switch scheduler when learning rate regime completed if self.i_step == self.n_steps_warmup: @@ -237,6 +237,33 @@ def step(self): return self.lr + def _set_param_group_lrs(self, base_lr: float): + """ + Set learning rates for all parameter groups, applying per-group multipliers. + + For Muon+AdamW composite optimizers, Muon parameter groups have an lr_multiplier + that scales their learning rate relative to the base LR. + + Args: + base_lr: The base learning rate to set. + """ + for g in self.optimizer.param_groups: + lr_multiplier = g.get("lr_multiplier", 1.0) + g["lr"] = base_lr * lr_multiplier + + def _apply_lr_multipliers(self): + """ + Apply per-group LR multipliers after a scheduler step. + + The scheduler sets the same LR for all groups, so we need to scale + Muon groups by their lr_multiplier afterwards. + """ + for g in self.optimizer.param_groups: + if g.get("is_muon", False): + lr_multiplier = g.get("lr_multiplier", 1.0) + # Scale Muon groups relative to base LR + g["lr"] = self.lr * lr_multiplier + def get_lr(self): return self.lr diff --git a/src/weathergen/train/optimizer.py b/src/weathergen/train/optimizer.py new file mode 100644 index 000000000..b142d9dc2 --- /dev/null +++ b/src/weathergen/train/optimizer.py @@ -0,0 +1,564 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +Optimizer module for WeatherGenerator. + +Provides support for: +- Standard AdamW optimizer +- Hybrid Muon+AdamW optimizer (Muon for 2D hidden weights, AdamW for embeddings/heads) + +The Muon optimizer uses orthogonalization of gradients for improved training dynamics +on transformer hidden layer weights. See: https://arxiv.org/abs/2407.01490 +""" + +import logging +from typing import Any + +import numpy as np +import torch +from torch.optim import Optimizer + +logger = logging.getLogger(__name__) + + +# Patterns identifying parameters that should use AdamW (not Muon) +# These include embeddings, prediction heads, and other 1D or special parameters +ADAMW_PATTERNS = [ + "embed_target_coords", + "embeds.", + "embed.", + "unembed", + "pred_heads", + "latent_heads", + "q_cells", + "bilin", + "class_token", + "register_token", + "norm", + "bias", +] + + +def classify_muon_params( + model: torch.nn.Module, +) -> tuple[list[torch.nn.Parameter], list[torch.nn.Parameter], list[str], list[str]]: + """ + Classify model parameters into Muon-eligible and AdamW-eligible groups. + + Muon is applied to 2D hidden layer weights (attention Q/K/V/O, MLP linear layers). + AdamW is applied to embeddings, output heads, 1D parameters, and biases. + + Args: + model: The model whose parameters to classify. + + Returns: + A tuple of (muon_params, adamw_params, muon_names, adamw_names). + """ + muon_params: list[torch.nn.Parameter] = [] + adamw_params: list[torch.nn.Parameter] = [] + muon_names: list[str] = [] + adamw_names: list[str] = [] + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + name_lower = name.lower() + + # 1D parameters (biases, layer norm weights) -> AdamW + if param.ndim < 2: + adamw_params.append(param) + adamw_names.append(name) + continue + + # Check if parameter matches any AdamW pattern + is_adamw = any(pattern in name_lower for pattern in ADAMW_PATTERNS) + + if is_adamw: + adamw_params.append(param) + adamw_names.append(name) + else: + # 2D hidden weights -> Muon + muon_params.append(param) + muon_names.append(name) + + return muon_params, adamw_params, muon_names, adamw_names + + +def _scale_adamw_betas( + beta1_base: float, + beta2_base: float, + eps_base: float, + batch_size_total: int, +) -> tuple[float, float, float]: + """ + Scale AdamW hyperparameters based on batch size following SDE scaling rules. + + See: https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + + Args: + beta1_base: Base beta1 value (target for batch_size_total=1). + beta2_base: Base beta2 value (target for batch_size_total=1). + eps_base: Base epsilon value. + batch_size_total: Total effective batch size across all ranks. + + Returns: + Tuple of (scaled_beta1, scaled_beta2, scaled_eps). + """ + kappa = batch_size_total + beta1 = max(0.5, 1.0 - kappa * (1.0 - beta1_base)) + beta2 = 1.0 - kappa * (1.0 - beta2_base) + eps = eps_base / np.sqrt(kappa) + return beta1, beta2, eps + + +def create_optimizer( + model: torch.nn.Module, + optimizer_cfg: Any, + lr_cfg: Any, + batch_size_total: int, +) -> Optimizer: + """ + Factory function to create the appropriate optimizer based on config. + + Args: + model: The model to optimize. + optimizer_cfg: Optimizer configuration containing type and hyperparameters. + lr_cfg: Learning rate configuration containing lr_start. + batch_size_total: Total effective batch size across all ranks. + + Returns: + The configured optimizer (AdamW or CompositeOptimizer). + """ + optimizer_type = optimizer_cfg.get("type", "adamw") + initial_lr = lr_cfg.lr_start + weight_decay = optimizer_cfg.weight_decay + + # Scale AdamW betas based on batch size + adamw_cfg = optimizer_cfg.adamw + beta1, beta2, eps = _scale_adamw_betas( + adamw_cfg.beta1, + adamw_cfg.beta2, + adamw_cfg.get("eps", 2e-08), + batch_size_total, + ) + + if optimizer_type == "adamw": + logger.info("Creating AdamW optimizer") + return torch.optim.AdamW( + model.parameters(), + lr=initial_lr, + weight_decay=weight_decay, + betas=(beta1, beta2), + eps=eps, + ) + + elif optimizer_type == "muon_adamw": + logger.info("Creating Muon+AdamW composite optimizer") + return _create_muon_adamw_optimizer( + model=model, + optimizer_cfg=optimizer_cfg, + initial_lr=initial_lr, + weight_decay=weight_decay, + adamw_betas=(beta1, beta2), + adamw_eps=eps, + batch_size_total=batch_size_total, + ) + + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + + +def _create_muon_adamw_optimizer( + model: torch.nn.Module, + optimizer_cfg: Any, + initial_lr: float, + weight_decay: float, + adamw_betas: tuple[float, float], + adamw_eps: float, + batch_size_total: int, +) -> "CompositeOptimizer": + """ + Create a Muon+AdamW composite optimizer. + + Args: + model: The model to optimize. + optimizer_cfg: Optimizer configuration. + initial_lr: Initial learning rate (for AdamW; Muon uses multiplied version). + weight_decay: Weight decay coefficient. + adamw_betas: Scaled (beta1, beta2) for AdamW. + adamw_eps: Scaled epsilon for AdamW. + batch_size_total: Total effective batch size. + + Returns: + CompositeOptimizer wrapping Muon and AdamW. + """ + muon_cfg = optimizer_cfg.get("muon", {}) + lr_multiplier = muon_cfg.get("lr_multiplier", 20.0) + muon_momentum = muon_cfg.get("momentum", 0.95) + muon_nesterov = muon_cfg.get("nesterov", True) + muon_weight_decay = muon_cfg.get("weight_decay", weight_decay) + + # Classify parameters + muon_params, adamw_params, muon_names, adamw_names = classify_muon_params(model) + + logger.info(f"Muon parameters ({len(muon_params)}): {muon_names[:5]}...") + logger.info(f"AdamW parameters ({len(adamw_params)}): {adamw_names[:5]}...") + + # Create parameter groups for AdamW + # Include both AdamW-only params and mark them appropriately + adamw_param_groups = [ + { + "params": adamw_params, + "lr": initial_lr, + "is_muon": False, + "lr_multiplier": 1.0, + } + ] + + # Create AdamW optimizer for embeddings/heads + adamw_optimizer = torch.optim.AdamW( + adamw_param_groups, + lr=initial_lr, + weight_decay=weight_decay, + betas=adamw_betas, + eps=adamw_eps, + ) + + # Create Muon optimizer for hidden weights + muon_lr = initial_lr * lr_multiplier + + # Parameter groups for Muon + muon_param_groups = [ + { + "params": muon_params, + "lr": muon_lr, + "is_muon": True, + "lr_multiplier": lr_multiplier, + } + ] + + # Try to use PyTorch's built-in Muon if available (PyTorch >= 2.9) + muon_optimizer = _create_muon_optimizer( + param_groups=muon_param_groups, + lr=muon_lr, + momentum=muon_momentum, + nesterov=muon_nesterov, + weight_decay=muon_weight_decay, + ) + + return CompositeOptimizer( + muon_optimizer=muon_optimizer, + adamw_optimizer=adamw_optimizer, + muon_lr_multiplier=lr_multiplier, + ) + + +def _create_muon_optimizer( + param_groups: list[dict], + lr: float, + momentum: float, + nesterov: bool, + weight_decay: float, +) -> Optimizer: + """ + Create a Muon optimizer, using PyTorch's built-in version if available. + + Falls back to custom implementation for older PyTorch versions. + """ + # Try PyTorch's built-in Muon (available in PyTorch >= 2.9) + if hasattr(torch.optim, "Muon"): + logger.info("Using torch.optim.Muon") + return torch.optim.Muon( + param_groups, + lr=lr, + momentum=momentum, + nesterov=nesterov, + weight_decay=weight_decay, + ) + else: + logger.info("Using custom Muon implementation (torch.optim.Muon not available)") + return MuonCustom( + param_groups, + lr=lr, + momentum=momentum, + nesterov=nesterov, + weight_decay=weight_decay, + ) + + +class CompositeOptimizer: + """ + Composite optimizer that combines Muon and AdamW for different parameter groups. + + Muon is used for 2D hidden layer weights, AdamW for embeddings and heads. + This class wraps both optimizers and provides a unified interface. + + Note: This class does not inherit from torch.optim.Optimizer to avoid + conflicts with state management. It provides the same interface. + """ + + def __init__( + self, + muon_optimizer: Optimizer, + adamw_optimizer: Optimizer, + muon_lr_multiplier: float = 20.0, + ): + """ + Initialize the composite optimizer. + + Args: + muon_optimizer: Optimizer for Muon-eligible parameters. + adamw_optimizer: Optimizer for AdamW-eligible parameters. + muon_lr_multiplier: LR multiplier for Muon relative to base LR. + """ + self.muon_optimizer = muon_optimizer + self.adamw_optimizer = adamw_optimizer + self.muon_lr_multiplier = muon_lr_multiplier + + # Combine param_groups from both optimizers for unified interface + self._param_groups = muon_optimizer.param_groups + adamw_optimizer.param_groups + + @property + def param_groups(self) -> list: + """Return combined param groups from both optimizers.""" + return self._param_groups + + @param_groups.setter + def param_groups(self, value): + """Set param groups (needed for lr_scheduler compatibility).""" + self._param_groups = value + + def step(self, closure=None): + """ + Perform a single optimization step. + + Args: + closure: A closure that reevaluates the model and returns the loss. + Optional for most optimizers. + + Returns: + Loss value if closure is provided, None otherwise. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + self.muon_optimizer.step() + self.adamw_optimizer.step() + + return loss + + def zero_grad(self, set_to_none: bool = True): + """ + Reset gradients of all optimized parameters. + + Args: + set_to_none: If True, set gradients to None instead of zero. + This can improve memory efficiency. + """ + self.muon_optimizer.zero_grad(set_to_none=set_to_none) + self.adamw_optimizer.zero_grad(set_to_none=set_to_none) + + def state_dict(self) -> dict: + """ + Return the state of both optimizers as a single dictionary. + + Returns: + Dictionary containing state from both Muon and AdamW optimizers. + """ + return { + "muon": self.muon_optimizer.state_dict(), + "adamw": self.adamw_optimizer.state_dict(), + "muon_lr_multiplier": self.muon_lr_multiplier, + "optimizer_type": "composite_muon_adamw", + } + + def load_state_dict(self, state_dict: dict): + """ + Load optimizer state from a dictionary. + + Args: + state_dict: Dictionary containing saved optimizer state. + """ + if ( + "optimizer_type" in state_dict + and state_dict["optimizer_type"] == "composite_muon_adamw" + ): + self.muon_optimizer.load_state_dict(state_dict["muon"]) + self.adamw_optimizer.load_state_dict(state_dict["adamw"]) + self.muon_lr_multiplier = state_dict.get("muon_lr_multiplier", self.muon_lr_multiplier) + else: + # Fallback: try to load as regular optimizer state + # This handles migration from pure AdamW checkpoints + logger.warning( + "Loading non-composite state dict into CompositeOptimizer. " + "This may not work correctly - optimizer state may be lost." + ) + + @property + def state(self) -> dict: + """ + Return combined state from both optimizers. + """ + combined_state = {} + muon_state = self.muon_optimizer.state + adamw_state = self.adamw_optimizer.state + combined_state.update(muon_state) + combined_state.update(adamw_state) + return combined_state + + +class MuonCustom(Optimizer): + """ + Custom Muon optimizer implementation for PyTorch versions without torch.optim.Muon. + + Muon applies Newton-Schulz orthogonalization to gradients before the SGD update, + which helps with optimization of transformer hidden layer weights. + + Reference: https://arxiv.org/abs/2407.01490 + """ + + def __init__( + self, + params, + lr: float = 0.02, + momentum: float = 0.95, + nesterov: bool = True, + weight_decay: float = 0.0, + ns_steps: int = 5, + ): + """ + Initialize the Muon optimizer. + + Args: + params: Iterable of parameters to optimize or dicts defining param groups. + lr: Learning rate. + momentum: Momentum factor. + nesterov: Whether to use Nesterov momentum. + weight_decay: Weight decay (L2 penalty). + ns_steps: Number of Newton-Schulz iterations for orthogonalization. + """ + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + nesterov=nesterov, + weight_decay=weight_decay, + ns_steps=ns_steps, + ) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """ + Perform a single optimization step. + + Args: + closure: A closure that reevaluates the model and returns the loss. + + Returns: + Loss value if closure is provided, None otherwise. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + momentum = group["momentum"] + nesterov = group["nesterov"] + lr = group["lr"] + weight_decay = group["weight_decay"] + ns_steps = group.get("ns_steps", 5) + + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad + + # Apply weight decay + if weight_decay != 0: + grad = grad.add(p, alpha=weight_decay) + + # Apply Newton-Schulz orthogonalization for 2D+ tensors + if p.ndim >= 2: + grad = self._newton_schulz_orthogonalize(grad, ns_steps) + + # Get or initialize momentum buffer + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(grad) + + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + # Apply update + p.add_(grad, alpha=-lr) + + return loss + + def _newton_schulz_orthogonalize(self, grad: torch.Tensor, ns_steps: int) -> torch.Tensor: + """ + Apply Newton-Schulz iteration to orthogonalize the gradient. + + This projects the gradient onto the manifold of orthogonal matrices, + which helps with optimization stability for large matrices. + + Args: + grad: Gradient tensor to orthogonalize. + ns_steps: Number of Newton-Schulz iterations. + + Returns: + Orthogonalized gradient tensor. + """ + # Reshape to 2D if needed + original_shape = grad.shape + if grad.ndim > 2: + grad = grad.view(grad.shape[0], -1) + + # Transpose if needed to ensure we have more rows than columns + transposed = False + if grad.shape[0] < grad.shape[1]: + grad = grad.T + transposed = True + + # Normalize + grad = grad / (grad.norm() + 1e-7) + + # Newton-Schulz iteration: X_{k+1} = X_k (3I - X_k^T X_k) / 2 + # This converges to an orthogonal matrix + for _ in range(ns_steps): + grad = grad @ ( + 1.5 * torch.eye(grad.shape[1], device=grad.device, dtype=grad.dtype) + - 0.5 * grad.T @ grad + ) + + # Restore original orientation + if transposed: + grad = grad.T + + # Reshape back to original + grad = grad.view(original_shape) + + return grad diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 5e16593e1..969f683ac 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -31,6 +31,7 @@ from weathergen.model.utils import apply_fct_to_blocks, set_to_eval from weathergen.train.loss_calculator import LossCalculator from weathergen.train.lr_scheduler import LearningRateScheduler +from weathergen.train.optimizer import CompositeOptimizer, create_optimizer from weathergen.train.trainer_base import TrainerBase from weathergen.train.utils import ( extract_batch_metadata, @@ -46,6 +47,24 @@ logger = logging.getLogger(__name__) +<<<<<<< Updated upstream +======= +DEBUG = False +if DEBUG: + + def debug_barrier(name, rank): + """Simple checkpoint function - call at key points in training loop""" + _debug_start_time = time.time() + torch.cuda.synchronize() + elapsed = time.time() - _debug_start_time + print(f"[{elapsed:8.2f}s] [Rank {rank}] CHECKPOINT: {name}", flush=True) +else: + + def debug_barrier(name, rank): + return + + +>>>>>>> Stashed changes class Trainer(TrainerBase): def __init__(self, train_log_freq: Config): TrainerBase.__init__(self) @@ -292,22 +311,14 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): logger.warning("Trainable parameters are inaccurate with FSDP enabled.") # self.model.print_num_parameters() - # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ - # aiming for beta1=0.9 and beta2=0.95 following the MAE paper - # https://arxiv.org/pdf/2111.06377 - kappa = self.get_batch_size_total(self.batch_size_per_gpu) - # aiming for beta1 = 0.9 at one node, ie kappa=B=4 - beta1 = max(0.5, 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta1)) - # aiming for beta2 = 0.95 at one node, ie B=4 - beta2 = 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta2) - eps = self.training_cfg.optimizer.adamw.get("eps", 2e-08) / np.sqrt(kappa) - - self.optimizer = torch.optim.AdamW( - self.model.parameters(), - lr=self.training_cfg.learning_rate_scheduling.lr_start, - weight_decay=self.training_cfg.optimizer.weight_decay, - betas=(beta1, beta2), - eps=eps, + # Create optimizer using factory function + # Supports both standard AdamW and hybrid Muon+AdamW configurations + # See: https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + self.optimizer = create_optimizer( + model=self.model, + optimizer_cfg=self.training_cfg.optimizer, + lr_cfg=self.training_cfg.learning_rate_scheduling, + batch_size_total=self.get_batch_size_total(self.batch_size_per_gpu), ) self.grad_scaler = torch.amp.GradScaler("cuda") @@ -625,6 +636,12 @@ def _get_full_model_state_dict(self): def _get_full_optimizer_state_dict(self): is_rank_zero = is_root() + + # Handle CompositeOptimizer (Muon+AdamW) separately + if isinstance(self.optimizer, CompositeOptimizer): + return self._get_full_composite_optimizer_state_dict(is_rank_zero) + + # Standard optimizer (AdamW) handling sharded_sd = self.optimizer.state_dict() sharded_state = sharded_sd["state"] full_state = {} @@ -653,6 +670,50 @@ def _get_full_optimizer_state_dict(self): else: return {} + def _get_full_composite_optimizer_state_dict(self, is_rank_zero: bool): + """ + Get full optimizer state dict for CompositeOptimizer (Muon+AdamW). + + Handles DTensor consolidation for both sub-optimizers. + """ + + def consolidate_optimizer_state(optimizer): + """Consolidate sharded state from a single optimizer.""" + sharded_sd = optimizer.state_dict() + sharded_state = sharded_sd["state"] + full_state = {} + for group_id, sharded_group in sharded_state.items(): + group_state = {} + for attr, sharded_tensor in sharded_group.items(): + if isinstance(sharded_tensor, DTensor): + full_tensor = sharded_tensor.full_tensor() + else: + full_tensor = sharded_tensor + if is_rank_zero: + group_state[attr] = full_tensor.cpu() + else: + del full_tensor + if is_rank_zero: + full_state[group_id] = group_state + else: + del group_state + if is_rank_zero: + return { + "param_groups": sharded_sd["param_groups"], + "state": full_state, + } + return {} + + if is_rank_zero: + return { + "optimizer_type": "composite_muon_adamw", + "muon": consolidate_optimizer_state(self.optimizer.muon_optimizer), + "adamw": consolidate_optimizer_state(self.optimizer.adamw_optimizer), + "muon_lr_multiplier": self.optimizer.muon_lr_multiplier, + } + else: + return {} + def save_model(self, mini_epoch: int, name=None): # Saving at mini_epoch == max_mini_epoch means that we are saving the latest checkpoint. max_mini_epoch = self.training_cfg.num_mini_epochs diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 000000000..6fd6422f8 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,447 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Tests for the optimizer module.""" + +import pytest +import torch +import torch.nn as nn +from omegaconf import OmegaConf + +from weathergen.train.optimizer import ( + ADAMW_PATTERNS, + CompositeOptimizer, + MuonCustom, + classify_muon_params, + create_optimizer, +) + + +class DummyTransformerBlock(nn.Module): + """Simple transformer-like model for testing parameter classification.""" + + def __init__(self, dim: int = 64, num_heads: int = 4): + super().__init__() + self.dim = dim + self.num_heads = num_heads + + # Attention components (should be Muon-eligible) + self.proj_heads_q = nn.Linear(dim, dim, bias=False) + self.proj_heads_k = nn.Linear(dim, dim, bias=False) + self.proj_heads_v = nn.Linear(dim, dim, bias=False) + self.proj_out = nn.Linear(dim, dim, bias=False) + + # MLP components (should be Muon-eligible) + self.mlp_fc1 = nn.Linear(dim, dim * 4, bias=False) + self.mlp_fc2 = nn.Linear(dim * 4, dim, bias=False) + + # Embeddings (should be AdamW) + self.embed_target_coords = nn.Linear(3, dim, bias=False) + self.embeds = nn.Embedding(100, dim) + + # Prediction heads (should be AdamW) + self.pred_heads = nn.Linear(dim, 10, bias=False) + self.latent_heads = nn.Linear(dim, dim, bias=False) + + # Biases and norms (should be AdamW) + self.bias = nn.Parameter(torch.zeros(dim)) + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + return x + + +class SimpleMLP(nn.Module): + """Simple MLP for testing optimizer steps.""" + + def __init__(self, input_dim: int = 10, hidden_dim: int = 32, output_dim: int = 5): + super().__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + self.embed = nn.Embedding(100, hidden_dim) + + def forward(self, x): + return self.fc2(torch.relu(self.fc1(x))) + + +@pytest.fixture +def dummy_model(): + """Create a dummy transformer model for testing.""" + return DummyTransformerBlock(dim=64, num_heads=4) + + +@pytest.fixture +def simple_model(): + """Create a simple MLP model for testing optimizer steps.""" + return SimpleMLP(input_dim=10, hidden_dim=32, output_dim=5) + + +@pytest.fixture +def optimizer_cfg(): + """Create a standard optimizer config.""" + return OmegaConf.create({ + "type": "adamw", + "grad_clip": 1.0, + "weight_decay": 0.1, + "adamw": { + "beta1": 0.975, + "beta2": 0.9875, + "eps": 2e-08, + }, + "muon": { + "lr_multiplier": 20.0, + "momentum": 0.95, + "nesterov": True, + "weight_decay": 0.1, + }, + }) + + +@pytest.fixture +def lr_cfg(): + """Create a standard LR config.""" + return OmegaConf.create({ + "lr_start": 1e-6, + "lr_max": 5e-5, + }) + + +class TestClassifyMuonParams: + """Tests for the classify_muon_params function.""" + + def test_classification_separates_params(self, dummy_model): + """Test that parameters are correctly separated into Muon and AdamW groups.""" + muon_params, adamw_params, muon_names, adamw_names = classify_muon_params(dummy_model) + + # Check that all trainable params are classified + total_params = sum(1 for p in dummy_model.parameters() if p.requires_grad) + assert len(muon_params) + len(adamw_params) == total_params + + # Check names match params count + assert len(muon_params) == len(muon_names) + assert len(adamw_params) == len(adamw_names) + + def test_attention_weights_are_muon(self, dummy_model): + """Test that attention Q/K/V/O weights are classified as Muon-eligible.""" + _, _, muon_names, _ = classify_muon_params(dummy_model) + + # These should be in Muon group + expected_muon = ["proj_heads_q", "proj_heads_k", "proj_heads_v", "proj_out"] + for name in expected_muon: + assert any(name in muon_name for muon_name in muon_names), f"{name} should be Muon" + + def test_mlp_weights_are_muon(self, dummy_model): + """Test that MLP linear weights are classified as Muon-eligible.""" + _, _, muon_names, _ = classify_muon_params(dummy_model) + + # MLP weights should be Muon + assert any("mlp_fc1" in name for name in muon_names) + assert any("mlp_fc2" in name for name in muon_names) + + def test_embeddings_are_adamw(self, dummy_model): + """Test that embedding parameters are classified as AdamW-eligible.""" + _, _, _, adamw_names = classify_muon_params(dummy_model) + + # These should be in AdamW group + expected_adamw = ["embed_target_coords", "embeds"] + for name in expected_adamw: + assert any(name in adamw_name for adamw_name in adamw_names), f"{name} should be AdamW" + + def test_pred_heads_are_adamw(self, dummy_model): + """Test that prediction heads are classified as AdamW-eligible.""" + _, _, _, adamw_names = classify_muon_params(dummy_model) + + assert any("pred_heads" in name for name in adamw_names) + assert any("latent_heads" in name for name in adamw_names) + + def test_1d_params_are_adamw(self, dummy_model): + """Test that 1D parameters (biases, norm weights) are AdamW-eligible.""" + _, adamw_params, _, adamw_names = classify_muon_params(dummy_model) + + # Check that bias and norm params are in AdamW + assert any("bias" in name for name in adamw_names) + assert any("norm" in name for name in adamw_names) + + # All 1D params should be in AdamW + for param in adamw_params: + if param.ndim < 2: + assert True # 1D params are correctly in AdamW + + def test_frozen_params_excluded(self, dummy_model): + """Test that frozen parameters are excluded from classification.""" + # Freeze some parameters + dummy_model.proj_heads_q.weight.requires_grad = False + dummy_model.embed_target_coords.weight.requires_grad = False + + muon_params, adamw_params, muon_names, adamw_names = classify_muon_params(dummy_model) + + # Frozen params should not appear + assert "proj_heads_q.weight" not in muon_names + assert "embed_target_coords.weight" not in adamw_names + + # Total should be reduced + total_trainable = sum(1 for p in dummy_model.parameters() if p.requires_grad) + assert len(muon_params) + len(adamw_params) == total_trainable + + +class TestCreateOptimizer: + """Tests for the create_optimizer factory function.""" + + def test_creates_adamw_by_default(self, simple_model, optimizer_cfg, lr_cfg): + """Test that AdamW is created when type is 'adamw'.""" + optimizer_cfg.type = "adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + assert isinstance(optimizer, torch.optim.AdamW) + + def test_creates_composite_for_muon_adamw(self, simple_model, optimizer_cfg, lr_cfg): + """Test that CompositeOptimizer is created when type is 'muon_adamw'.""" + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + assert isinstance(optimizer, CompositeOptimizer) + + def test_raises_for_unknown_type(self, simple_model, optimizer_cfg, lr_cfg): + """Test that unknown optimizer type raises ValueError.""" + optimizer_cfg.type = "unknown" + + with pytest.raises(ValueError, match="Unknown optimizer type"): + create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + def test_batch_size_scaling(self, simple_model, optimizer_cfg, lr_cfg): + """Test that betas are scaled based on batch size.""" + optimizer_cfg.type = "adamw" + + opt_small = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=1) + opt_large = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=16) + + # Larger batch should have different betas (closer to target) + beta1_small = opt_small.param_groups[0]["betas"][0] + beta1_large = opt_large.param_groups[0]["betas"][0] + + # With larger batch, beta1 should be smaller (more momentum decay) + assert beta1_large < beta1_small + + +class TestCompositeOptimizer: + """Tests for the CompositeOptimizer class.""" + + def test_step_updates_both_optimizers(self, simple_model, optimizer_cfg, lr_cfg): + """Test that step() updates parameters from both optimizers.""" + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + # Create dummy input and compute loss + x = torch.randn(4, 10) + output = simple_model(x) + loss = output.sum() + + # Store initial params + initial_params = {name: p.clone() for name, p in simple_model.named_parameters()} + + # Backward and step + loss.backward() + optimizer.step() + + # Check that params changed + params_changed = False + for name, p in simple_model.named_parameters(): + if not torch.equal(p, initial_params[name]): + params_changed = True + break + + assert params_changed + + def test_zero_grad_clears_both(self, simple_model, optimizer_cfg, lr_cfg): + """Test that zero_grad() clears gradients from both optimizers.""" + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + # Create gradients + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + + # Verify grads exist + has_grads = any(p.grad is not None for p in simple_model.parameters()) + assert has_grads + + # Zero grads + optimizer.zero_grad() + + # Verify grads are cleared + for p in simple_model.parameters(): + assert p.grad is None or p.grad.abs().sum() == 0 + + def test_state_dict_roundtrip(self, simple_model, optimizer_cfg, lr_cfg): + """Test that state dict can be saved and loaded.""" + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + # Take a step to populate state + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + optimizer.step() + + # Save state + state_dict = optimizer.state_dict() + + # Verify structure + assert "optimizer_type" in state_dict + assert state_dict["optimizer_type"] == "composite_muon_adamw" + assert "muon" in state_dict + assert "adamw" in state_dict + assert "muon_lr_multiplier" in state_dict + + # Create new optimizer and load state + optimizer2 = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + optimizer2.load_state_dict(state_dict) + + # Take another step - should not raise + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + optimizer2.step() + + def test_param_groups_combined(self, simple_model, optimizer_cfg, lr_cfg): + """Test that param_groups contains groups from both optimizers.""" + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + # Should have groups from both Muon and AdamW + assert len(optimizer.param_groups) >= 2 + + # Check that is_muon flag exists + has_muon_group = any(g.get("is_muon", False) for g in optimizer.param_groups) + has_adamw_group = any(not g.get("is_muon", True) for g in optimizer.param_groups) + + assert has_muon_group + assert has_adamw_group + + +class TestMuonCustom: + """Tests for the custom Muon optimizer implementation.""" + + def test_step_updates_params(self, simple_model): + """Test that Muon step updates parameters.""" + # Get only 2D params that will have gradients (fc1, fc2 weights) + # Exclude embedding since it's not used in the forward pass + params = [ + p for name, p in simple_model.named_parameters() + if p.ndim >= 2 and "embed" not in name + ] + optimizer = MuonCustom(params, lr=0.01, momentum=0.95) + + # Create dummy gradients + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + + # Store initial values + initial_values = [p.clone() for p in params] + + # Step + optimizer.step() + + # Check params with gradients changed + for i, p in enumerate(params): + if p.grad is not None: + assert not torch.equal(p, initial_values[i]), f"Param {i} was not updated" + + def test_momentum_buffer_created(self, simple_model): + """Test that momentum buffer is created after first step.""" + # Get params that will have gradients + params = [ + p for name, p in simple_model.named_parameters() + if p.ndim >= 2 and "embed" not in name + ] + optimizer = MuonCustom(params, lr=0.01, momentum=0.95) + + # Initially no state + assert all(len(optimizer.state[p]) == 0 for p in params) + + # Create gradients and step + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + optimizer.step() + + # Now should have momentum buffer for params with gradients + for p in params: + if p.grad is not None: + assert "momentum_buffer" in optimizer.state[p] + + def test_weight_decay_applied(self): + """Test that weight decay is applied to parameters.""" + # Simple 2D parameter + param = nn.Parameter(torch.ones(4, 4)) + optimizer = MuonCustom([param], lr=0.1, momentum=0.0, weight_decay=0.1) + + # Set gradient to zero (only weight decay should affect) + param.grad = torch.zeros_like(param) + + initial_norm = param.norm().item() + optimizer.step() + final_norm = param.norm().item() + + # Weight decay should reduce norm (since grad=0, only decay acts) + assert final_norm < initial_norm + + def test_nesterov_momentum(self): + """Test that Nesterov momentum produces different results than standard momentum.""" + torch.manual_seed(42) + + # Create two identical params + param1 = nn.Parameter(torch.randn(4, 4)) + param2 = nn.Parameter(param1.clone()) + + opt_standard = MuonCustom([param1], lr=0.1, momentum=0.9, nesterov=False) + opt_nesterov = MuonCustom([param2], lr=0.1, momentum=0.9, nesterov=True) + + # Same gradient + grad = torch.randn(4, 4) + param1.grad = grad.clone() + param2.grad = grad.clone() + + # Multiple steps + for _ in range(3): + opt_standard.step() + opt_nesterov.step() + param1.grad = grad.clone() + param2.grad = grad.clone() + + # Results should differ + assert not torch.allclose(param1, param2) + + +class TestAdamWPatterns: + """Tests for the ADAMW_PATTERNS constant.""" + + def test_patterns_match_expected_names(self): + """Test that patterns match the expected parameter name patterns.""" + expected_patterns = [ + "embed_target_coords", + "embeds.", + "embed.", + "pred_heads", + "latent_heads", + "q_cells", + "bilin", + "norm", + "bias", + ] + + for pattern in expected_patterns: + assert pattern in ADAMW_PATTERNS + + def test_class_token_in_patterns(self): + """Test that class_token and register_token are in patterns.""" + assert "class_token" in ADAMW_PATTERNS + assert "register_token" in ADAMW_PATTERNS From a0e5b385d0e7f7c9d8a729c149cf2c045b64e312 Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Mon, 2 Feb 2026 20:57:17 +0000 Subject: [PATCH 02/38] Fix merge issue --- src/weathergen/model/encoder.py | 28 ---------------------------- src/weathergen/train/trainer.py | 18 ------------------ 2 files changed, 46 deletions(-) diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 9cb2e624b..c51124b80 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -153,34 +153,6 @@ def assimilate_local_project_chunked(self, tokens, tokens_global, cell_lens, q_c # combined cell lens for all tokens in batch across all input steps zero_pad = torch.zeros(1, device=tokens.device, dtype=torch.int32) -<<<<<<< Updated upstream - # subdivision factor for required splitting -======= - # Identify non-empty cells (sorted indices from torch.where) - non_empty_mask = cell_lens > 0 - num_non_empty = non_empty_mask.sum().item() - non_empty_indices = torch.where(non_empty_mask)[0] - - if num_non_empty == 0: - assert False, "No non-empty cells found - cannot process empty input" - - # Gather cell_lens and tokens_global for non-empty cells only - # non_empty_indices is sorted, so output will be in original cell order - cell_lens_non_empty = cell_lens[non_empty_indices] - tokens_global_non_empty = tokens_global[non_empty_indices] - - # Reorder tokens: gather tokens for non-empty cells in their original order - cumsum_orig = torch.cat([zero_pad, cell_lens.cumsum(0)]) - token_indices = torch.cat( - [ - torch.arange(cumsum_orig[idx], cumsum_orig[idx + 1], device=tokens.device) - for idx in non_empty_indices - ] - ) - tokens_reordered = tokens[token_indices] - - # Fixed number of chunks based on healpix cells (same for all ranks) ->>>>>>> Stashed changes clen = self.num_healpix_cells // (2 if self.cf.healpix_level <= 5 else 8) tokens_global_unmasked = [] posteriors = [] diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 969f683ac..2b70c7d6e 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -47,24 +47,6 @@ logger = logging.getLogger(__name__) -<<<<<<< Updated upstream -======= -DEBUG = False -if DEBUG: - - def debug_barrier(name, rank): - """Simple checkpoint function - call at key points in training loop""" - _debug_start_time = time.time() - torch.cuda.synchronize() - elapsed = time.time() - _debug_start_time - print(f"[{elapsed:8.2f}s] [Rank {rank}] CHECKPOINT: {name}", flush=True) -else: - - def debug_barrier(name, rank): - return - - ->>>>>>> Stashed changes class Trainer(TrainerBase): def __init__(self, train_log_freq: Config): TrainerBase.__init__(self) From 7e95beffbd23943c9f829643b1b9ea1fae1ffa73 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Mon, 2 Feb 2026 22:01:14 +0100 Subject: [PATCH 03/38] Claude fixing things --- src/weathergen/train/optimizer.py | 45 +++++++++++++++----------- tests/test_optimizer.py | 52 +++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 18 deletions(-) diff --git a/src/weathergen/train/optimizer.py b/src/weathergen/train/optimizer.py index b142d9dc2..b099140e4 100644 --- a/src/weathergen/train/optimizer.py +++ b/src/weathergen/train/optimizer.py @@ -294,15 +294,14 @@ def _create_muon_optimizer( ) -class CompositeOptimizer: +class CompositeOptimizer(Optimizer): """ Composite optimizer that combines Muon and AdamW for different parameter groups. Muon is used for 2D hidden layer weights, AdamW for embeddings and heads. This class wraps both optimizers and provides a unified interface. - Note: This class does not inherit from torch.optim.Optimizer to avoid - conflicts with state management. It provides the same interface. + Inherits from Optimizer for compatibility with PyTorch LR schedulers. """ def __init__( @@ -323,18 +322,23 @@ def __init__( self.adamw_optimizer = adamw_optimizer self.muon_lr_multiplier = muon_lr_multiplier - # Combine param_groups from both optimizers for unified interface - self._param_groups = muon_optimizer.param_groups + adamw_optimizer.param_groups + # Manually initialize Optimizer base class attributes without calling __init__ + # This avoids the param_groups setup that would conflict with our combined groups + from collections import OrderedDict, defaultdict - @property - def param_groups(self) -> list: - """Return combined param groups from both optimizers.""" - return self._param_groups + self.defaults = {} + self._optimizer_step_pre_hooks = OrderedDict() + self._optimizer_step_post_hooks = OrderedDict() + self._optimizer_state_dict_pre_hooks = OrderedDict() + self._optimizer_state_dict_post_hooks = OrderedDict() + self._optimizer_load_state_dict_pre_hooks = OrderedDict() + self._optimizer_load_state_dict_post_hooks = OrderedDict() + + # Combined param_groups from both optimizers + self.param_groups = muon_optimizer.param_groups + adamw_optimizer.param_groups - @param_groups.setter - def param_groups(self, value): - """Set param groups (needed for lr_scheduler compatibility).""" - self._param_groups = value + # State is a combined view (we override the property below) + self._state = defaultdict(dict) def step(self, closure=None): """ @@ -408,14 +412,19 @@ def load_state_dict(self, state_dict: dict): def state(self) -> dict: """ Return combined state from both optimizers. + + This provides a unified view of optimizer state for checkpointing. """ - combined_state = {} - muon_state = self.muon_optimizer.state - adamw_state = self.adamw_optimizer.state - combined_state.update(muon_state) - combined_state.update(adamw_state) + combined_state = dict(self._state) + combined_state.update(self.muon_optimizer.state) + combined_state.update(self.adamw_optimizer.state) return combined_state + @state.setter + def state(self, value): + """Set state (needed for Optimizer base class compatibility).""" + self._state = value + class MuonCustom(Optimizer): """ diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 6fd6422f8..eb73a989f 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -421,6 +421,58 @@ def test_nesterov_momentum(self): assert not torch.allclose(param1, param2) +class TestLRSchedulerCompatibility: + """Tests for LR scheduler compatibility with CompositeOptimizer.""" + + def test_works_with_onecyclelr(self, simple_model, optimizer_cfg, lr_cfg): + """Test that CompositeOptimizer works with OneCycleLR scheduler.""" + from torch.optim.lr_scheduler import OneCycleLR + + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + # This should not raise TypeError (isinstance check) + # cycle_momentum=False since CompositeOptimizer has mixed defaults + scheduler = OneCycleLR( + optimizer, + max_lr=0.01, + total_steps=100, + cycle_momentum=False, + ) + + # Take a few steps + for _ in range(5): + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + def test_works_with_linearlr(self, simple_model, optimizer_cfg, lr_cfg): + """Test that CompositeOptimizer works with LinearLR scheduler.""" + from torch.optim.lr_scheduler import LinearLR + + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + # This should not raise TypeError + scheduler = LinearLR( + optimizer, + start_factor=0.1, + total_iters=100, + ) + + # Take a few steps + for _ in range(5): + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + class TestAdamWPatterns: """Tests for the ADAMW_PATTERNS constant.""" From 2e1bd76b0a7175c84c4faa28381f60956e1074f4 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Mon, 2 Feb 2026 22:05:34 +0100 Subject: [PATCH 04/38] Fixing Betas expected everywhere --- src/weathergen/train/optimizer.py | 14 +++++++++++++- tests/test_optimizer.py | 6 +++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/weathergen/train/optimizer.py b/src/weathergen/train/optimizer.py index b099140e4..933896904 100644 --- a/src/weathergen/train/optimizer.py +++ b/src/weathergen/train/optimizer.py @@ -326,7 +326,13 @@ def __init__( # This avoids the param_groups setup that would conflict with our combined groups from collections import OrderedDict, defaultdict - self.defaults = {} + # Set defaults with betas for LR scheduler compatibility (OneCycleLR checks this) + # Use AdamW's betas since that's the more common scheduler interaction + adamw_betas = adamw_optimizer.defaults.get("betas", (0.9, 0.999)) + self.defaults = { + "betas": adamw_betas, + "momentum": muon_optimizer.defaults.get("momentum", 0.95), + } self._optimizer_step_pre_hooks = OrderedDict() self._optimizer_step_post_hooks = OrderedDict() self._optimizer_state_dict_pre_hooks = OrderedDict() @@ -334,6 +340,12 @@ def __init__( self._optimizer_load_state_dict_pre_hooks = OrderedDict() self._optimizer_load_state_dict_post_hooks = OrderedDict() + # Ensure all param groups have betas for OneCycleLR compatibility + # OneCycleLR with cycle_momentum=True tries to modify betas on ALL groups + for group in muon_optimizer.param_groups: + if "betas" not in group: + group["betas"] = adamw_betas + # Combined param_groups from both optimizers self.param_groups = muon_optimizer.param_groups + adamw_optimizer.param_groups diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index eb73a989f..fd24508a4 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -431,13 +431,13 @@ def test_works_with_onecyclelr(self, simple_model, optimizer_cfg, lr_cfg): optimizer_cfg.type = "muon_adamw" optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) - # This should not raise TypeError (isinstance check) - # cycle_momentum=False since CompositeOptimizer has mixed defaults + # This should not raise TypeError (isinstance check) or ValueError (momentum check) + # CompositeOptimizer now has proper defaults with betas and momentum scheduler = OneCycleLR( optimizer, max_lr=0.01, total_steps=100, - cycle_momentum=False, + cycle_momentum=True, # Default - requires betas or momentum in defaults ) # Take a few steps From 84fa7d156bc3ddc83fcf7442023aad932d85da27 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Tue, 3 Feb 2026 14:32:43 +0100 Subject: [PATCH 05/38] First commit --- config/default_config.yml | 15 ++ src/weathergen/model/attention.py | 152 +++++++++++++ src/weathergen/model/engines.py | 99 ++++++++- src/weathergen/model/layers.py | 73 ++++++- tests/test_layer_scale.py | 350 ++++++++++++++++++++++++++++++ 5 files changed, 680 insertions(+), 9 deletions(-) create mode 100644 tests/test_layer_scale.py diff --git a/config/default_config.yml b/config/default_config.yml index 12b18084d..7a2a55845 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -85,6 +85,21 @@ freeze_modules: "" norm_type: "LayerNorm" +# Residual scaling (LayerScale / ReZero) +# Options: null (disabled), 0.0 (ReZero), 1e-5 (LayerScale default) +# LayerScale applies per-channel learned scaling before residual addition +# ReZero initializes scaling to 0 for gradual signal introduction +layer_scale_init: null + +# Stochastic Depth rates per component (0.0 = disabled) +# Randomly drops entire residual paths during training for regularization +# Rates increase linearly with depth: 0.0 for early layers, up to specified rate for deeper layers +stochastic_depth: + ae_local: 0.0 + ae_global: 0.0 + ae_aggregation: 0.0 + forecasting: 0.0 + ##################################### streams_directory: "./config/streams/era5_1deg/" diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 99606bdce..0743a1af0 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -13,6 +13,7 @@ from flash_attn import flash_attn_func, flash_attn_varlen_func from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from weathergen.model.layers import LayerScale, StochasticDepth from weathergen.model.norms import AdaLayerNorm, RMSNorm @@ -31,6 +32,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiSelfAttentionHeadVarlen, self).__init__() @@ -66,6 +69,16 @@ def __init__( self.dtype = attention_dtype + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + assert with_flash, "Only flash attention supported at the moment" def forward(self, x, x_lens, ada_ln_aux=None): @@ -99,6 +112,14 @@ def forward(self, x, x_lens, ada_ln_aux=None): out = self.proj_out(outs.flatten(-2, -1)) + # Apply LayerScale before residual + if self.layer_scale is not None: + out = self.layer_scale(out) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + out = self.drop_path(out) + if self.with_residual: out = out + x_in @@ -119,6 +140,8 @@ def __init__( softcap=0.0, norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiSelfAttentionHeadVarlenFlex, self).__init__() @@ -149,6 +172,16 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + assert with_flash, "Only flash attention supported at the moment" def att(qs, ks, vs, x_mask): @@ -174,6 +207,15 @@ def forward(self, x, x_lens=None): outs = self.compiled_flex_attention(qs, ks, vs).transpose(1, 2).squeeze() out = self.dropout(self.proj_out(outs.flatten(-2, -1))) + + # Apply LayerScale before residual + if self.layer_scale is not None: + out = self.layer_scale(out) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + out = self.drop_path(out) + if self.with_residual: out = out + x_in @@ -197,6 +239,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -230,6 +274,17 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype + + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + assert with_flash, "Only flash attention supported." # define block mask @@ -256,6 +311,15 @@ def forward(self, x, ada_ln_aux=None): outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) out = self.proj_out(self.dropout(outs.flatten(-2, -1))) + + # Apply LayerScale before residual + if self.layer_scale is not None: + out = self.layer_scale(out) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + out = self.drop_path(out) + if self.with_residual: out = x_in + out @@ -278,6 +342,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiCrossAttentionHeadVarlen, self).__init__() @@ -318,6 +384,17 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype + + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed_q, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + assert with_flash, "Only flash attention supported at the moment" def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): @@ -355,6 +432,15 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): assert False outs = self.proj_out(outs.flatten(-2, -1)) + + # Apply LayerScale before residual + if self.layer_scale is not None: + outs = self.layer_scale(outs) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + outs = self.drop_path(outs) + if self.with_residual: outs = x_q_in + outs @@ -378,6 +464,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiCrossAttentionHeadVarlenSlicedQ, self).__init__() @@ -425,6 +513,17 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype + + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed_q, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + assert with_flash, "Only flash attention supported at the moment" def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): @@ -466,6 +565,15 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): ] outs = self.proj_out(torch.stack(outs).transpose(1, 0).flatten(-2, -1)) + + # Apply LayerScale before residual + if self.layer_scale is not None: + outs = self.layer_scale(outs) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + outs = self.drop_path(outs) + if self.with_residual: outs = x_q_in + outs.reshape(x_q_in.shape) @@ -487,6 +595,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiSelfAttentionHead, self).__init__() @@ -521,6 +631,17 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype + + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + if with_flash: self.att = torch.nn.functional.scaled_dot_product_attention else: @@ -546,6 +667,15 @@ def forward(self, x, ada_ln_aux=None): outs = flash_attn_func(qs, ks, vs, softcap=self.softcap, dropout_p=dropout_rate) out = self.proj_out(outs.flatten(-2, -1)) + + # Apply LayerScale before residual + if self.layer_scale is not None: + out = self.layer_scale(out) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + out = self.drop_path(out) + if self.with_residual: out = out + x_in @@ -566,6 +696,8 @@ def __init__( norm_type="LayerNorm", norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiCrossAttentionHead, self).__init__() @@ -602,6 +734,17 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype + + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed_q, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + self.att = torch.nn.functional.scaled_dot_product_attention self.softmax = torch.nn.Softmax(dim=-1) @@ -624,6 +767,15 @@ def forward(self, x_q, x_kv): outs = self.att(qs, ks, vs).transpose(2, 1) outs = self.dropout(self.proj_out(outs.flatten(-2, -1))) + + # Apply LayerScale before residual + if self.layer_scale is not None: + outs = self.layer_scale(outs) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + outs = self.drop_path(outs) + if self.with_residual: outs = x_q_in + outs diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 24fb794a8..57035780b 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -138,7 +138,16 @@ def __init__(self, cf: Config) -> None: self.cf = cf self.ae_local_blocks = torch.nn.ModuleList() - for _ in range(self.cf.ae_local_num_blocks): + # Get LayerScale and StochasticDepth config + layer_scale_init = cf.get("layer_scale_init", None) + stochastic_depth_cfg = cf.get("stochastic_depth", {}) + max_drop_rate = stochastic_depth_cfg.get("ae_local", 0.0) if stochastic_depth_cfg else 0.0 + num_blocks = self.cf.ae_local_num_blocks + + for i in range(num_blocks): + # Linear scaling of drop rate with depth + drop_rate = max_drop_rate * (i / max(num_blocks - 1, 1)) if num_blocks > 1 else 0.0 + self.ae_local_blocks.append( MultiSelfAttentionHeadVarlen( self.cf.ae_local_dim_embed, @@ -149,6 +158,8 @@ def __init__(self, cf: Config) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) self.ae_local_blocks.append( @@ -159,6 +170,8 @@ def __init__(self, cf: Config) -> None: dropout_rate=self.cf.ae_local_dropout_rate, norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) @@ -181,6 +194,15 @@ def __init__(self, cf: Config) -> None: self.cf = cf self.ae_adapter = torch.nn.ModuleList() + # Get LayerScale and StochasticDepth config + layer_scale_init = cf.get("layer_scale_init", None) + stochastic_depth_cfg = cf.get("stochastic_depth", {}) + # Use ae_local rate for adapter (transition layer) + max_drop_rate = stochastic_depth_cfg.get("ae_local", 0.0) if stochastic_depth_cfg else 0.0 + ae_adapter_num_blocks = cf.get("ae_adapter_num_blocks", 2) + + # First block + drop_rate = 0.0 if ae_adapter_num_blocks <= 1 else 0.0 self.ae_adapter.append( MultiCrossAttentionHeadVarlenSlicedQ( self.cf.ae_global_dim_embed, @@ -195,11 +217,19 @@ def __init__(self, cf: Config) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) - ae_adapter_num_blocks = cf.get("ae_adapter_num_blocks", 2) - for _ in range(ae_adapter_num_blocks - 1): + for i in range(ae_adapter_num_blocks - 1): + # Linear scaling of drop rate with depth + drop_rate = ( + max_drop_rate * ((i + 1) / max(ae_adapter_num_blocks - 1, 1)) + if ae_adapter_num_blocks > 1 + else 0.0 + ) + self.ae_adapter.append( MLP( self.cf.ae_global_dim_embed, @@ -208,6 +238,8 @@ def __init__(self, cf: Config) -> None: dropout_rate=self.cf.ae_adapter_dropout_rate, norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) self.ae_adapter.append( @@ -224,6 +256,8 @@ def __init__(self, cf: Config) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) @@ -257,12 +291,23 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: self.ae_aggregation_blocks = torch.nn.ModuleList() + # Get LayerScale and StochasticDepth config + layer_scale_init = cf.get("layer_scale_init", None) + stochastic_depth_cfg = cf.get("stochastic_depth", {}) + max_drop_rate = ( + stochastic_depth_cfg.get("ae_aggregation", 0.0) if stochastic_depth_cfg else 0.0 + ) + num_blocks = self.cf.ae_aggregation_num_blocks + global_rate = int(1 / self.cf.ae_aggregation_att_dense_rate) - for i in range(self.cf.ae_aggregation_num_blocks): + for i in range(num_blocks): + # Linear scaling of drop rate with depth + drop_rate = max_drop_rate * (i / max(num_blocks - 1, 1)) if num_blocks > 1 else 0.0 + ## Alternate between local and global attention # as controlled by cf.ae_dense_local_att_dense_rate # Last block is always global attention - if i % global_rate == 0 or i + 1 == self.cf.ae_aggregation_num_blocks: + if i % global_rate == 0 or i + 1 == num_blocks: self.ae_aggregation_blocks.append( MultiSelfAttentionHeadVarlen( self.cf.ae_global_dim_embed, @@ -273,6 +318,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) else: @@ -289,6 +336,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) # MLP block @@ -301,6 +350,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: hidden_factor=self.cf.ae_aggregation_mlp_hidden_factor, norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) @@ -329,12 +380,21 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: self.ae_global_blocks = torch.nn.ModuleList() + # Get LayerScale and StochasticDepth config + layer_scale_init = cf.get("layer_scale_init", None) + stochastic_depth_cfg = cf.get("stochastic_depth", {}) + max_drop_rate = stochastic_depth_cfg.get("ae_global", 0.0) if stochastic_depth_cfg else 0.0 + num_blocks = self.cf.ae_global_num_blocks + global_rate = int(1 / self.cf.ae_global_att_dense_rate) - for i in range(self.cf.ae_global_num_blocks): + for i in range(num_blocks): + # Linear scaling of drop rate with depth + drop_rate = max_drop_rate * (i / max(num_blocks - 1, 1)) if num_blocks > 1 else 0.0 + ## Alternate between local and global attention # as controlled by cf.ae_global_att_dense_rate # Last block is always global attention - if i % global_rate == 0 or i + 1 == self.cf.ae_global_num_blocks: + if i % global_rate == 0 or i + 1 == num_blocks: self.ae_global_blocks.append( MultiSelfAttentionHead( self.cf.ae_global_dim_embed, @@ -345,6 +405,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) else: @@ -360,6 +422,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) # MLP block @@ -372,6 +436,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: hidden_factor=self.cf.ae_global_mlp_hidden_factor, norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) if self.cf.get("ae_global_trailing_layer_norm", False): @@ -400,9 +466,20 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.num_healpix_cells = num_healpix_cells self.fe_blocks = torch.nn.ModuleList() + # Get LayerScale and StochasticDepth config + layer_scale_init = cf.get("layer_scale_init", None) + stochastic_depth_cfg = cf.get("stochastic_depth", {}) + max_drop_rate = ( + stochastic_depth_cfg.get("forecasting", 0.0) if stochastic_depth_cfg else 0.0 + ) + num_blocks = self.cf.fe_num_blocks + global_rate = int(1 / self.cf.forecast_att_dense_rate) if mode_cfg.get("forecast", {}).get("policy") is not None: - for i in range(self.cf.fe_num_blocks): + for i in range(num_blocks): + # Linear scaling of drop rate with depth + drop_rate = max_drop_rate * (i / max(num_blocks - 1, 1)) if num_blocks > 1 else 0.0 + # Alternate between global and local attention if (i % global_rate == 0) or i + 1 == self.cf.ae_global_num_blocks: self.fe_blocks.append( @@ -416,6 +493,8 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) else: @@ -432,6 +511,8 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) # Add MLP block @@ -444,6 +525,8 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = norm_type=self.cf.norm_type, dim_aux=dim_aux, norm_eps=self.cf.mlp_norm_eps, + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) # Optionally, add LayerNorm after i-th layer diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 1f7b8df5d..7238d4799 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -14,6 +14,51 @@ from weathergen.model.norms import AdaLayerNorm, RMSNorm +class LayerScale(nn.Module): + """Per-channel learnable scaling, as in CaiT (Touvron et al., 2021). + + Applies a learned per-channel scaling factor to the input. When used before + residual connections, it allows the network to gradually incorporate new + layer contributions during training. + + Args: + dim: Number of channels/features to scale. + init_value: Initial value for the scaling factors. Use 1e-5 for LayerScale + or 0.0 for ReZero initialization. + """ + + def __init__(self, dim: int, init_value: float = 1e-5): + super().__init__() + self.gamma = nn.Parameter(init_value * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.gamma + + +class StochasticDepth(nn.Module): + """Stochastic Depth / DropPath regularization (Huang et al., 2016). + + Randomly drops entire residual paths during training. This acts as a form + of regularization and enables training deeper networks. + + Args: + drop_prob: Probability of dropping the path. 0.0 means no dropping. + """ + + def __init__(self, drop_prob: float = 0.0): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.training or self.drop_prob == 0.0: + return x + keep_prob = 1.0 - self.drop_prob + # Per-sample dropout (batch dimension) + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + mask = torch.empty(shape, dtype=x.dtype, device=x.device).bernoulli_(keep_prob) + return x * mask / keep_prob # Scale to maintain expected value + + class NamedLinear(torch.nn.Module): def __init__(self, name: str | None = None, **kwargs): super(NamedLinear, self).__init__() @@ -43,8 +88,16 @@ def __init__( dim_aux=None, norm_eps=1e-5, name: str | None = None, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): - """Constructor""" + """Constructor + + Args: + layer_scale_init: If not None, applies LayerScale with this init value. + Use 1e-5 for LayerScale, 0.0 for ReZero. + stochastic_depth_rate: Probability of dropping this block during training. + """ super(MLP, self).__init__() @@ -79,12 +132,30 @@ def __init__( self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_out, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + def forward(self, *args): x, x_in, aux = args[0], args[0], args[-1] for i, layer in enumerate(self.layers): x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) + # Apply LayerScale before residual + if self.layer_scale is not None: + x = self.layer_scale(x) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + x = self.drop_path(x) + if self.with_residual: if x.shape[-1] == x_in.shape[-1]: x = x_in + x diff --git a/tests/test_layer_scale.py b/tests/test_layer_scale.py new file mode 100644 index 000000000..a62002ac1 --- /dev/null +++ b/tests/test_layer_scale.py @@ -0,0 +1,350 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Unit tests for LayerScale and StochasticDepth modules.""" + +import pytest +import torch + +from weathergen.model.layers import LayerScale, MLP, StochasticDepth + + +class TestLayerScale: + """Tests for the LayerScale module.""" + + def test_init_value(self): + """Test that gamma is initialized to the specified value.""" + dim = 64 + init_value = 1e-5 + layer_scale = LayerScale(dim, init_value) + + assert layer_scale.gamma.shape == (dim,) + assert torch.allclose(layer_scale.gamma, torch.full((dim,), init_value)) + + def test_init_value_rezero(self): + """Test ReZero initialization (init_value=0).""" + dim = 64 + layer_scale = LayerScale(dim, init_value=0.0) + + assert torch.allclose(layer_scale.gamma, torch.zeros(dim)) + + def test_forward_scaling(self): + """Test that forward applies per-channel scaling.""" + dim = 64 + batch_size = 8 + seq_len = 16 + init_value = 0.5 + + layer_scale = LayerScale(dim, init_value) + x = torch.randn(batch_size, seq_len, dim) + + out = layer_scale(x) + + expected = x * init_value + assert torch.allclose(out, expected) + + def test_forward_with_learned_gamma(self): + """Test forward with modified gamma values.""" + dim = 64 + layer_scale = LayerScale(dim, init_value=1.0) + + # Modify gamma + with torch.no_grad(): + layer_scale.gamma.fill_(2.0) + + x = torch.randn(8, 16, dim) + out = layer_scale(x) + + expected = x * 2.0 + assert torch.allclose(out, expected) + + def test_gradient_flow(self): + """Test that gradients flow through LayerScale.""" + dim = 64 + layer_scale = LayerScale(dim, init_value=1e-5) + x = torch.randn(8, 16, dim, requires_grad=True) + + out = layer_scale(x) + loss = out.sum() + loss.backward() + + assert x.grad is not None + assert layer_scale.gamma.grad is not None + + def test_output_shape(self): + """Test that output shape matches input shape.""" + dim = 64 + layer_scale = LayerScale(dim, init_value=1e-5) + + for shape in [(8, dim), (8, 16, dim), (8, 16, 32, dim)]: + x = torch.randn(*shape) + out = layer_scale(x) + assert out.shape == x.shape + + +class TestStochasticDepth: + """Tests for the StochasticDepth module.""" + + def test_init(self): + """Test initialization with drop probability.""" + drop_prob = 0.1 + sd = StochasticDepth(drop_prob) + assert sd.drop_prob == drop_prob + + def test_eval_mode_no_drop(self): + """Test that eval mode never drops (identity).""" + drop_prob = 0.9 # High drop prob + sd = StochasticDepth(drop_prob) + sd.eval() + + x = torch.randn(8, 16, 64) + out = sd(x) + + assert torch.equal(out, x) + + def test_train_mode_zero_prob(self): + """Test that zero drop probability is identity in train mode.""" + sd = StochasticDepth(drop_prob=0.0) + sd.train() + + x = torch.randn(8, 16, 64) + out = sd(x) + + assert torch.equal(out, x) + + def test_train_mode_high_prob(self): + """Test that very high drop probability drops most samples in train mode.""" + sd = StochasticDepth(drop_prob=0.99) + sd.train() + + torch.manual_seed(42) + x = torch.ones(100, 16, 64) + out = sd(x) + + # With 99% drop, most samples should be zero + zero_samples = (out.sum(dim=(1, 2)) == 0).sum().item() + assert zero_samples > 90 # At least 90 out of 100 should be dropped + + def test_expected_value_preservation(self): + """Test that expected value is preserved during training.""" + drop_prob = 0.3 + sd = StochasticDepth(drop_prob) + sd.train() + + torch.manual_seed(42) + x = torch.ones(1000, 16, 64) + + # Run many times to average + outputs = [] + for _ in range(1000): + outputs.append(sd(x).mean().item()) + + mean_output = sum(outputs) / len(outputs) + # Expected value should be approximately 1.0 (the input value) + assert abs(mean_output - 1.0) < 0.1 # Allow 10% tolerance + + def test_per_sample_dropping(self): + """Test that dropping is per-sample in batch dimension.""" + drop_prob = 0.5 + sd = StochasticDepth(drop_prob) + sd.train() + + torch.manual_seed(42) + batch_size = 100 + x = torch.ones(batch_size, 16, 64) + + out = sd(x) + + # Check that samples are either scaled or zero + sample_sums = out.sum(dim=(1, 2)) + expected_sum_scaled = 16 * 64 / (1 - drop_prob) + + for s in sample_sums: + # Each sample should be either 0 or scaled + assert s.item() == 0.0 or abs(s.item() - expected_sum_scaled) < 1e-4 + + def test_gradient_flow(self): + """Test that gradients flow through StochasticDepth.""" + sd = StochasticDepth(drop_prob=0.5) + sd.train() + + torch.manual_seed(42) # Ensure some samples are kept + x = torch.randn(8, 16, 64, requires_grad=True) + + out = sd(x) + loss = out.sum() + loss.backward() + + # Gradient should exist for kept samples + assert x.grad is not None + + def test_output_shape(self): + """Test that output shape matches input shape.""" + sd = StochasticDepth(drop_prob=0.5) + sd.train() + + for shape in [(8, 64), (8, 16, 64), (8, 16, 32, 64)]: + x = torch.randn(*shape) + out = sd(x) + assert out.shape == x.shape + + +class TestMLPWithLayerScaleAndStochasticDepth: + """Integration tests for MLP with LayerScale and StochasticDepth.""" + + def test_mlp_with_layer_scale(self): + """Test MLP with LayerScale enabled.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + layer_scale_init=1e-5, + ) + + assert mlp.layer_scale is not None + assert isinstance(mlp.layer_scale, LayerScale) + + x = torch.randn(8, 16, 64) + out = mlp(x) + + assert out.shape == x.shape + + def test_mlp_with_stochastic_depth(self): + """Test MLP with StochasticDepth enabled.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + stochastic_depth_rate=0.1, + ) + + assert mlp.drop_path is not None + assert isinstance(mlp.drop_path, StochasticDepth) + + mlp.train() + x = torch.randn(8, 16, 64) + out = mlp(x) + + assert out.shape == x.shape + + def test_mlp_with_both(self): + """Test MLP with both LayerScale and StochasticDepth.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + layer_scale_init=1e-5, + stochastic_depth_rate=0.1, + ) + + assert mlp.layer_scale is not None + assert mlp.drop_path is not None + + mlp.train() + x = torch.randn(8, 16, 64) + out = mlp(x) + + assert out.shape == x.shape + + def test_mlp_without_features(self): + """Test MLP with neither feature (backward compatibility).""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + ) + + assert mlp.layer_scale is None + assert mlp.drop_path is None + + x = torch.randn(8, 16, 64) + out = mlp(x) + + assert out.shape == x.shape + + def test_mlp_layer_scale_in_state_dict(self): + """Test that LayerScale parameters appear in state_dict.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + layer_scale_init=1e-5, + ) + + state_dict = mlp.state_dict() + assert "layer_scale.gamma" in state_dict + + def test_mlp_gradient_flow_with_features(self): + """Test gradient flow through MLP with LayerScale and StochasticDepth.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + layer_scale_init=1e-5, + stochastic_depth_rate=0.1, + ) + mlp.train() + + torch.manual_seed(42) + x = torch.randn(8, 16, 64, requires_grad=True) + + out = mlp(x) + loss = out.sum() + loss.backward() + + assert x.grad is not None + assert mlp.layer_scale.gamma.grad is not None + + +class TestReZero: + """Tests specifically for ReZero initialization (layer_scale_init=0).""" + + def test_rezero_initial_output(self): + """Test that ReZero initially outputs just the residual.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + layer_scale_init=0.0, # ReZero + ) + + x = torch.randn(8, 16, 64) + out = mlp(x) + + # With ReZero, initial output should be approximately equal to input + # (since layer_scale starts at 0, the layer contribution is 0) + assert torch.allclose(out, x, atol=1e-5) + + def test_rezero_gradual_learning(self): + """Test that ReZero allows gradual learning of layer scale.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + layer_scale_init=0.0, + ) + + # Initially gamma is 0 + assert torch.allclose(mlp.layer_scale.gamma, torch.zeros(64)) + + # After gradient update, gamma should change + x = torch.randn(8, 16, 64) + target = torch.randn(8, 16, 64) + + optimizer = torch.optim.SGD(mlp.parameters(), lr=0.1) + + for _ in range(10): + optimizer.zero_grad() + out = mlp(x) + loss = ((out - target) ** 2).mean() + loss.backward() + optimizer.step() + + # Gamma should now be non-zero + assert not torch.allclose(mlp.layer_scale.gamma, torch.zeros(64)) From c98c746ae292f2a61196bd677dfcd364496fe8c7 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Tue, 3 Feb 2026 15:00:11 +0100 Subject: [PATCH 06/38] use existing implementation --- src/weathergen/train/optimizer.py | 182 ++++++++++++++++++------------ 1 file changed, 111 insertions(+), 71 deletions(-) diff --git a/src/weathergen/train/optimizer.py b/src/weathergen/train/optimizer.py index 933896904..5dd047b05 100644 --- a/src/weathergen/train/optimizer.py +++ b/src/weathergen/train/optimizer.py @@ -438,14 +438,104 @@ def state(self, value): self._state = value +def _zeropower_via_newtonschulz5(grad: torch.Tensor, steps: int) -> torch.Tensor: + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of grad. + + Uses quintic iteration with coefficients selected to maximize the slope at zero. + This produces something like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), + rather than exact UV^T, but this doesn't hurt model performance. + + Reference: https://github.com/KellerJordan/Muon + + Args: + grad: Gradient tensor (must be at least 2D). + steps: Number of Newton-Schulz iterations. + + Returns: + Orthogonalized gradient tensor. + """ + assert grad.ndim >= 2 + coef_a, coef_b, coef_c = (3.4445, -4.7750, 2.0315) + x = grad.bfloat16() + + # Transpose if more rows than columns (NS works better on wide matrices) + if grad.size(-2) > grad.size(-1): + x = x.mT + + # Normalize by spectral norm (approximated by Frobenius norm for stability) + x = x / (x.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Perform Newton-Schulz iterations with quintic coefficients + for _ in range(steps): + xxt = x @ x.mT + poly = coef_b * xxt + coef_c * xxt @ xxt + x = coef_a * x + poly @ x + + # Restore original orientation + if grad.size(-2) > grad.size(-1): + x = x.mT + + return x + + +def _muon_update( + grad: torch.Tensor, + momentum_buffer: torch.Tensor, + beta: float = 0.95, + ns_steps: int = 5, + nesterov: bool = True, +) -> torch.Tensor: + """ + Compute Muon update: momentum + orthogonalization + scaling. + + Args: + grad: Parameter gradient. + momentum_buffer: Momentum buffer (modified in-place). + beta: Momentum coefficient. + ns_steps: Number of Newton-Schulz iterations. + nesterov: Whether to use Nesterov momentum. + + Returns: + The update to apply to parameters. + """ + # Momentum accumulation using lerp for numerical stability + momentum_buffer.lerp_(grad, 1 - beta) + + # Compute update (Nesterov or standard momentum) + if nesterov: + update = grad.lerp(momentum_buffer, beta) + else: + update = momentum_buffer.clone() + + # Reshape for orthogonalization if needed (e.g., conv filters) + original_shape = update.shape + if update.ndim == 4: + update = update.view(len(update), -1) + + # Apply Newton-Schulz orthogonalization + update = _zeropower_via_newtonschulz5(update, steps=ns_steps) + + # Scale by sqrt(max(1, rows/cols)) to preserve gradient magnitude + update = update * max(1, update.size(-2) / update.size(-1)) ** 0.5 + + # Restore original shape and dtype + return update.to(grad.dtype).view(original_shape) + + class MuonCustom(Optimizer): """ - Custom Muon optimizer implementation for PyTorch versions without torch.optim.Muon. + Custom Muon optimizer implementation based on Keller Jordan's reference. + + Muon (MomentUm Orthogonalized by Newton-schulz) internally runs standard SGD-momentum, + then performs an orthogonalization post-processing step where each 2D parameter's update + is replaced with the nearest orthogonal matrix via Newton-Schulz iteration. - Muon applies Newton-Schulz orthogonalization to gradients before the SGD update, - which helps with optimization of transformer hidden layer weights. + Reference: https://github.com/KellerJordan/Muon + https://kellerjordan.github.io/posts/muon/ - Reference: https://arxiv.org/abs/2407.01490 + Note: Muon should only be used for hidden weight layers. Embeddings, output heads, + biases, and layer norms should use AdamW. """ def __init__( @@ -462,10 +552,10 @@ def __init__( Args: params: Iterable of parameters to optimize or dicts defining param groups. - lr: Learning rate. - momentum: Momentum factor. + lr: Learning rate (in units of spectral norm per update). + momentum: Momentum factor (0.95 is typically good). nesterov: Whether to use Nesterov momentum. - weight_decay: Weight decay (L2 penalty). + weight_decay: Decoupled weight decay (like AdamW). ns_steps: Number of Newton-Schulz iterations for orthogonalization. """ if lr < 0.0: @@ -511,75 +601,25 @@ def step(self, closure=None): if p.grad is None: continue - grad = p.grad - - # Apply weight decay - if weight_decay != 0: - grad = grad.add(p, alpha=weight_decay) - - # Apply Newton-Schulz orthogonalization for 2D+ tensors - if p.ndim >= 2: - grad = self._newton_schulz_orthogonalize(grad, ns_steps) - - # Get or initialize momentum buffer + # Initialize momentum buffer if needed state = self.state[p] if len(state) == 0: state["momentum_buffer"] = torch.zeros_like(p) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(grad) + # Compute Muon update + update = _muon_update( + p.grad, + state["momentum_buffer"], + beta=momentum, + ns_steps=ns_steps, + nesterov=nesterov, + ) - if nesterov: - grad = grad.add(buf, alpha=momentum) - else: - grad = buf + # Apply decoupled weight decay FIRST (like AdamW) + if weight_decay != 0: + p.mul_(1 - lr * weight_decay) - # Apply update - p.add_(grad, alpha=-lr) + # Apply the orthogonalized update + p.add_(update.view(p.shape), alpha=-lr) return loss - - def _newton_schulz_orthogonalize(self, grad: torch.Tensor, ns_steps: int) -> torch.Tensor: - """ - Apply Newton-Schulz iteration to orthogonalize the gradient. - - This projects the gradient onto the manifold of orthogonal matrices, - which helps with optimization stability for large matrices. - - Args: - grad: Gradient tensor to orthogonalize. - ns_steps: Number of Newton-Schulz iterations. - - Returns: - Orthogonalized gradient tensor. - """ - # Reshape to 2D if needed - original_shape = grad.shape - if grad.ndim > 2: - grad = grad.view(grad.shape[0], -1) - - # Transpose if needed to ensure we have more rows than columns - transposed = False - if grad.shape[0] < grad.shape[1]: - grad = grad.T - transposed = True - - # Normalize - grad = grad / (grad.norm() + 1e-7) - - # Newton-Schulz iteration: X_{k+1} = X_k (3I - X_k^T X_k) / 2 - # This converges to an orthogonal matrix - for _ in range(ns_steps): - grad = grad @ ( - 1.5 * torch.eye(grad.shape[1], device=grad.device, dtype=grad.dtype) - - 0.5 * grad.T @ grad - ) - - # Restore original orientation - if transposed: - grad = grad.T - - # Reshape back to original - grad = grad.view(original_shape) - - return grad From c2495b52fbd1afb2b7a3f5b236b5cc4be06df67a Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Tue, 3 Feb 2026 15:55:20 +0000 Subject: [PATCH 07/38] Add Layerscale etc to default config --- config/default_config.yml | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 7a2a55845..f14687b04 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -89,16 +89,18 @@ norm_type: "LayerNorm" # Options: null (disabled), 0.0 (ReZero), 1e-5 (LayerScale default) # LayerScale applies per-channel learned scaling before residual addition # ReZero initializes scaling to 0 for gradual signal introduction -layer_scale_init: null +# LayerScale - small init for stable gradients +layer_scale_init: 1e-5 # Stochastic Depth rates per component (0.0 = disabled) # Randomly drops entire residual paths during training for regularization # Rates increase linearly with depth: 0.0 for early layers, up to specified rate for deeper layers -stochastic_depth: - ae_local: 0.0 - ae_global: 0.0 - ae_aggregation: 0.0 - forecasting: 0.0 +# Stochastic Depth - light regularization +stochastic_depth: + ae_local: 0.0 # Keep local attention stable + ae_global: 0.1 # Light dropout on global + ae_aggregation: 0.05 + forecasting: 0.1 ##################################### From 68aa2f4c5ff2af551ef8b2810358ff5a22a627a6 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux <24638638+sophie-xhonneux@users.noreply.github.com> Date: Tue, 3 Feb 2026 17:04:16 +0100 Subject: [PATCH 08/38] Make JEPA default config for testing --- config/default_config.yml | 112 +++++++++++++++++++++++--------------- 1 file changed, 67 insertions(+), 45 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index f14687b04..c94caeb0a 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -11,8 +11,8 @@ embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 -ae_local_dim_embed: 512 #1024 -ae_local_num_blocks: 2 +ae_local_dim_embed: 1024 +ae_local_num_blocks: 4 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 ae_local_with_qk_lnorm: True @@ -25,7 +25,7 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 512 #1024 #2048 +ae_global_dim_embed: 2048 ae_global_num_blocks: 2 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 @@ -37,7 +37,7 @@ ae_global_block_factor: 64 ae_global_mlp_hidden_factor: 2 ae_global_trailing_layer_norm: False -ae_aggregation_num_blocks: 2 +ae_aggregation_num_blocks: 8 ae_aggregation_num_heads: 32 ae_aggregation_dropout_rate: 0.1 ae_aggregation_with_qk_lnorm: True @@ -45,7 +45,7 @@ ae_aggregation_att_dense_rate: 1.0 ae_aggregation_block_factor: 64 ae_aggregation_mlp_hidden_factor: 2 -decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning pred_adapter_kv: False pred_self_attention: True pred_dyadic_dims: False @@ -63,13 +63,15 @@ fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False healpix_level: 5 with_mixed_precision: True with_flash_attention: True compile_model: False -with_fsdp: True +with_fsdp: False +ddp_find_unused_parameters: False attention_dtype: bf16 mixed_precision_dtype: bf16 mlp_norm_eps: 1e-5 @@ -84,6 +86,8 @@ latent_noise_deterministic_latents: True freeze_modules: "" norm_type: "LayerNorm" +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore # Residual scaling (LayerScale / ReZero) # Options: null (disabled), 0.0 (ReZero), 1e-5 (LayerScale default) @@ -101,15 +105,12 @@ stochastic_depth: ae_global: 0.1 # Light dropout on global ae_aggregation: 0.05 forecasting: 0.1 - ##################################### streams_directory: "./config/streams/era5_1deg/" +# streams_directory: "./config/streams/era5_nppatms_synop/" streams: ??? -# type of zarr_store -zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore - general: # mutable parameters @@ -123,7 +124,7 @@ general: # model_path, # run_path, # path_shared_ - + multiprocessing_method: "fork" desc: "" @@ -141,19 +142,13 @@ data_loading : num_workers: 12 rng_seed: ??? - repeat_data_in_mini_epoch : False - - # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with - # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. - # If this happens, you can disable the flag, but performance will drop on GH200. - memory_pinning: True # config for training training_config: # training_mode: "masking", "student_teacher", "latent_loss" - training_mode: ["masking"] + training_mode: ["student_teacher"] num_mini_epochs: 32 samples_per_mini_epoch: 4096 @@ -164,6 +159,8 @@ training_config: time_window_step: 06:00:00 time_window_len: 06:00:00 + + window_offset_prediction : 0 learning_rate_scheduling : lr_start: 1e-6 @@ -179,8 +176,8 @@ training_config: optimizer: # Optimizer type: "adamw" (default) or "muon_adamw" (Muon for hidden weights, AdamW for embeddings/heads) - type: "adamw" - grad_clip: 1.0 + type: "muon_adamw" + grad_clip: 0.5 weight_decay: 0.1 log_grad_norms: False adamw: @@ -199,24 +196,55 @@ training_config: weight_decay: 0.1 losses : { - "physical": { - type: LossPhysical, - loss_fcts: { "mse": { }, }, + "student-teacher": { + enabled: True, + type: LossLatentSSLStudentTeacher, + weight: 1.0, + loss_fcts : { + "JEPA": { + 'weight': 4, "loss_extra_args": {}, "out_dim": 2048, "head": transformer, + "num_blocks": 6, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 768, + "dropout_rate": 0.1, + target_source_correspondence: {0 : {0 : "subset"} }, }, - } + }, + target_and_aux_calc: { "EMATeacher" : + { ema_ramp_up_ratio : 0.09, + ema_halflife_in_thousands: 1e-3, + model_param_overrides : { + training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} + }, + } + } + } + } model_input: { - "forecasting" : { - # masking strategy: "random", "healpix", "forecast" - masking_strategy: "forecast", + "random_easy" : { + # masking strategy: "random", "forecast" + masking_strategy: "random", + num_samples: 1, + num_steps_input: 1, + masking_strategy_config : { + diffusion_rn : True, + rate : 0.6, + rate_sampling: False }, - } + }, + } + + target_input: { + "random_easy_target" : { + masking_strategy: "healpix", + num_samples: 1, + masking_strategy_config : { rate : 0.2, hl_mask: 0, rate_sampling: False }, + }, + } forecast : - time_step: 06:00:00 - num_steps: 2 - offset: 1 - policy: "fixed" + time_step: 00:00:00 + num_steps: 0 + policy: null # validation config; full validation config is merge of training and validation config @@ -230,24 +258,18 @@ validation_config: # whether to track the exponential moving average of weights for validation validate_with_ema: - enabled : True + enabled : False ema_ramp_up_ratio: 0.09 ema_halflife_in_thousands: 1e-3 - # parameters for validation samples that are written to disk - output : { - # number of samples that are written - num_samples: 0, - # write samples in normalized model space - normalized_samples: False, - # output streams to write; default all - streams: null, - } + # number of validation samples that are written to disk + write_num_samples: 0 + # output streams to write; default all + output_streams: null # run validation before training starts (mainly for model development) validate_before_training: False - - + # test config; full test config is merge of validation and test config # test config is used by default when running inference @@ -273,7 +295,7 @@ wgtags: # issue number. # Expected values are lowercase strings with no spaces, just underscores: # Examples: "rollout_ablation_grid" - exp: null + exp: jepa_muon_layerscale # *** Experiment-specific tags *** # All extra tags (including lists, dictionaries, etc.) are treated # as strings by mlflow, so treat all extra tags as simple string key: value pairs. From c0ce9dde83bef953720441045ecf2a35fb52d62f Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Wed, 4 Feb 2026 17:06:43 +0100 Subject: [PATCH 09/38] Add assert to prevent silent errors --- src/weathergen/train/optimizer.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/weathergen/train/optimizer.py b/src/weathergen/train/optimizer.py index 5dd047b05..65890b76a 100644 --- a/src/weathergen/train/optimizer.py +++ b/src/weathergen/train/optimizer.py @@ -284,7 +284,10 @@ def _create_muon_optimizer( weight_decay=weight_decay, ) else: - logger.info("Using custom Muon implementation (torch.optim.Muon not available)") + logger.warning( + "Using custom Muon implementation (torch.optim.Muon not available). " + "NOTE: This implementation does NOT support FSDP2. Use DDP or single-GPU training." + ) return MuonCustom( param_groups, lr=lr, @@ -536,6 +539,12 @@ class MuonCustom(Optimizer): Note: Muon should only be used for hidden weight layers. Embeddings, output heads, biases, and layer norms should use AdamW. + + WARNING: This implementation does NOT support FSDP2 (Fully Sharded Data Parallel). + The Newton-Schulz orthogonalization requires the FULL gradient matrix, but FSDP2 + shards gradients across GPUs. Computing `X @ X.T` on a sharded gradient gives + mathematically incorrect results. For FSDP2 support, see the distributed version + in the reference implementation: https://github.com/KellerJordan/Muon/blob/master/muon.py """ def __init__( @@ -590,6 +599,22 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() + # Check for FSDP2 on first step (DTensor indicates sharded parameters) + if not hasattr(self, "_fsdp_checked"): + self._fsdp_checked = True + for group in self.param_groups: + for p in group["params"]: + # FSDP2 uses DTensor for sharding + is_fsdp2 = hasattr(p, "_local_tensor") or type(p).__name__ == "DTensor" + assert not is_fsdp2, ( + "MuonCustom does not support FSDP2 (Fully Sharded Data Parallel). " + "The Newton-Schulz orthogonalization requires full gradients, but " + "FSDP2 shards gradients across GPUs, leading to incorrect results. " + "Options: (1) Use DDP instead of FSDP, (2) Use AdamW optimizer, " + "(3) Use torch.optim.Muon (PyTorch >= 2.9) if it supports FSDP. " + "Reference FSDP impl: github.com/KellerJordan/Muon/blob/master/muon.py" + ) + for group in self.param_groups: momentum = group["momentum"] nesterov = group["nesterov"] From 71d2cce021d86e2763e595e91b2c3243abdc3a6e Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Wed, 4 Feb 2026 19:53:44 +0100 Subject: [PATCH 10/38] Add collapse monitoring --- config/config_jepa.yml | 26 +- src/weathergen/train/collapse_monitor.py | 356 +++++++++++++++ .../train/target_and_aux_ssl_teacher.py | 20 + src/weathergen/train/trainer.py | 96 ++++ tests/test_collapse_monitor.py | 410 ++++++++++++++++++ 5 files changed, 907 insertions(+), 1 deletion(-) create mode 100644 src/weathergen/train/collapse_monitor.py create mode 100644 tests/test_collapse_monitor.py diff --git a/config/config_jepa.yml b/config/config_jepa.yml index fc27da8c9..f10f16445 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -130,10 +130,34 @@ data_loading : # config for training training_config: - + # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["student_teacher"] + # Collapse monitoring for SSL training (JEPA/DINO/iBOT) + # Detects representation collapse via various metrics + collapse_monitoring: + enabled: true + compute_frequency: 100 # batches between metric computations + log_frequency: 100 # batches between metric logging + metrics: + effective_rank: + enabled: true + tensor_source: "both" # "student", "teacher", or "both" + sample_size: 2048 # max samples for SVD (0 = no sampling) + singular_values: + enabled: true + top_k: 10 + tensor_source: "both" + sample_size: 2048 + dimension_variance: + enabled: true + tensor_source: "both" # cheap to compute, good early indicator + prototype_entropy: + enabled: true # only applies to DINO + ema_beta: + enabled: true + num_mini_epochs: 32 samples_per_mini_epoch: 4096 shuffle: True diff --git a/src/weathergen/train/collapse_monitor.py b/src/weathergen/train/collapse_monitor.py new file mode 100644 index 000000000..b739908e9 --- /dev/null +++ b/src/weathergen/train/collapse_monitor.py @@ -0,0 +1,356 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +Collapse monitoring metrics for SSL training (JEPA/DINO). + +This module implements metrics to detect representation collapse during self-supervised learning: +- Effective Rank (RankMe): Entropy of normalized singular values +- Singular Value Spectrum: Top-k singular values and concentration ratio +- Per-Dimension Variance: Min/mean/max variance across embedding dimensions +- Prototype Entropy: Normalized entropy of DINO prototype assignments +- EMA Beta: Current teacher momentum value + +References: +- RankMe (ICML 2023): https://arxiv.org/abs/2210.02885 +- C-JEPA (NeurIPS 2024): https://arxiv.org/abs/2410.19560 +""" + +from __future__ import annotations + +import logging +from collections import defaultdict +from typing import Any + +import torch + +logger = logging.getLogger(__name__) + + +class CollapseMonitor: + """ + Monitor for detecting representation collapse during SSL training. + + Computes and caches various collapse indicators that can be logged + at configurable intervals to minimize computational overhead. + """ + + def __init__(self, config: dict[str, Any], device: torch.device) -> None: + """ + Initialize the collapse monitor. + + Args: + config: Configuration dictionary with collapse_monitoring settings. + device: Device to use for computations. + """ + self.device = device + self.enabled = config.get("enabled", False) + self.compute_frequency = config.get("compute_frequency", 100) + self.log_frequency = config.get("log_frequency", 100) + + # Metric configurations + metrics_config = config.get("metrics", {}) + + self.effective_rank_config = metrics_config.get("effective_rank", {}) + self.singular_values_config = metrics_config.get("singular_values", {}) + self.dimension_variance_config = metrics_config.get("dimension_variance", {}) + self.prototype_entropy_config = metrics_config.get("prototype_entropy", {}) + self.ema_beta_config = metrics_config.get("ema_beta", {}) + + # Cache for accumulating metrics between log intervals + self._metrics_cache: dict[str, list[float]] = defaultdict(list) + + def should_compute(self, step: int) -> bool: + """Check if metrics should be computed at this step.""" + return self.enabled and step % self.compute_frequency == 0 + + def should_log(self, step: int) -> bool: + """Check if metrics should be logged at this step.""" + return self.enabled and step % self.log_frequency == 0 + + def compute_metrics( + self, + student_latent: torch.Tensor | None = None, + teacher_latent: torch.Tensor | None = None, + prototype_probs: torch.Tensor | None = None, + ema_beta: float | None = None, + loss_type: str | None = None, + ) -> dict[str, float]: + """ + Compute all enabled collapse monitoring metrics. + + Args: + student_latent: Student model latent representations [B, N, D] or [B, D]. + teacher_latent: Teacher model latent representations [B, N, D] or [B, D]. + prototype_probs: Post-softmax prototype assignment probabilities [B, K] (DINO only). + ema_beta: Current EMA momentum value. + loss_type: Type of SSL loss ("JEPA" or "DINO"). + + Returns: + Dictionary of computed metrics. + """ + if not self.enabled: + return {} + + metrics: dict[str, float] = {} + + # Determine which tensors to monitor based on config + tensors_to_monitor: dict[str, torch.Tensor | None] = {} + + effective_rank_source = self.effective_rank_config.get("tensor_source", "both") + sv_source = self.singular_values_config.get("tensor_source", "both") + var_source = self.dimension_variance_config.get("tensor_source", "both") + + # Build tensor dict based on what's requested + if effective_rank_source in ("student", "both") or sv_source in ( + "student", + "both", + ) or var_source in ("student", "both"): + tensors_to_monitor["student"] = student_latent + + if effective_rank_source in ("teacher", "both") or sv_source in ( + "teacher", + "both", + ) or var_source in ("teacher", "both"): + tensors_to_monitor["teacher"] = teacher_latent + + # Compute effective rank + if self.effective_rank_config.get("enabled", True): + sample_size = self.effective_rank_config.get("sample_size", 2048) + for name, tensor in tensors_to_monitor.items(): + if tensor is not None: + source = self.effective_rank_config.get("tensor_source", "both") + if source == "both" or source == name: + eff_rank = self._compute_effective_rank(tensor, sample_size) + metrics[f"collapse.{name}.effective_rank"] = eff_rank + + # Compute singular value spectrum + if self.singular_values_config.get("enabled", True): + top_k = self.singular_values_config.get("top_k", 10) + sample_size = self.singular_values_config.get("sample_size", 2048) + for name, tensor in tensors_to_monitor.items(): + if tensor is not None: + source = self.singular_values_config.get("tensor_source", "both") + if source == "both" or source == name: + sv_metrics = self._compute_singular_values(tensor, top_k, sample_size) + for key, value in sv_metrics.items(): + metrics[f"collapse.{name}.{key}"] = value + + # Compute per-dimension variance + if self.dimension_variance_config.get("enabled", True): + for name, tensor in tensors_to_monitor.items(): + if tensor is not None: + source = self.dimension_variance_config.get("tensor_source", "both") + if source == "both" or source == name: + var_metrics = self._compute_dimension_variance(tensor) + for key, value in var_metrics.items(): + metrics[f"collapse.{name}.{key}"] = value + + # Compute prototype entropy (DINO only) + if ( + self.prototype_entropy_config.get("enabled", True) + and prototype_probs is not None + and loss_type == "DINO" + ): + entropy = self._compute_prototype_entropy(prototype_probs) + metrics["collapse.dino.prototype_entropy"] = entropy + + # Log EMA beta + if self.ema_beta_config.get("enabled", True) and ema_beta is not None: + metrics["collapse.ema_beta"] = ema_beta + + # Cache metrics for averaging + for key, value in metrics.items(): + self._metrics_cache[key].append(value) + + return metrics + + def get_cached_metrics(self) -> dict[str, float]: + """ + Get averaged cached metrics and clear the cache. + + Returns: + Dictionary of averaged metrics since last call. + """ + averaged_metrics: dict[str, float] = {} + for key, values in self._metrics_cache.items(): + if values: + averaged_metrics[key] = sum(values) / len(values) + + self._metrics_cache.clear() + return averaged_metrics + + def _flatten_to_samples(self, z: torch.Tensor) -> torch.Tensor: + """ + Flatten patch dimension into sample dimension. + + Treats [B, N, D] as [B*N, D] where each patch is an independent sample. + This is consistent with C-JEPA/VICReg approach. + + Args: + z: Tensor of shape [B, N, D] or [B, D]. + + Returns: + Tensor of shape [B*N, D] or [B, D]. + """ + if z.ndim == 3: + return z.reshape(-1, z.shape[-1]) + return z + + def _sample_rows(self, z: torch.Tensor, sample_size: int) -> torch.Tensor: + """ + Randomly sample rows to reduce SVD computation cost. + + Args: + z: Tensor of shape [N, D]. + sample_size: Maximum number of samples (0 = no sampling). + + Returns: + Sampled tensor of shape [min(N, sample_size), D]. + """ + if sample_size <= 0 or z.shape[0] <= sample_size: + return z + + indices = torch.randperm(z.shape[0], device=z.device)[:sample_size] + return z[indices] + + def _compute_effective_rank(self, z: torch.Tensor, sample_size: int = 2048) -> float: + """ + Compute effective rank via entropy of normalized singular values (RankMe). + + The effective rank measures how many dimensions are actually being used + in the representation. A low effective rank indicates collapse. + + Args: + z: Latent representations [B, N, D] or [B, D]. + sample_size: Maximum samples for SVD computation. + + Returns: + Effective rank (exp of entropy of normalized singular values). + """ + z = self._flatten_to_samples(z.detach()) + z = self._sample_rows(z, sample_size) + + # Center the data + z_centered = z - z.mean(dim=0, keepdim=True) + + # Compute SVD + try: + _, s, _ = torch.linalg.svd(z_centered, full_matrices=False) + except RuntimeError: + # SVD can fail on degenerate matrices + logger.warning("SVD failed in effective rank computation") + return 0.0 + + # Normalize singular values to get a probability distribution + s_normalized = s / (s.sum() + 1e-8) + + # Compute entropy + entropy = -torch.sum(s_normalized * torch.log(s_normalized + 1e-8)) + + # Effective rank is exp(entropy) + effective_rank = torch.exp(entropy) + + return effective_rank.item() + + def _compute_singular_values( + self, z: torch.Tensor, top_k: int = 10, sample_size: int = 2048 + ) -> dict[str, float]: + """ + Compute top-k singular values and concentration ratio. + + The concentration ratio (top SV / sum of all SVs) indicates how much + variance is captured by the largest singular value. High concentration + suggests dimensional collapse. + + Args: + z: Latent representations [B, N, D] or [B, D]. + top_k: Number of top singular values to return. + sample_size: Maximum samples for SVD computation. + + Returns: + Dictionary with top-k singular values and concentration ratio. + """ + z = self._flatten_to_samples(z.detach()) + z = self._sample_rows(z, sample_size) + + # Center the data + z_centered = z - z.mean(dim=0, keepdim=True) + + # Compute SVD + try: + _, s, _ = torch.linalg.svd(z_centered, full_matrices=False) + except RuntimeError: + logger.warning("SVD failed in singular value computation") + return {} + + metrics: dict[str, float] = {} + + # Top-k singular values + for i in range(min(top_k, len(s))): + metrics[f"singular_value_{i}"] = s[i].item() + + # Concentration ratio (top SV / sum) + s_sum = s.sum() + 1e-8 + metrics["sv_concentration"] = (s[0] / s_sum).item() + + return metrics + + def _compute_dimension_variance(self, z: torch.Tensor) -> dict[str, float]: + """ + Compute per-dimension variance statistics. + + Low minimum variance indicates "dead" dimensions that are not being used. + Large variance ratio (max/min) suggests imbalanced dimension usage. + + Args: + z: Latent representations [B, N, D] or [B, D]. + + Returns: + Dictionary with var_min, var_mean, var_max. + """ + z = self._flatten_to_samples(z.detach()) + + # Compute variance along sample dimension + var_per_dim = z.var(dim=0) + + return { + "var_min": var_per_dim.min().item(), + "var_mean": var_per_dim.mean().item(), + "var_max": var_per_dim.max().item(), + } + + def _compute_prototype_entropy(self, probs: torch.Tensor) -> float: + """ + Compute normalized entropy of DINO prototype assignments. + + Low entropy indicates collapse to few prototypes. Entropy is normalized + to [0, 1] range where 1 means uniform distribution. + + Args: + probs: Post-softmax prototype assignment probabilities [B, K]. + + Returns: + Normalized entropy in [0, 1]. + """ + probs = probs.detach() + + # Average across batch to get prototype usage distribution + avg_probs = probs.mean(dim=0) + + # Compute entropy + entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-8)) + + # Normalize by maximum possible entropy (uniform distribution) + num_prototypes = probs.shape[1] + max_entropy = torch.log(torch.tensor(float(num_prototypes), device=probs.device)) + + normalized_entropy = entropy / (max_entropy + 1e-8) + + return normalized_entropy.item() diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index cfb252f86..99b1f2860 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -77,6 +77,26 @@ def to_device(self, device) -> EMATeacher: module.to(device) return self + def get_current_beta(self, cur_step: int) -> float: + """ + Get current EMA beta value for monitoring. + + The beta value determines how much the teacher model is updated towards + the student model at each step. Higher beta means slower teacher updates. + + Args: + cur_step: Current training step (typically istep * batch_size). + + Returns: + Current EMA beta value. + """ + halflife_steps = self.ema_model.halflife_steps + rampup_ratio = self.ema_model.rampup_ratio + if rampup_ratio is not None: + halflife_steps = min(halflife_steps, cur_step / 1e3 * rampup_ratio) + beta = 0.5 ** (self.batch_size / max(halflife_steps * 1e3, 1e-6)) + return beta + def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): return_dict = {} diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index e949dc1cc..163ce53a8 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -29,8 +29,10 @@ init_model_and_shard, ) from weathergen.model.utils import apply_fct_to_blocks, set_to_eval +from weathergen.train.collapse_monitor import CollapseMonitor from weathergen.train.loss_calculator import LossCalculator from weathergen.train.lr_scheduler import LearningRateScheduler +from weathergen.train.target_and_aux_ssl_teacher import EMATeacher from weathergen.train.trainer_base import TrainerBase from weathergen.train.utils import ( extract_batch_metadata, @@ -74,6 +76,7 @@ def __init__(self, train_log_freq: Config): self.batch_size_per_gpu = -1 self.batch_size_validation_per_gpu = -1 self.batch_size_test_per_gpu = -1 + self.collapse_monitor: CollapseMonitor | None = None def get_batch_size_total(self, batch_size_per_gpu) -> int: """ @@ -146,6 +149,10 @@ def init(self, cf: Config, devices): self.train_logger = TrainLogger(cf, config.get_path_run(self.cf)) + # Initialize collapse monitor for SSL training + collapse_config = self.training_cfg.get("collapse_monitoring", {}) + self.collapse_monitor = CollapseMonitor(collapse_config, None) # device set later in run() + def get_target_aux_calculators(self, mode_cfg): """ Get target_aux_calculators for given mode_cfg @@ -227,6 +234,9 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): device_type = torch.accelerator.current_accelerator() self.device = torch.device(f"{device_type}:{cf.local_rank}") + # Update collapse monitor device + self.collapse_monitor.device = self.device + # create data loaders self.dataset = MultiStreamDataSampler(cf, self.training_cfg, stage=TRAIN) self.dataset_val = MultiStreamDataSampler(cf, self.validation_cfg, stage=VAL) @@ -501,9 +511,16 @@ def train(self, mini_epoch): if self.validate_with_ema: self.ema_model.update(self.cf.general.istep * batch_size_total, batch_size_total) + # Compute collapse monitoring metrics + if self.collapse_monitor.should_compute(self.cf.general.istep): + self._compute_collapse_metrics(preds, targets_and_auxs) + self._log_terminal(bidx, mini_epoch, TRAIN) if bidx % self.train_log_freq.metrics == 0: self._log(TRAIN) + # Log collapse metrics + if self.collapse_monitor.should_log(self.cf.general.istep): + self._log_collapse_metrics(TRAIN) # save model checkpoint (with designation _latest) if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: @@ -775,3 +792,82 @@ def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): logger.info("\n") self.t_start = time.time() + + def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: + """ + Extract latent tensors from predictions and targets, then compute collapse metrics. + + This method extracts the student and teacher latent representations from the + SSL training outputs and passes them to the collapse monitor. + """ + # Get student latents from predictions (first forecast step) + student_latent = None + teacher_latent = None + prototype_probs = None + ema_beta = None + loss_type = None + + # Find SSL loss type and extract latents + for _loss_name, target_aux in targets_and_auxs.items(): + # Check if this is an EMATeacher-based loss + if hasattr(target_aux, "latent") and target_aux.latent: + # Get the first timestep's latent dict + target_latent_dict = target_aux.latent[0] if target_aux.latent else {} + + # Determine the SSL loss type (JEPA, DINO, iBOT) + for ssl_type in ["JEPA", "DINO", "iBOT"]: + if ssl_type in target_latent_dict: + loss_type = ssl_type + # Get teacher latent + teacher_latent_data = target_latent_dict[ssl_type] + if isinstance(teacher_latent_data, list) and len(teacher_latent_data) > 0: + teacher_latent = teacher_latent_data[0] + elif isinstance(teacher_latent_data, dict): + # Handle LatentState or dict + teacher_latent = teacher_latent_data.get( + "latent", teacher_latent_data + ) + else: + teacher_latent = teacher_latent_data + break + + # Get student latents from predictions + if preds.latent and len(preds.latent) > 0: + pred_latent_dict = preds.latent[0] + for ssl_type in ["JEPA", "DINO", "iBOT"]: + if ssl_type in pred_latent_dict: + student_latent_data = pred_latent_dict[ssl_type] + if isinstance(student_latent_data, list) and len(student_latent_data) > 0: + student_latent = student_latent_data[0] + elif isinstance(student_latent_data, dict): + student_latent = student_latent_data.get("latent", student_latent_data) + else: + student_latent = student_latent_data + loss_type = ssl_type + break + + # Get EMA beta from target_and_aux_calculators + for _calc_name, calculator in self.target_and_aux_calculators.items(): + if isinstance(calculator, EMATeacher): + batch_size_total = self.get_batch_size_total(self.batch_size_per_gpu) + step = batch_size_total * self.cf.general.istep + ema_beta = calculator.get_current_beta(step) + break + + # Ensure tensors are properly formatted + if student_latent is not None and isinstance(student_latent, torch.Tensor): + self.collapse_monitor.compute_metrics( + student_latent=student_latent, + teacher_latent=teacher_latent if isinstance(teacher_latent, torch.Tensor) else None, + prototype_probs=prototype_probs, + ema_beta=ema_beta, + loss_type=loss_type, + ) + + def _log_collapse_metrics(self, stage: Stage) -> None: + """ + Log cached collapse monitoring metrics. + """ + metrics = self.collapse_monitor.get_cached_metrics() + if metrics and is_root(): + self.train_logger.log_metrics(stage, metrics) diff --git a/tests/test_collapse_monitor.py b/tests/test_collapse_monitor.py new file mode 100644 index 000000000..5656205f9 --- /dev/null +++ b/tests/test_collapse_monitor.py @@ -0,0 +1,410 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Unit tests for collapse monitoring metrics.""" + +import pytest +import torch + +from weathergen.train.collapse_monitor import CollapseMonitor + + +@pytest.fixture +def default_config(): + """Default enabled config for collapse monitoring.""" + return { + "enabled": True, + "compute_frequency": 100, + "log_frequency": 100, + "metrics": { + "effective_rank": { + "enabled": True, + "tensor_source": "both", + "sample_size": 2048, + }, + "singular_values": { + "enabled": True, + "top_k": 10, + "tensor_source": "both", + "sample_size": 2048, + }, + "dimension_variance": { + "enabled": True, + "tensor_source": "both", + }, + "prototype_entropy": { + "enabled": True, + }, + "ema_beta": { + "enabled": True, + }, + }, + } + + +@pytest.fixture +def monitor(default_config): + """Create a collapse monitor with default config.""" + device = torch.device("cpu") + return CollapseMonitor(default_config, device) + + +class TestCollapseMonitorInitialization: + """Test CollapseMonitor initialization.""" + + def test_disabled_monitor(self): + """Test that disabled monitor doesn't compute metrics.""" + config = {"enabled": False} + monitor = CollapseMonitor(config, torch.device("cpu")) + assert not monitor.enabled + assert not monitor.should_compute(100) + assert not monitor.should_log(100) + + def test_enabled_monitor(self, default_config): + """Test that enabled monitor computes at correct intervals.""" + monitor = CollapseMonitor(default_config, torch.device("cpu")) + assert monitor.enabled + assert monitor.should_compute(0) + assert monitor.should_compute(100) + assert not monitor.should_compute(50) + + def test_frequency_settings(self): + """Test custom frequency settings.""" + config = { + "enabled": True, + "compute_frequency": 50, + "log_frequency": 200, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + assert monitor.should_compute(50) + assert monitor.should_compute(100) # 100 is a multiple of 50 + assert not monitor.should_compute(75) # 75 is not a multiple of 50 + assert monitor.should_log(200) + assert not monitor.should_log(100) + + +class TestEffectiveRank: + """Test effective rank computation.""" + + def test_full_rank_matrix(self, monitor): + """Full rank random matrix should have effective rank close to min(N, D).""" + torch.manual_seed(42) + # Create a full-rank matrix with orthogonal columns + dim = 64 + num_samples = 128 + z = torch.randn(num_samples, dim) + # Make it more orthogonal via QR decomposition + q, _ = torch.linalg.qr(z.T) + z = q.T # Now z is [dim, dim] with orthogonal rows + z = torch.cat([z, torch.randn(num_samples - dim, dim)], dim=0) + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # For a full-rank matrix, effective rank should be significant portion of D + assert eff_rank > dim * 0.3, f"Expected effective rank > {dim * 0.3}, got {eff_rank}" + + def test_low_rank_matrix(self, monitor): + """Low rank matrix should have effective rank close to actual rank.""" + torch.manual_seed(42) + # Create a rank-5 matrix + actual_rank = 5 + num_samples, dim = 128, 64 + u_mat = torch.randn(num_samples, actual_rank) + v_mat = torch.randn(actual_rank, dim) + z = u_mat @ v_mat + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # Effective rank should be close to actual rank + assert eff_rank < actual_rank * 2, ( + f"Expected effective rank < {actual_rank * 2}, got {eff_rank}" + ) + assert eff_rank > actual_rank * 0.5, ( + f"Expected effective rank > {actual_rank * 0.5}, got {eff_rank}" + ) + + def test_collapsed_matrix(self, monitor): + """Completely collapsed matrix should have effective rank ~1.""" + num_samples, dim = 128, 64 + # All rows are the same (rank 1) + row = torch.randn(1, dim) + z = row.expand(num_samples, dim).clone() + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # Effective rank should be very close to 1 + assert eff_rank < 2, f"Expected effective rank < 2, got {eff_rank}" + + def test_3d_tensor_flattening(self, monitor): + """Test that [B, N, D] tensors are properly flattened.""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + z = torch.randn(batch_size, num_patches, dim) + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # Should compute without error and return reasonable value + assert 1 <= eff_rank <= dim + + +class TestSingularValues: + """Test singular value spectrum computation.""" + + def test_top_k_singular_values(self, monitor): + """Test that top-k singular values are correctly computed.""" + torch.manual_seed(42) + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + + sv_metrics = monitor._compute_singular_values(z, top_k=5, sample_size=0) + + # Check that we got top-5 singular values + assert "singular_value_0" in sv_metrics + assert "singular_value_4" in sv_metrics + assert "singular_value_5" not in sv_metrics + + # Singular values should be in descending order + for i in range(4): + assert sv_metrics[f"singular_value_{i}"] >= sv_metrics[f"singular_value_{i + 1}"] + + def test_concentration_ratio(self, monitor): + """Test singular value concentration ratio.""" + torch.manual_seed(42) + # Create a rank-1 matrix where first SV dominates + num_samples, dim = 128, 64 + # Use outer product to create a truly rank-1 dominated matrix + u_vec = torch.randn(num_samples, 1) + v_vec = torch.randn(1, dim) + z = u_vec @ v_vec * 10 + torch.randn(num_samples, dim) * 0.01 # Strong rank-1 component + + sv_metrics = monitor._compute_singular_values(z, top_k=5, sample_size=0) + + # Concentration should be high when one SV dominates + assert "sv_concentration" in sv_metrics + assert sv_metrics["sv_concentration"] > 0.8 # First SV dominates strongly + + def test_uniform_singular_values(self, monitor): + """Test with approximately uniform singular values.""" + torch.manual_seed(42) + # Create orthogonal matrix with equal singular values + dim = 64 + q, _ = torch.linalg.qr(torch.randn(dim, dim)) + z = q * 10 # Scale uniformly + + sv_metrics = monitor._compute_singular_values(z, top_k=5, sample_size=0) + + # Concentration should be low (close to 1/D) + assert sv_metrics["sv_concentration"] < 0.1 + + +class TestDimensionVariance: + """Test per-dimension variance computation.""" + + def test_random_matrix_balanced_variance(self, monitor): + """Random matrix should have balanced variance across dimensions.""" + torch.manual_seed(42) + num_samples, dim = 1024, 64 + z = torch.randn(num_samples, dim) + + var_metrics = monitor._compute_dimension_variance(z) + + # All variances should be close to 1 for standard normal + assert abs(var_metrics["var_mean"] - 1.0) < 0.2 + # Variance ratio should be small for random matrix + var_ratio = var_metrics["var_max"] / (var_metrics["var_min"] + 1e-8) + assert var_ratio < 5 # Balanced dimensions + + def test_dead_dimensions(self, monitor): + """Test detection of dead (zero-variance) dimensions.""" + torch.manual_seed(42) + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + # Kill some dimensions (set to constant) + z[:, :10] = 0.5 + + var_metrics = monitor._compute_dimension_variance(z) + + # Minimum variance should be very close to 0 (dead dimensions) + assert var_metrics["var_min"] < 1e-6 + + def test_imbalanced_dimensions(self, monitor): + """Test with highly imbalanced dimension variances.""" + torch.manual_seed(42) + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + # Scale some dimensions much more than others + z[:, 0] *= 100 + z[:, 1:10] *= 0.01 + + var_metrics = monitor._compute_dimension_variance(z) + + # Large variance ratio indicates imbalance + var_ratio = var_metrics["var_max"] / (var_metrics["var_min"] + 1e-8) + assert var_ratio > 1000 + + +class TestPrototypeEntropy: + """Test DINO prototype entropy computation.""" + + def test_uniform_prototype_distribution(self, monitor): + """Uniform prototype distribution should have entropy ~1.""" + batch_size, num_prototypes = 128, 64 + # Uniform distribution + probs = torch.ones(batch_size, num_prototypes) / num_prototypes + + entropy = monitor._compute_prototype_entropy(probs) + + # Normalized entropy should be close to 1 + assert abs(entropy - 1.0) < 0.01 + + def test_single_prototype_collapse(self, monitor): + """Collapse to single prototype should have entropy ~0.""" + batch_size, num_prototypes = 128, 64 + # All mass on first prototype + probs = torch.zeros(batch_size, num_prototypes) + probs[:, 0] = 1.0 + + entropy = monitor._compute_prototype_entropy(probs) + + # Normalized entropy should be close to 0 + assert entropy < 0.01 + + def test_partial_collapse(self, monitor): + """Partial collapse should have intermediate entropy.""" + batch_size, num_prototypes = 128, 64 + # Only 4 prototypes used uniformly (much stronger collapse) + probs = torch.zeros(batch_size, num_prototypes) + probs[:, :4] = 0.25 # Only 4 out of 64 prototypes + + entropy = monitor._compute_prototype_entropy(probs) + + # Entropy should be between 0 and 1 (log(4)/log(64) ≈ 0.33) + assert 0.2 < entropy < 0.5 + + +class TestMetricsCaching: + """Test metrics caching and averaging.""" + + def test_cache_accumulation(self, monitor): + """Test that metrics are properly cached.""" + torch.manual_seed(42) + z1 = torch.randn(64, 32) + z2 = torch.randn(64, 32) + + # Compute metrics twice + monitor.compute_metrics(student_latent=z1) + monitor.compute_metrics(student_latent=z2) + + # Cache should contain averaged values + cached = monitor.get_cached_metrics() + assert "collapse.student.effective_rank" in cached + + def test_cache_clear(self, monitor): + """Test that cache is cleared after get_cached_metrics.""" + torch.manual_seed(42) + z = torch.randn(64, 32) + + monitor.compute_metrics(student_latent=z) + _ = monitor.get_cached_metrics() + + # Second call should return empty + cached = monitor.get_cached_metrics() + assert len(cached) == 0 + + +class TestIntegration: + """Integration tests with both student and teacher tensors.""" + + def test_full_metrics_computation(self, monitor): + """Test computing all metrics with both student and teacher.""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + student = torch.randn(batch_size, num_patches, dim) + teacher = torch.randn(batch_size, num_patches, dim) + + metrics = monitor.compute_metrics( + student_latent=student, + teacher_latent=teacher, + ema_beta=0.999, + loss_type="JEPA", + ) + + # Check that both student and teacher metrics are computed + assert "collapse.student.effective_rank" in metrics + assert "collapse.teacher.effective_rank" in metrics + assert "collapse.student.var_min" in metrics + assert "collapse.teacher.var_min" in metrics + assert "collapse.ema_beta" in metrics + assert metrics["collapse.ema_beta"] == 0.999 + + def test_dino_prototype_entropy(self, monitor): + """Test DINO prototype entropy computation.""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + num_prototypes = 128 + student = torch.randn(batch_size, num_patches, dim) + probs = torch.softmax(torch.randn(batch_size, num_prototypes), dim=-1) + + metrics = monitor.compute_metrics( + student_latent=student, + prototype_probs=probs, + loss_type="DINO", + ) + + assert "collapse.dino.prototype_entropy" in metrics + assert 0 <= metrics["collapse.dino.prototype_entropy"] <= 1 + + def test_disabled_metrics(self): + """Test that disabled metrics are not computed.""" + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": {"enabled": False}, + "singular_values": {"enabled": False}, + "dimension_variance": {"enabled": True, "tensor_source": "student"}, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + + torch.manual_seed(42) + z = torch.randn(64, 32) + metrics = monitor.compute_metrics(student_latent=z) + + # Only dimension variance should be computed + assert "collapse.student.var_min" in metrics + assert "collapse.student.effective_rank" not in metrics + assert "collapse.student.singular_value_0" not in metrics + + +class TestSampling: + """Test row sampling for SVD computations.""" + + def test_sampling_reduces_computation(self, monitor): + """Test that sampling works for large tensors.""" + torch.manual_seed(42) + num_samples, dim = 10000, 64 + z = torch.randn(num_samples, dim) + + # With sampling + eff_rank_sampled = monitor._compute_effective_rank(z, sample_size=1024) + # Without sampling + eff_rank_full = monitor._compute_effective_rank(z, sample_size=0) + + # Results should be in same ballpark + assert abs(eff_rank_sampled - eff_rank_full) < eff_rank_full * 0.3 + + def test_no_sampling_when_small(self, monitor): + """Test that small tensors aren't sampled.""" + torch.manual_seed(42) + num_samples, dim = 100, 64 + z = torch.randn(num_samples, dim) + + # Sample size larger than N + sampled = monitor._sample_rows(z, sample_size=1024) + assert sampled.shape[0] == num_samples # No sampling occurred From 1d296117ca92d398e4e3b4785cfd9236a27c1ad0 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Wed, 4 Feb 2026 20:14:17 +0100 Subject: [PATCH 11/38] Fix bug --- src/weathergen/train/trainer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 163ce53a8..7514860a0 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -811,8 +811,14 @@ def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: for _loss_name, target_aux in targets_and_auxs.items(): # Check if this is an EMATeacher-based loss if hasattr(target_aux, "latent") and target_aux.latent: - # Get the first timestep's latent dict - target_latent_dict = target_aux.latent[0] if target_aux.latent else {} + # Handle both cases: + # 1. latent is a list[dict] (as per TargetAuxOutput dataclass) + # 2. latent is a dict (as set directly by EMATeacher) + if isinstance(target_aux.latent, list): + target_latent_dict = target_aux.latent[0] if target_aux.latent else {} + else: + # EMATeacher sets latent directly as a dict + target_latent_dict = target_aux.latent # Determine the SSL loss type (JEPA, DINO, iBOT) for ssl_type in ["JEPA", "DINO", "iBOT"]: From bc92ae7ed53422180d7c2278b2a9a854551e4080 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Wed, 4 Feb 2026 20:25:03 +0100 Subject: [PATCH 12/38] Fix SVD computation failing --- src/weathergen/train/collapse_monitor.py | 34 +++++++++++++++++++++--- src/weathergen/train/trainer.py | 18 +++++++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/src/weathergen/train/collapse_monitor.py b/src/weathergen/train/collapse_monitor.py index b739908e9..3b89d0dec 100644 --- a/src/weathergen/train/collapse_monitor.py +++ b/src/weathergen/train/collapse_monitor.py @@ -199,6 +199,10 @@ def _flatten_to_samples(self, z: torch.Tensor) -> torch.Tensor: Returns: Tensor of shape [B*N, D] or [B, D]. """ + # Convert to float32 for SVD compatibility (bfloat16/float16 can fail) + if z.dtype in (torch.bfloat16, torch.float16): + z = z.float() + if z.ndim == 3: return z.reshape(-1, z.shape[-1]) return z @@ -237,15 +241,26 @@ def _compute_effective_rank(self, z: torch.Tensor, sample_size: int = 2048) -> f z = self._flatten_to_samples(z.detach()) z = self._sample_rows(z, sample_size) + # Validate tensor before SVD + if z.numel() == 0: + logger.warning("Empty tensor in effective rank computation") + return 0.0 + if torch.isnan(z).any() or torch.isinf(z).any(): + logger.warning("NaN/Inf values in tensor for effective rank computation") + return 0.0 + if z.shape[0] < 2 or z.shape[1] < 2: + logger.warning(f"Tensor too small for SVD: shape={z.shape}") + return 0.0 + # Center the data z_centered = z - z.mean(dim=0, keepdim=True) # Compute SVD try: _, s, _ = torch.linalg.svd(z_centered, full_matrices=False) - except RuntimeError: + except RuntimeError as e: # SVD can fail on degenerate matrices - logger.warning("SVD failed in effective rank computation") + logger.warning(f"SVD failed in effective rank computation: {e}, shape={z.shape}") return 0.0 # Normalize singular values to get a probability distribution @@ -280,14 +295,25 @@ def _compute_singular_values( z = self._flatten_to_samples(z.detach()) z = self._sample_rows(z, sample_size) + # Validate tensor before SVD + if z.numel() == 0: + logger.warning("Empty tensor in singular value computation") + return {} + if torch.isnan(z).any() or torch.isinf(z).any(): + logger.warning("NaN/Inf values in tensor for singular value computation") + return {} + if z.shape[0] < 2 or z.shape[1] < 2: + logger.warning(f"Tensor too small for SVD: shape={z.shape}") + return {} + # Center the data z_centered = z - z.mean(dim=0, keepdim=True) # Compute SVD try: _, s, _ = torch.linalg.svd(z_centered, full_matrices=False) - except RuntimeError: - logger.warning("SVD failed in singular value computation") + except RuntimeError as e: + logger.warning(f"SVD failed in singular value computation: {e}, shape={z.shape}") return {} metrics: dict[str, float] = {} diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 7514860a0..af89e3598 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -860,6 +860,19 @@ def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: ema_beta = calculator.get_current_beta(step) break + # Debug logging for tensor extraction + if student_latent is not None: + shape = student_latent.shape if isinstance(student_latent, torch.Tensor) else "N/A" + logger.debug(f"Collapse monitor - student: type={type(student_latent)}, shape={shape}") + else: + logger.debug("Collapse monitor - student_latent is None") + + if teacher_latent is not None: + shape = teacher_latent.shape if isinstance(teacher_latent, torch.Tensor) else "N/A" + logger.debug(f"Collapse monitor - teacher: type={type(teacher_latent)}, shape={shape}") + else: + logger.debug("Collapse monitor - teacher_latent is None") + # Ensure tensors are properly formatted if student_latent is not None and isinstance(student_latent, torch.Tensor): self.collapse_monitor.compute_metrics( @@ -869,6 +882,11 @@ def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: ema_beta=ema_beta, loss_type=loss_type, ) + else: + logger.debug( + f"Collapse monitor - skipping compute_metrics: " + f"student_latent is {'None' if student_latent is None else type(student_latent)}" + ) def _log_collapse_metrics(self, stage: Stage) -> None: """ From 7693c1903fd3f33e533d7422ae6e0a578b247161 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Wed, 4 Feb 2026 21:15:15 +0100 Subject: [PATCH 13/38] Reduce variables logged --- config/config_jepa.yml | 1 - src/weathergen/train/collapse_monitor.py | 17 ++++---- tests/test_collapse_monitor.py | 49 ++++++++++++++---------- 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/config/config_jepa.yml b/config/config_jepa.yml index f10f16445..464cc7e60 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -147,7 +147,6 @@ training_config: sample_size: 2048 # max samples for SVD (0 = no sampling) singular_values: enabled: true - top_k: 10 tensor_source: "both" sample_size: 2048 dimension_variance: diff --git a/src/weathergen/train/collapse_monitor.py b/src/weathergen/train/collapse_monitor.py index 3b89d0dec..d77060ad8 100644 --- a/src/weathergen/train/collapse_monitor.py +++ b/src/weathergen/train/collapse_monitor.py @@ -132,13 +132,12 @@ def compute_metrics( # Compute singular value spectrum if self.singular_values_config.get("enabled", True): - top_k = self.singular_values_config.get("top_k", 10) sample_size = self.singular_values_config.get("sample_size", 2048) for name, tensor in tensors_to_monitor.items(): if tensor is not None: source = self.singular_values_config.get("tensor_source", "both") if source == "both" or source == name: - sv_metrics = self._compute_singular_values(tensor, top_k, sample_size) + sv_metrics = self._compute_singular_values(tensor, sample_size) for key, value in sv_metrics.items(): metrics[f"collapse.{name}.{key}"] = value @@ -275,10 +274,10 @@ def _compute_effective_rank(self, z: torch.Tensor, sample_size: int = 2048) -> f return effective_rank.item() def _compute_singular_values( - self, z: torch.Tensor, top_k: int = 10, sample_size: int = 2048 + self, z: torch.Tensor, sample_size: int = 2048 ) -> dict[str, float]: """ - Compute top-k singular values and concentration ratio. + Compute singular value statistics and concentration ratio. The concentration ratio (top SV / sum of all SVs) indicates how much variance is captured by the largest singular value. High concentration @@ -286,11 +285,10 @@ def _compute_singular_values( Args: z: Latent representations [B, N, D] or [B, D]. - top_k: Number of top singular values to return. sample_size: Maximum samples for SVD computation. Returns: - Dictionary with top-k singular values and concentration ratio. + Dictionary with sv_min, sv_max, sv_mean, and sv_concentration. """ z = self._flatten_to_samples(z.detach()) z = self._sample_rows(z, sample_size) @@ -318,9 +316,10 @@ def _compute_singular_values( metrics: dict[str, float] = {} - # Top-k singular values - for i in range(min(top_k, len(s))): - metrics[f"singular_value_{i}"] = s[i].item() + # Singular value statistics + metrics["sv_min"] = s.min().item() + metrics["sv_max"] = s.max().item() + metrics["sv_mean"] = s.mean().item() # Concentration ratio (top SV / sum) s_sum = s.sum() + 1e-8 diff --git a/tests/test_collapse_monitor.py b/tests/test_collapse_monitor.py index 5656205f9..6a6f2ed8c 100644 --- a/tests/test_collapse_monitor.py +++ b/tests/test_collapse_monitor.py @@ -30,7 +30,6 @@ def default_config(): }, "singular_values": { "enabled": True, - "top_k": 10, "tensor_source": "both", "sample_size": 2048, }, @@ -152,22 +151,23 @@ def test_3d_tensor_flattening(self, monitor): class TestSingularValues: """Test singular value spectrum computation.""" - def test_top_k_singular_values(self, monitor): - """Test that top-k singular values are correctly computed.""" + def test_singular_value_statistics(self, monitor): + """Test that singular value statistics are correctly computed.""" torch.manual_seed(42) num_samples, dim = 128, 64 z = torch.randn(num_samples, dim) - sv_metrics = monitor._compute_singular_values(z, top_k=5, sample_size=0) + sv_metrics = monitor._compute_singular_values(z, sample_size=0) - # Check that we got top-5 singular values - assert "singular_value_0" in sv_metrics - assert "singular_value_4" in sv_metrics - assert "singular_value_5" not in sv_metrics + # Check that we got min, max, mean statistics + assert "sv_min" in sv_metrics + assert "sv_max" in sv_metrics + assert "sv_mean" in sv_metrics + assert "sv_concentration" in sv_metrics - # Singular values should be in descending order - for i in range(4): - assert sv_metrics[f"singular_value_{i}"] >= sv_metrics[f"singular_value_{i + 1}"] + # Max should be >= mean >= min + assert sv_metrics["sv_max"] >= sv_metrics["sv_mean"] + assert sv_metrics["sv_mean"] >= sv_metrics["sv_min"] def test_concentration_ratio(self, monitor): """Test singular value concentration ratio.""" @@ -179,24 +179,31 @@ def test_concentration_ratio(self, monitor): v_vec = torch.randn(1, dim) z = u_vec @ v_vec * 10 + torch.randn(num_samples, dim) * 0.01 # Strong rank-1 component - sv_metrics = monitor._compute_singular_values(z, top_k=5, sample_size=0) + sv_metrics = monitor._compute_singular_values(z, sample_size=0) # Concentration should be high when one SV dominates assert "sv_concentration" in sv_metrics assert sv_metrics["sv_concentration"] > 0.8 # First SV dominates strongly + # Max should be much larger than min for rank-1 dominated matrix + assert sv_metrics["sv_max"] > sv_metrics["sv_min"] * 10 + def test_uniform_singular_values(self, monitor): - """Test with approximately uniform singular values.""" + """Test with random matrix (spread singular values).""" torch.manual_seed(42) - # Create orthogonal matrix with equal singular values - dim = 64 - q, _ = torch.linalg.qr(torch.randn(dim, dim)) - z = q * 10 # Scale uniformly + # Random matrix will have spread singular values + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + + sv_metrics = monitor._compute_singular_values(z, sample_size=0) - sv_metrics = monitor._compute_singular_values(z, top_k=5, sample_size=0) + # Concentration should be relatively low for random matrix + assert sv_metrics["sv_concentration"] < 0.2 - # Concentration should be low (close to 1/D) - assert sv_metrics["sv_concentration"] < 0.1 + # All statistics should be positive + assert sv_metrics["sv_min"] > 0 + assert sv_metrics["sv_max"] > 0 + assert sv_metrics["sv_mean"] > 0 class TestDimensionVariance: @@ -379,7 +386,7 @@ def test_disabled_metrics(self): # Only dimension variance should be computed assert "collapse.student.var_min" in metrics assert "collapse.student.effective_rank" not in metrics - assert "collapse.student.singular_value_0" not in metrics + assert "collapse.student.sv_max" not in metrics class TestSampling: From 7f8de00c84696f9c476fe441a59de693039c5c08 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Wed, 4 Feb 2026 22:03:09 +0100 Subject: [PATCH 14/38] Fix EMA beta value computation --- src/weathergen/model/ema.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/ema.py b/src/weathergen/model/ema.py index 141947863..08f367116 100644 --- a/src/weathergen/model/ema.py +++ b/src/weathergen/model/ema.py @@ -63,8 +63,8 @@ def update(self, cur_step, batch_size): # determine correct interpolation params halflife_steps = self.halflife_steps if self.rampup_ratio is not None: - halflife_steps = min(halflife_steps, cur_step / 1e3 * self.rampup_ratio) - beta = 0.5 ** (batch_size / max(halflife_steps * 1e3, 1e-6)) + halflife_steps = min(halflife_steps, cur_step / self.rampup_ratio) + beta = 0.5 ** (batch_size / max(halflife_steps, 1e-6)) for name, p_ema in self.ema_model.named_parameters(): p_src = self.src_params.get(name, None) From c3eb019adfdfcf5e79d48b0b00710f367fe10caa Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Wed, 4 Feb 2026 22:06:59 +0100 Subject: [PATCH 15/38] Refactor get_current_beta to ema.py --- src/weathergen/model/ema.py | 24 +++++++++++++++---- .../train/target_and_aux_ssl_teacher.py | 19 +-------------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/weathergen/model/ema.py b/src/weathergen/model/ema.py index 08f367116..f126e0d83 100644 --- a/src/weathergen/model/ema.py +++ b/src/weathergen/model/ema.py @@ -55,16 +55,32 @@ def requires_grad_(self, flag: bool): for p in self.ema_model.parameters(): p.requires_grad = flag + def get_current_beta(self, cur_step: int) -> float: + """ + Get current EMA beta value for monitoring. + + The beta value determines how much the teacher model is updated towards + the student model at each step. Higher beta means slower teacher updates. + + Args: + cur_step: Current training step (typically istep * batch_size). + + Returns: + Current EMA beta value. + """ + halflife_steps = self.ema_model.halflife_steps + if self.rampup_ratio is not None: + halflife_steps = min(halflife_steps, cur_step / self.rampup_ratio) + beta = 0.5 ** (self.batch_size / max(halflife_steps, 1e-6)) + return beta + @torch.no_grad() def update(self, cur_step, batch_size): # ensure model remains sharded if self.is_model_sharded: self.ema_model.reshard() # determine correct interpolation params - halflife_steps = self.halflife_steps - if self.rampup_ratio is not None: - halflife_steps = min(halflife_steps, cur_step / self.rampup_ratio) - beta = 0.5 ** (batch_size / max(halflife_steps, 1e-6)) + beta = self.get_current_beta(cur_step) for name, p_ema in self.ema_model.named_parameters(): p_src = self.src_params.get(name, None) diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index 99b1f2860..05213931a 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -78,24 +78,7 @@ def to_device(self, device) -> EMATeacher: return self def get_current_beta(self, cur_step: int) -> float: - """ - Get current EMA beta value for monitoring. - - The beta value determines how much the teacher model is updated towards - the student model at each step. Higher beta means slower teacher updates. - - Args: - cur_step: Current training step (typically istep * batch_size). - - Returns: - Current EMA beta value. - """ - halflife_steps = self.ema_model.halflife_steps - rampup_ratio = self.ema_model.rampup_ratio - if rampup_ratio is not None: - halflife_steps = min(halflife_steps, cur_step / 1e3 * rampup_ratio) - beta = 0.5 ** (self.batch_size / max(halflife_steps * 1e3, 1e-6)) - return beta + return self.ema_model.get_current_beta(cur_step) def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): From 59a0a8972c1bc933ab94c2691cf3b3cf3f684bab Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Wed, 4 Feb 2026 21:33:29 +0000 Subject: [PATCH 16/38] Sensible default for ema in jepa --- config/config_jepa.yml | 4 ++-- src/weathergen/model/ema.py | 4 +++- src/weathergen/train/target_and_aux_ssl_teacher.py | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/config/config_jepa.yml b/config/config_jepa.yml index 464cc7e60..f90a3a88f 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -205,8 +205,8 @@ training_config: }, }, target_and_aux_calc: { "EMATeacher" : - { ema_ramp_up_ratio : 0.09, - ema_halflife_in_thousands: 1e-3, + { ema_ramp_up_ratio : null, + ema_halflife_in_thousands: 1e-1, model_param_overrides : { training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} }, diff --git a/src/weathergen/model/ema.py b/src/weathergen/model/ema.py index f126e0d83..b42d756d4 100644 --- a/src/weathergen/model/ema.py +++ b/src/weathergen/model/ema.py @@ -30,6 +30,7 @@ def __init__( self.rampup_ratio = rampup_ratio self.ema_model = empty_model self.is_model_sharded = is_model_sharded + self.batch_size = 1 # Build a name → param map once self.src_params = dict(self.original_model.named_parameters()) @@ -68,7 +69,7 @@ def get_current_beta(self, cur_step: int) -> float: Returns: Current EMA beta value. """ - halflife_steps = self.ema_model.halflife_steps + halflife_steps = self.halflife_steps if self.rampup_ratio is not None: halflife_steps = min(halflife_steps, cur_step / self.rampup_ratio) beta = 0.5 ** (self.batch_size / max(halflife_steps, 1e-6)) @@ -80,6 +81,7 @@ def update(self, cur_step, batch_size): if self.is_model_sharded: self.ema_model.reshard() # determine correct interpolation params + self.batch_size = batch_size beta = self.get_current_beta(cur_step) for name, p_ema in self.ema_model.named_parameters(): diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index 05213931a..76994221c 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -78,7 +78,8 @@ def to_device(self, device) -> EMATeacher: return self def get_current_beta(self, cur_step: int) -> float: - return self.ema_model.get_current_beta(cur_step) + beta = self.ema_model.get_current_beta(cur_step) + return beta def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): From 4a091c87158ab876b5813f2e1ed48d0651e6b12a Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux <24638638+sophie-xhonneux@users.noreply.github.com> Date: Thu, 5 Feb 2026 22:05:17 +0100 Subject: [PATCH 17/38] New defaults --- config/default_config.yml | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index c94caeb0a..ab2fffe23 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -150,6 +150,29 @@ training_config: # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["student_teacher"] + # Collapse monitoring for SSL training (JEPA/DINO/iBOT) + # Detects representation collapse via various metrics + collapse_monitoring: + enabled: true + compute_frequency: 100 # batches between metric computations + log_frequency: 100 # batches between metric logging + metrics: + effective_rank: + enabled: true + tensor_source: "both" # "student", "teacher", or "both" + sample_size: 2048 # max samples for SVD (0 = no sampling) + singular_values: + enabled: true + tensor_source: "both" + sample_size: 2048 + dimension_variance: + enabled: true + tensor_source: "both" # cheap to compute, good early indicator + prototype_entropy: + enabled: true # only applies to DINO + ema_beta: + enabled: true + num_mini_epochs: 32 samples_per_mini_epoch: 4096 shuffle: True @@ -187,7 +210,7 @@ training_config: eps: 2e-08 muon: # Learning rate multiplier for Muon relative to base LR (muon_lr = base_lr * lr_multiplier) - lr_multiplier: 20.0 + lr_multiplier: 30.0 # Momentum factor for Muon SGD momentum: 0.95 # Use Nesterov momentum @@ -209,8 +232,8 @@ training_config: }, }, target_and_aux_calc: { "EMATeacher" : - { ema_ramp_up_ratio : 0.09, - ema_halflife_in_thousands: 1e-3, + { ema_ramp_up_ratio : null, + ema_halflife_in_thousands: 1e-1, model_param_overrides : { training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} }, @@ -227,7 +250,7 @@ training_config: num_steps_input: 1, masking_strategy_config : { diffusion_rn : True, - rate : 0.6, + rate : 0.4, rate_sampling: False }, }, @@ -237,7 +260,7 @@ training_config: "random_easy_target" : { masking_strategy: "healpix", num_samples: 1, - masking_strategy_config : { rate : 0.2, hl_mask: 0, rate_sampling: False }, + masking_strategy_config : { rate : 0.4, hl_mask: 5, rate_sampling: False }, }, } @@ -268,7 +291,7 @@ validation_config: output_streams: null # run validation before training starts (mainly for model development) - validate_before_training: False + validate_before_training: 8 # test config; full test config is merge of validation and test config # test config is used by default when running inference From 32d951ba3950614d5fee4fce60f24362ca919050 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Fri, 6 Feb 2026 16:55:44 +0100 Subject: [PATCH 18/38] Implement Frozenteacher --- .../weathergen/evaluate/plotting/plotter.py | 2 +- src/weathergen/model/model_interface.py | 54 ++- .../train/target_and_aux_ssl_teacher.py | 194 +++++++- tests/test_encoder_teacher.py | 446 ++++++++++++++++++ 4 files changed, 667 insertions(+), 29 deletions(-) create mode 100644 tests/test_encoder_teacher.py diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 44ec60a3b..e3cdfefa6 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -607,7 +607,7 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]: image_paths += names if image_paths: - image_paths=sorted(image_paths) + image_paths = sorted(image_paths) images = [Image.open(path) for path in image_paths] images[0].save( f"{map_output_dir}/animation_{self.run_id}_{tag}_{sa}_{self.stream}_{region}_{var}.gif", diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 24962da13..b6534aacd 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -21,7 +21,7 @@ ) from torch.distributed.tensor import distribute_tensor -from weathergen.common.config import Config, merge_configs +from weathergen.common.config import Config, load_run_config, merge_configs from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -34,7 +34,7 @@ from weathergen.model.model import Model, ModelParams from weathergen.model.utils import apply_fct_to_blocks, freeze_weights from weathergen.train.target_and_aux_module_base import PhysicalTargetAndAux -from weathergen.train.target_and_aux_ssl_teacher import EMATeacher +from weathergen.train.target_and_aux_ssl_teacher import EMATeacher, FrozenTeacher from weathergen.utils.distributed import is_root from weathergen.utils.utils import get_dtype @@ -327,7 +327,57 @@ def get_target_aux_calculator( batch_size = cf.get("world_size_original", cf.get("world_size")) * batch_size_per_gpu target_aux = EMATeacher(model, ema_model, batch_size, cf.training_config) + elif target_and_aux_calc == "FrozenTeacher": + target_aux = _create_frozen_teacher(cf, dataset, device, target_and_aux_calc_params) + else: raise NotImplementedError(f"{target_and_aux_calc} is not implemented") return target_aux + + +def _create_frozen_teacher(cf: Config, dataset, device, params: dict) -> FrozenTeacher: + """Create a FrozenTeacher from a pre-trained checkpoint. + + Args: + cf: Current training configuration. + dataset: Dataset for model creation. + device: Target device. + params: FrozenTeacher parameters from config, including: + - teacher_run_id (required): Run ID of the pre-trained teacher model. + - teacher_mini_epoch (optional): Mini-epoch to load. Default -1 (latest). + + Returns: + FrozenTeacher instance with loaded and frozen weights. + + Raises: + ValueError: If teacher_run_id is not provided. + """ + teacher_run_id = params.get("teacher_run_id") + teacher_mini_epoch = params.get("teacher_mini_epoch", -1) + + if teacher_run_id is None: + raise ValueError("FrozenTeacher requires 'teacher_run_id' in config") + + if is_root(): + logger.info( + f"Loading FrozenTeacher from run_id={teacher_run_id}, mini_epoch={teacher_mini_epoch}" + ) + + # Load teacher's config (contains full architecture) + teacher_config = load_run_config(teacher_run_id, teacher_mini_epoch, cf.get("model_path")) + + # Create model with teacher's architecture + teacher_model = get_model(teacher_config, "student", dataset, {}) + + # Load weights + teacher_model = load_model( + teacher_config, teacher_model, device, teacher_run_id, teacher_mini_epoch + ) + + # Freeze all parameters + for param in teacher_model.parameters(): + param.requires_grad = False + teacher_model.eval() + + return FrozenTeacher(teacher_model, cf.training_config) diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index cfb252f86..290cd3af1 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -9,7 +9,7 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING import torch @@ -20,19 +20,33 @@ ) from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase, TargetAuxOutput +if TYPE_CHECKING: + pass -class EMATeacher(TargetAndAuxModuleBase): - def __init__(self, model, ema_model, batch_size, training_cfg, **kwargs): - # One of the issues is that the teacher model may have a different architecture - # to the student, e.g. JEPA. So we need quite a flexible way to instantiate the - # the teacher. Because of the device sharding etc that requires quite a bit of - # massaging we assume that the teacher creates the EMA model correctly. However, - # note that you cannot assume that model.state_dict equals ema_model.state_dict - self.ema_model = ema_model - self.batch_size = batch_size - # is a dict of TargetProcessing classes as we may use several in parallel +class EncoderTeacher(TargetAndAuxModuleBase): + """Abstract base class for SSL teachers that use an encoder to generate targets. + + This class provides the common functionality for teacher models in student-teacher + SSL training setups. Subclasses must implement `_forward_teacher()` to define + how the teacher model generates outputs. + Attributes: + teacher_model: The teacher model used to generate target representations. + postprocess_targets: Dict of postprocessing modules for each loss type. + """ + + def __init__(self, teacher_model, training_cfg, **kwargs): + """Initialize the EncoderTeacher. + + Args: + teacher_model: The teacher model (can be EMA model wrapper or frozen model). + training_cfg: Training configuration containing loss specifications. + **kwargs: Additional arguments passed to postprocessing setup. + """ + self.teacher_model = teacher_model + + # Parse SSL losses from config to set up target postprocessing losses_cfg = [ v.loss_fcts for k, v in training_cfg.losses.items() @@ -41,24 +55,37 @@ def __init__(self, model, ema_model, batch_size, training_cfg, **kwargs): # TODO: support multiple LossLatentSSLStudentTeacher loss terms self.postprocess_targets = get_target_postprocessing(losses_cfg[0], training_cfg, **kwargs) - self.reset() + def _forward_teacher(self, model_params, batch): + """Execute forward pass on the teacher model. - def reset(self, batch_size=None): - self.ema_model.reset() - if batch_size is not None: - self.batch_size = batch_size + Subclasses must implement this method to define their specific forward behavior. - def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: - return + Args: + model_params: Model parameters for the forward pass. + batch: Input batch. - def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: - if self.ema_model.is_model_sharded: - self.ema_model.ema_model.reshard() - self.ema_model.update(istep, self.batch_size) + Returns: + Model output with get_latent_prediction() method. + + Raises: + NotImplementedError: If not implemented by subclass. + """ + raise NotImplementedError("Subclasses must implement _forward_teacher()") - def compute(self, bidx, batch, model_params, model) -> tuple[Any, Any]: + def compute(self, istep, batch, model_params, model) -> TargetAuxOutput: + """Compute target representations from the teacher model. + + Args: + istep: Training step index. + batch: Input batch. + model_params: Model parameters. + model: Student model (not used, but part of interface). + + Returns: + TargetAuxOutput containing latent targets and auxiliary outputs. + """ with torch.no_grad(): - outputs = self.ema_model.forward_eval(model_params, batch).get_latent_prediction(0) + outputs = self._forward_teacher(model_params, batch).get_latent_prediction(0) targets = {} for loss_name, target_module in self.postprocess_targets.items(): targets[loss_name] = target_module(outputs[loss_name]) @@ -72,13 +99,128 @@ def compute(self, bidx, batch, model_params, model) -> tuple[Any, Any]: return targets_out - def to_device(self, device) -> EMATeacher: + def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: + """Update state before backward pass. Default is no-op.""" + return + + def to_device(self, device) -> EncoderTeacher: + """Move postprocessors to the specified device. + + Args: + device: Target device. + + Returns: + Self for method chaining. + """ for _, module in self.postprocess_targets.items(): module.to(device) return self +class EMATeacher(EncoderTeacher): + """Teacher using Exponential Moving Average of student weights. + + This teacher maintains an EMA of the student model's weights and uses it + to generate target representations for SSL training. + """ + + def __init__(self, model, ema_model, batch_size, training_cfg, **kwargs): + """Initialize the EMATeacher. + + Args: + model: The student model (used for reference, weights copied to EMA). + ema_model: The EMA model wrapper that maintains averaged weights. + batch_size: Global batch size for EMA update scheduling. + training_cfg: Training configuration. + **kwargs: Additional arguments passed to parent. + + Note: + The teacher model may have a different architecture to the student, + e.g. for JEPA. The ema_model handles weight copying appropriately. + You cannot assume model.state_dict equals ema_model.state_dict. + """ + self.ema_model = ema_model + self.batch_size = batch_size + super().__init__(ema_model, training_cfg, **kwargs) + self.reset() + + def _forward_teacher(self, model_params, batch): + """Execute forward pass using EMA model's forward_eval method.""" + return self.ema_model.forward_eval(model_params, batch) + + def reset(self, batch_size=None): + """Reset EMA model weights to match current student weights. + + Args: + batch_size: Optional new batch size to use for EMA updates. + """ + self.ema_model.reset() + if batch_size is not None: + self.batch_size = batch_size + + def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: + """Update EMA weights after optimizer step. + + Args: + istep: Current training step. + batch: Current batch (unused). + model: Student model (unused, EMA model tracks it internally). + **kwargs: Additional arguments (unused). + """ + if self.ema_model.is_model_sharded: + self.ema_model.ema_model.reshard() + self.ema_model.update(istep, self.batch_size) + + +class FrozenTeacher(EncoderTeacher): + """Teacher loaded from a pre-trained checkpoint with frozen weights. + + This teacher uses a model loaded from a previous training run. The weights + are frozen and never updated during training. This is useful for distillation + from a pre-trained model as described in arXiv:2509.24317. + """ + + def __init__(self, teacher_model, training_cfg, **kwargs): + """Initialize the FrozenTeacher. + + Args: + teacher_model: Pre-trained model to use as teacher. + training_cfg: Training configuration. + **kwargs: Additional arguments passed to parent. + """ + super().__init__(teacher_model, training_cfg, **kwargs) + + # Ensure all parameters are frozen + for param in self.teacher_model.parameters(): + param.requires_grad = False + + # Set to eval mode permanently + self.teacher_model.eval() + + def _forward_teacher(self, model_params, batch): + """Execute forward pass on the frozen teacher model.""" + return self.teacher_model(model_params, batch) + + def reset(self, batch_size=None): + """No-op: frozen teacher weights don't change.""" + pass + + def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: + """No-op: frozen teacher weights don't change.""" + pass + + def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): + """Create postprocessing modules for each SSL loss type. + + Args: + target_losses: Dict of loss configurations keyed by loss name. + training_cfg: Training configuration. + **kwargs: Additional arguments (unused). + + Returns: + Dict mapping loss names to their postprocessing modules. + """ return_dict = {} for loss_name, conf in target_losses.items(): if loss_name == "iBOT": @@ -99,6 +241,6 @@ def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): elif loss_name == "JEPA": return_dict[loss_name] = JEPATargetProcessing() else: - # We skip losses that are not handled by the EMATeacher + # We skip losses that are not handled by the EncoderTeacher continue return return_dict diff --git a/tests/test_encoder_teacher.py b/tests/test_encoder_teacher.py new file mode 100644 index 000000000..f379a27ce --- /dev/null +++ b/tests/test_encoder_teacher.py @@ -0,0 +1,446 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Tests for EncoderTeacher class hierarchy (EMATeacher and FrozenTeacher).""" + +from __future__ import annotations + +import sys +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn as nn + +# Mock flash_attn before importing weathergen modules +sys.modules["flash_attn"] = MagicMock() + +from weathergen.train.target_and_aux_module_base import TargetAuxOutput # noqa: E402 + + +# ============================================================================= +# Fixtures for mock objects +# ============================================================================= + + +class MockLatentState: + """Mock latent state that get_latent_prediction returns.""" + + def __init__(self, data: dict): + self._data = data + + def __getitem__(self, key): + return self._data[key] + + +class MockModelOutput: + """Mock model output with get_latent_prediction method.""" + + def __init__(self, latent_data: dict): + self._latent_data = latent_data + + def get_latent_prediction(self, idx: int): + return self._latent_data + + +class MockSample: + """Mock sample with meta_info.""" + + def __init__(self): + self.meta_info = {"key": "value"} + + +class MockBatch: + """Mock batch for testing compute().""" + + def __init__(self, num_samples: int = 2): + self._samples = [MockSample() for _ in range(num_samples)] + + def get_samples(self): + return self._samples + + def get_output_len(self): + return 1 + + def get_output_idxs(self): + return [0] + + +class MockEMAModel: + """Mock EMA model for testing EMATeacher.""" + + def __init__(self, model: nn.Module): + self.model = model + self.ema_model = model + self.is_model_sharded = False + self._reset_called = False + self._update_called = False + self._update_args = None + + def reset(self): + self._reset_called = True + # Copy weights from model to ema_model (simulating real behavior) + with torch.no_grad(): + for p_ema, p_model in zip( + self.ema_model.parameters(), self.model.parameters() + ): + p_ema.copy_(p_model) + + def update(self, istep: int, batch_size: int): + self._update_called = True + self._update_args = (istep, batch_size) + # Simulate EMA update by slightly modifying weights + with torch.no_grad(): + for p in self.ema_model.parameters(): + p.mul_(0.999).add_(torch.randn_like(p) * 0.001) + + def forward_eval(self, model_params, batch): + return self.ema_model(model_params, batch) + + +@pytest.fixture +def simple_model(): + """Create a simple model for testing.""" + model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 5)) + return model + + +@pytest.fixture +def mock_training_cfg(): + """Create mock training config with JEPA loss.""" + from omegaconf import OmegaConf + + cfg = OmegaConf.create( + { + "losses": { + "ssl_loss": { + "type": "LossLatentSSLStudentTeacher", + "loss_fcts": {"JEPA": {"head": "identity", "out_dim": 256}}, + } + } + } + ) + return cfg + + +@pytest.fixture +def mock_ema_model(simple_model): + """Create mock EMA model wrapping simple_model.""" + return MockEMAModel(simple_model) + + +# ============================================================================= +# Interface Tests - Both EMATeacher and FrozenTeacher must pass these +# ============================================================================= + + +class TestEncoderTeacherInterface: + """Tests for the shared interface of EncoderTeacher subclasses.""" + + def test_ema_teacher_has_required_methods(self): + """Verify EMATeacher has all required interface methods.""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + required_methods = [ + "reset", + "update_state_pre_backward", + "update_state_post_opt_step", + "compute", + "to_device", + ] + for method in required_methods: + assert hasattr(EMATeacher, method), f"EMATeacher missing method: {method}" + assert callable( + getattr(EMATeacher, method) + ), f"EMATeacher.{method} is not callable" + + def test_frozen_teacher_has_required_methods(self): + """Verify FrozenTeacher has all required interface methods.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + required_methods = [ + "reset", + "update_state_pre_backward", + "update_state_post_opt_step", + "compute", + "to_device", + ] + for method in required_methods: + assert hasattr( + FrozenTeacher, method + ), f"FrozenTeacher missing method: {method}" + assert callable( + getattr(FrozenTeacher, method) + ), f"FrozenTeacher.{method} is not callable" + + def test_ema_teacher_update_state_pre_backward_is_noop( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """Verify update_state_pre_backward returns None (no-op).""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + teacher = EMATeacher( + simple_model, mock_ema_model, batch_size=8, training_cfg=mock_training_cfg + ) + result = teacher.update_state_pre_backward( + istep=0, batch=MockBatch(), model=simple_model + ) + assert result is None + + def test_frozen_teacher_update_state_pre_backward_is_noop( + self, simple_model, mock_training_cfg + ): + """Verify FrozenTeacher.update_state_pre_backward returns None (no-op).""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + # Create a callable model for FrozenTeacher + teacher_model = MagicMock() + teacher_model.parameters.return_value = iter([]) + teacher_model.eval = MagicMock() + + teacher = FrozenTeacher(teacher_model, mock_training_cfg) + result = teacher.update_state_pre_backward( + istep=0, batch=MockBatch(), model=simple_model + ) + assert result is None + + def test_ema_teacher_to_device_moves_postprocessors( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """Verify to_device moves postprocessors to specified device.""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + teacher = EMATeacher( + simple_model, mock_ema_model, batch_size=8, training_cfg=mock_training_cfg + ) + + # Track if .to() was called on postprocessors + for name, module in teacher.postprocess_targets.items(): + module.to = MagicMock(return_value=module) + + teacher.to_device("cpu") + + for name, module in teacher.postprocess_targets.items(): + module.to.assert_called_once_with("cpu") + + def test_frozen_teacher_to_device_moves_postprocessors( + self, simple_model, mock_training_cfg + ): + """Verify FrozenTeacher.to_device moves postprocessors.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + teacher_model = MagicMock() + teacher_model.parameters.return_value = iter([]) + teacher_model.eval = MagicMock() + + teacher = FrozenTeacher(teacher_model, mock_training_cfg) + + for name, module in teacher.postprocess_targets.items(): + module.to = MagicMock(return_value=module) + + teacher.to_device("cpu") + + for name, module in teacher.postprocess_targets.items(): + module.to.assert_called_once_with("cpu") + + +# ============================================================================= +# EMATeacher-specific Tests +# ============================================================================= + + +class TestEMATeacher: + """Tests specific to EMATeacher behavior.""" + + def test_ema_reset_calls_ema_model_reset( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """After reset, EMA model's reset method should be called.""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + teacher = EMATeacher( + simple_model, mock_ema_model, batch_size=8, training_cfg=mock_training_cfg + ) + + # Reset is called in __init__, so reset the flag first + mock_ema_model._reset_called = False + + teacher.reset() + assert mock_ema_model._reset_called + + def test_ema_reset_can_update_batch_size( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """Reset can optionally update batch size.""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + teacher = EMATeacher( + simple_model, mock_ema_model, batch_size=8, training_cfg=mock_training_cfg + ) + assert teacher.batch_size == 8 + + teacher.reset(batch_size=16) + assert teacher.batch_size == 16 + + def test_ema_update_post_opt_step_calls_ema_update( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """update_state_post_opt_step should call ema_model.update().""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + teacher = EMATeacher( + simple_model, mock_ema_model, batch_size=8, training_cfg=mock_training_cfg + ) + + teacher.update_state_post_opt_step( + istep=10, batch=MockBatch(), model=simple_model + ) + + assert mock_ema_model._update_called + assert mock_ema_model._update_args == (10, 8) + + +# ============================================================================= +# FrozenTeacher-specific Tests +# ============================================================================= + + +class TestFrozenTeacher: + """Tests specific to FrozenTeacher behavior.""" + + def test_frozen_teacher_init_freezes_parameters(self, mock_training_cfg): + """FrozenTeacher should freeze all model parameters on init.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + # Create a model with actual parameters + model = nn.Linear(10, 5) + assert all(p.requires_grad for p in model.parameters()) + + teacher = FrozenTeacher(model, mock_training_cfg) + + # All parameters should be frozen + assert all(not p.requires_grad for p in teacher.teacher_model.parameters()) + + def test_frozen_teacher_init_sets_eval_mode(self, mock_training_cfg): + """FrozenTeacher should set model to eval mode.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + model = nn.Linear(10, 5) + model.train() + assert model.training + + teacher = FrozenTeacher(model, mock_training_cfg) + + assert not teacher.teacher_model.training + + def test_frozen_reset_is_noop(self, mock_training_cfg): + """FrozenTeacher.reset() should not change weights.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + model = nn.Linear(10, 5) + teacher = FrozenTeacher(model, mock_training_cfg) + + # Get weights before reset + weights_before = { + k: v.clone() for k, v in teacher.teacher_model.state_dict().items() + } + + teacher.reset() + + # Weights should be unchanged + weights_after = teacher.teacher_model.state_dict() + for key in weights_before: + assert torch.equal(weights_before[key], weights_after[key]) + + def test_frozen_update_is_noop(self, mock_training_cfg): + """FrozenTeacher.update_state_post_opt_step() should not change weights.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + model = nn.Linear(10, 5) + teacher = FrozenTeacher(model, mock_training_cfg) + + # Get weights before update + weights_before = { + k: v.clone() for k, v in teacher.teacher_model.state_dict().items() + } + + teacher.update_state_post_opt_step( + istep=10, batch=MockBatch(), model=MagicMock() + ) + + # Weights should be unchanged + weights_after = teacher.teacher_model.state_dict() + for key in weights_before: + assert torch.equal(weights_before[key], weights_after[key]) + + def test_frozen_weights_require_no_grad(self, mock_training_cfg): + """All FrozenTeacher parameters should have requires_grad=False.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 5)) + teacher = FrozenTeacher(model, mock_training_cfg) + + for name, param in teacher.teacher_model.named_parameters(): + assert not param.requires_grad, f"Parameter {name} should have requires_grad=False" + + def test_frozen_model_in_eval_mode(self, mock_training_cfg): + """FrozenTeacher model should always be in eval mode.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + model = nn.Sequential( + nn.Linear(10, 10), nn.BatchNorm1d(10), nn.Linear(10, 5) + ) + model.train() # Start in train mode + + teacher = FrozenTeacher(model, mock_training_cfg) + + # Model should be in eval mode + assert not teacher.teacher_model.training + # All submodules should be in eval mode + for module in teacher.teacher_model.modules(): + assert not module.training + + +# ============================================================================= +# EncoderTeacher Base Class Tests +# ============================================================================= + + +class TestEncoderTeacherBaseClass: + """Tests for EncoderTeacher base class functionality.""" + + def test_encoder_teacher_exists(self): + """Verify EncoderTeacher base class exists.""" + from weathergen.train.target_and_aux_ssl_teacher import EncoderTeacher + + assert EncoderTeacher is not None + + def test_ema_teacher_inherits_from_encoder_teacher(self): + """Verify EMATeacher inherits from EncoderTeacher.""" + from weathergen.train.target_and_aux_ssl_teacher import ( + EMATeacher, + EncoderTeacher, + ) + + assert issubclass(EMATeacher, EncoderTeacher) + + def test_frozen_teacher_inherits_from_encoder_teacher(self): + """Verify FrozenTeacher inherits from EncoderTeacher.""" + from weathergen.train.target_and_aux_ssl_teacher import ( + EncoderTeacher, + FrozenTeacher, + ) + + assert issubclass(FrozenTeacher, EncoderTeacher) + + def test_encoder_teacher_has_forward_teacher_method(self): + """Verify EncoderTeacher has _forward_teacher method.""" + from weathergen.train.target_and_aux_ssl_teacher import EncoderTeacher + + assert hasattr(EncoderTeacher, "_forward_teacher") From 329825245306ed71e43efe1380fa6da2d6dbbe3c Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux <24638638+sophie-xhonneux@users.noreply.github.com> Date: Fri, 6 Feb 2026 17:28:51 +0100 Subject: [PATCH 19/38] Test config --- config/config_jepa.yml | 22 +++++++++++++--------- src/weathergen/model/model_interface.py | 2 +- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/config/config_jepa.yml b/config/config_jepa.yml index fc27da8c9..8896791e0 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -179,16 +179,20 @@ training_config: "num_blocks": 6, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 768, "dropout_rate": 0.1, target_source_correspondence: {0 : {0 : "subset"} }, + }, }, - }, - target_and_aux_calc: { "EMATeacher" : - { ema_ramp_up_ratio : 0.09, - ema_halflife_in_thousands: 1e-3, - model_param_overrides : { - training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} - }, - } - } + target_and_aux_calc: {FrozenTeacher: { + teacher_run_id: "zosrc8ti", # Required + teacher_mini_epoch: -1}}, + + # target_and_aux_calc: { "EMATeacher" : + # { ema_ramp_up_ratio : 0.09, + # ema_halflife_in_thousands: 1e-3, + # model_param_overrides : { + # training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} + # }, + # } + # } } } diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index b6534aacd..a71528012 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -152,7 +152,7 @@ def init_model_and_shard( if is_root(): logger.info(f"Continuing run with id={run_id_contd} at mini_epoch {mini_epoch_contd}.") model = load_model(cf, model, device, run_id_contd, mini_epoch_contd) - elif cf.get("load_chkpt", None).get("run_id", None): + elif cf.get("load_chkpt", {}).get("run_id", None): run_id = cf.load_chkpt.run_id mini_epoch = cf.load_chkpt.get("mini_epoch", -1) if is_root(): From b4c46b1dad681499e595ea20aa7d7d98fa059dd8 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Fri, 6 Feb 2026 17:32:12 +0100 Subject: [PATCH 20/38] Refactor frozen teacher creation --- src/weathergen/model/model_interface.py | 51 +----------------- .../train/target_and_aux_ssl_teacher.py | 54 +++++++++++++++++++ 2 files changed, 56 insertions(+), 49 deletions(-) diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index a71528012..547ba3a90 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -21,7 +21,7 @@ ) from torch.distributed.tensor import distribute_tensor -from weathergen.common.config import Config, load_run_config, merge_configs +from weathergen.common.config import Config, merge_configs from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -328,56 +328,9 @@ def get_target_aux_calculator( target_aux = EMATeacher(model, ema_model, batch_size, cf.training_config) elif target_and_aux_calc == "FrozenTeacher": - target_aux = _create_frozen_teacher(cf, dataset, device, target_and_aux_calc_params) + target_aux = FrozenTeacher.from_pretrained(cf, dataset, device, target_and_aux_calc_params) else: raise NotImplementedError(f"{target_and_aux_calc} is not implemented") return target_aux - - -def _create_frozen_teacher(cf: Config, dataset, device, params: dict) -> FrozenTeacher: - """Create a FrozenTeacher from a pre-trained checkpoint. - - Args: - cf: Current training configuration. - dataset: Dataset for model creation. - device: Target device. - params: FrozenTeacher parameters from config, including: - - teacher_run_id (required): Run ID of the pre-trained teacher model. - - teacher_mini_epoch (optional): Mini-epoch to load. Default -1 (latest). - - Returns: - FrozenTeacher instance with loaded and frozen weights. - - Raises: - ValueError: If teacher_run_id is not provided. - """ - teacher_run_id = params.get("teacher_run_id") - teacher_mini_epoch = params.get("teacher_mini_epoch", -1) - - if teacher_run_id is None: - raise ValueError("FrozenTeacher requires 'teacher_run_id' in config") - - if is_root(): - logger.info( - f"Loading FrozenTeacher from run_id={teacher_run_id}, mini_epoch={teacher_mini_epoch}" - ) - - # Load teacher's config (contains full architecture) - teacher_config = load_run_config(teacher_run_id, teacher_mini_epoch, cf.get("model_path")) - - # Create model with teacher's architecture - teacher_model = get_model(teacher_config, "student", dataset, {}) - - # Load weights - teacher_model = load_model( - teacher_config, teacher_model, device, teacher_run_id, teacher_mini_epoch - ) - - # Freeze all parameters - for param in teacher_model.parameters(): - param.requires_grad = False - teacher_model.eval() - - return FrozenTeacher(teacher_model, cf.training_config) diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index 290cd3af1..7b7ac8a0d 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -9,6 +9,7 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING import torch @@ -23,6 +24,8 @@ if TYPE_CHECKING: pass +logger = logging.getLogger(__name__) + class EncoderTeacher(TargetAndAuxModuleBase): """Abstract base class for SSL teachers that use an encoder to generate targets. @@ -197,6 +200,57 @@ def __init__(self, teacher_model, training_cfg, **kwargs): # Set to eval mode permanently self.teacher_model.eval() + @classmethod + def from_pretrained(cls, cf, dataset, device, params: dict) -> FrozenTeacher: + """Create a FrozenTeacher from a pre-trained checkpoint. + + Args: + cf: Current training configuration. + dataset: Dataset for model creation. + device: Target device. + params: FrozenTeacher parameters from config, including: + - teacher_run_id (required): Run ID of the pre-trained teacher model. + - teacher_mini_epoch (optional): Mini-epoch to load. Default -1 (latest). + + Returns: + FrozenTeacher instance with loaded and frozen weights. + + Raises: + ValueError: If teacher_run_id is not provided. + """ + # Lazy imports to avoid circular dependency with model_interface + from weathergen.common.config import load_run_config, merge_configs + from weathergen.model.model_interface import get_model, load_model + from weathergen.utils.distributed import is_root + + teacher_run_id = params.get("teacher_run_id") + teacher_mini_epoch = params.get("teacher_mini_epoch", -1) + + if teacher_run_id is None: + raise ValueError("FrozenTeacher requires 'teacher_run_id' in config") + + if is_root(): + logger.info( + f"Loading FrozenTeacher from run_id={teacher_run_id}, " + f"mini_epoch={teacher_mini_epoch}" + ) + + # Load teacher's config (contains full architecture) + teacher_config = load_run_config(teacher_run_id, teacher_mini_epoch, cf.get("model_path")) + + # Disable FSDP/DDP for frozen teacher - it's loaded as a simple non-sharded model + teacher_config = merge_configs(teacher_config, {"with_ddp": False, "with_fsdp": False}) + + # Create model with teacher's architecture + teacher_model = get_model(teacher_config, "student", dataset, {}) + + # Load weights + teacher_model = load_model( + teacher_config, teacher_model, device, teacher_run_id, teacher_mini_epoch + ) + + return cls(teacher_model, cf.training_config) + def _forward_teacher(self, model_params, batch): """Execute forward pass on the frozen teacher model.""" return self.teacher_model(model_params, batch) From 590d366fe0a8af12f78759adf4d8c957cc81a555 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Fri, 6 Feb 2026 17:38:19 +0100 Subject: [PATCH 21/38] Fix stuff --- .../train/target_and_aux_ssl_teacher.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index 7b7ac8a0d..df2788323 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -183,16 +183,21 @@ class FrozenTeacher(EncoderTeacher): from a pre-trained model as described in arXiv:2509.24317. """ - def __init__(self, teacher_model, training_cfg, **kwargs): + def __init__(self, teacher_model, training_cfg, teacher_model_params=None, **kwargs): """Initialize the FrozenTeacher. Args: teacher_model: Pre-trained model to use as teacher. training_cfg: Training configuration. + teacher_model_params: Model parameters matching the teacher's architecture. + If None, will use the student's model_params (may cause dimension mismatch). **kwargs: Additional arguments passed to parent. """ super().__init__(teacher_model, training_cfg, **kwargs) + # Store teacher-specific model params + self.teacher_model_params = teacher_model_params + # Ensure all parameters are frozen for param in self.teacher_model.parameters(): param.requires_grad = False @@ -220,6 +225,7 @@ def from_pretrained(cls, cf, dataset, device, params: dict) -> FrozenTeacher: """ # Lazy imports to avoid circular dependency with model_interface from weathergen.common.config import load_run_config, merge_configs + from weathergen.model.model import ModelParams from weathergen.model.model_interface import get_model, load_model from weathergen.utils.distributed import is_root @@ -249,11 +255,23 @@ def from_pretrained(cls, cf, dataset, device, params: dict) -> FrozenTeacher: teacher_config, teacher_model, device, teacher_run_id, teacher_mini_epoch ) - return cls(teacher_model, cf.training_config) + # Create model params matching teacher's architecture + teacher_model_params = ModelParams(teacher_config).create(teacher_config) + teacher_model_params = teacher_model_params.to(device) + + return cls(teacher_model, cf.training_config, teacher_model_params=teacher_model_params) def _forward_teacher(self, model_params, batch): - """Execute forward pass on the frozen teacher model.""" - return self.teacher_model(model_params, batch) + """Execute forward pass on the frozen teacher model. + + Uses the teacher's own model_params instead of the student's to ensure + dimension compatibility. + """ + # Use teacher's model params if available, otherwise fall back to passed-in params + params_to_use = ( + self.teacher_model_params if self.teacher_model_params is not None else model_params + ) + return self.teacher_model(params_to_use, batch) def reset(self, batch_size=None): """No-op: frozen teacher weights don't change.""" From 64ae9f1ba19994b43088b4b0b73fdcf894289963 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Fri, 6 Feb 2026 17:41:33 +0100 Subject: [PATCH 22/38] Fix The issue is we're passing cf.training_config (current training config) but the teacher model's latent heads are defined by teacher_config. We need to pass the teacher's training config so the postprocessing keys match the teacher model's outputs. --- src/weathergen/train/target_and_aux_ssl_teacher.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index df2788323..1606182ba 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -259,7 +259,11 @@ def from_pretrained(cls, cf, dataset, device, params: dict) -> FrozenTeacher: teacher_model_params = ModelParams(teacher_config).create(teacher_config) teacher_model_params = teacher_model_params.to(device) - return cls(teacher_model, cf.training_config, teacher_model_params=teacher_model_params) + # Use teacher's training config for postprocessing setup - the latent head names + # must match the teacher model's output keys + return cls( + teacher_model, teacher_config.training_config, teacher_model_params=teacher_model_params + ) def _forward_teacher(self, model_params, batch): """Execute forward pass on the frozen teacher model. From 4444b04dfa7e73fd59fc4fafafd21b95d00d5ef2 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Fri, 6 Feb 2026 17:50:17 +0100 Subject: [PATCH 23/38] Debug more The fix now: 1. FrozenTeacher inspects the teacher model's actual latent_heads attribute to determine what postprocessing is needed 2. Sets up JEPA/DINO/iBOT postprocessing based on what heads exist (using identity transform for all, with warnings for DINO/iBOT since full centering isn't supported for frozen teachers) 3. Tests updated to use models with latent_heads attributes --- .../train/target_and_aux_ssl_teacher.py | 58 ++++++++++++++--- tests/test_encoder_teacher.py | 65 +++++++++---------- 2 files changed, 80 insertions(+), 43 deletions(-) diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index 1606182ba..e49184341 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -188,16 +188,18 @@ def __init__(self, teacher_model, training_cfg, teacher_model_params=None, **kwa Args: teacher_model: Pre-trained model to use as teacher. - training_cfg: Training configuration. + training_cfg: Training configuration (used for parent class, may be None). teacher_model_params: Model parameters matching the teacher's architecture. If None, will use the student's model_params (may cause dimension mismatch). **kwargs: Additional arguments passed to parent. """ - super().__init__(teacher_model, training_cfg, **kwargs) - - # Store teacher-specific model params + # Don't call parent __init__ - we set up postprocessing based on actual model heads + self.teacher_model = teacher_model self.teacher_model_params = teacher_model_params + # Set up postprocessing based on teacher model's actual latent heads + self.postprocess_targets = self._setup_postprocessing_from_model(teacher_model) + # Ensure all parameters are frozen for param in self.teacher_model.parameters(): param.requires_grad = False @@ -205,6 +207,46 @@ def __init__(self, teacher_model, training_cfg, teacher_model_params=None, **kwa # Set to eval mode permanently self.teacher_model.eval() + def _setup_postprocessing_from_model(self, teacher_model): + """Set up postprocessing based on the teacher model's actual latent heads. + + This inspects the model's latent_heads to determine what postprocessing is needed, + rather than relying on training config which may not have the SSL loss structure. + """ + postprocess = {} + + # Get latent head names from the model + if hasattr(teacher_model, "latent_heads") and teacher_model.latent_heads is not None: + for head_name in teacher_model.latent_heads.keys(): + if head_name == "JEPA": + postprocess[head_name] = JEPATargetProcessing() + elif head_name == "DINO": + # DINO requires more config - use identity for now + # Full DINO postprocessing would need center_momentum, temps, etc. + logger.warning( + "DINO postprocessing for FrozenTeacher using identity transform. " + "Full DINO centering not supported for frozen teachers." + ) + postprocess[head_name] = JEPATargetProcessing() + elif head_name == "iBOT": + # iBOT requires more config - use identity for now + logger.warning( + "iBOT postprocessing for FrozenTeacher using identity transform. " + "Full iBOT centering not supported for frozen teachers." + ) + postprocess[head_name] = JEPATargetProcessing() + else: + # Unknown head type - use identity + postprocess[head_name] = JEPATargetProcessing() + + if not postprocess: + raise ValueError( + "FrozenTeacher model has no latent heads. " + "Ensure the teacher model was trained with SSL losses." + ) + + return postprocess + @classmethod def from_pretrained(cls, cf, dataset, device, params: dict) -> FrozenTeacher: """Create a FrozenTeacher from a pre-trained checkpoint. @@ -259,11 +301,9 @@ def from_pretrained(cls, cf, dataset, device, params: dict) -> FrozenTeacher: teacher_model_params = ModelParams(teacher_config).create(teacher_config) teacher_model_params = teacher_model_params.to(device) - # Use teacher's training config for postprocessing setup - the latent head names - # must match the teacher model's output keys - return cls( - teacher_model, teacher_config.training_config, teacher_model_params=teacher_model_params - ) + # FrozenTeacher sets up postprocessing by inspecting the model's latent heads, + # so we don't need to pass training_config + return cls(teacher_model, training_cfg=None, teacher_model_params=teacher_model_params) def _forward_teacher(self, model_params, batch): """Execute forward pass on the frozen teacher model. diff --git a/tests/test_encoder_teacher.py b/tests/test_encoder_teacher.py index f379a27ce..118324b65 100644 --- a/tests/test_encoder_teacher.py +++ b/tests/test_encoder_teacher.py @@ -135,6 +135,15 @@ def mock_ema_model(simple_model): return MockEMAModel(simple_model) +@pytest.fixture +def model_with_latent_heads(): + """Create a model with latent_heads attribute for FrozenTeacher testing.""" + model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 5)) + # Add latent_heads attribute to mimic real model structure + model.latent_heads = nn.ModuleDict({"JEPA": nn.Identity()}) + return model + + # ============================================================================= # Interface Tests - Both EMATeacher and FrozenTeacher must pass these # ============================================================================= @@ -194,17 +203,12 @@ def test_ema_teacher_update_state_pre_backward_is_noop( assert result is None def test_frozen_teacher_update_state_pre_backward_is_noop( - self, simple_model, mock_training_cfg + self, simple_model, model_with_latent_heads ): """Verify FrozenTeacher.update_state_pre_backward returns None (no-op).""" from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher - # Create a callable model for FrozenTeacher - teacher_model = MagicMock() - teacher_model.parameters.return_value = iter([]) - teacher_model.eval = MagicMock() - - teacher = FrozenTeacher(teacher_model, mock_training_cfg) + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) result = teacher.update_state_pre_backward( istep=0, batch=MockBatch(), model=simple_model ) @@ -230,16 +234,12 @@ def test_ema_teacher_to_device_moves_postprocessors( module.to.assert_called_once_with("cpu") def test_frozen_teacher_to_device_moves_postprocessors( - self, simple_model, mock_training_cfg + self, model_with_latent_heads ): """Verify FrozenTeacher.to_device moves postprocessors.""" from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher - teacher_model = MagicMock() - teacher_model.parameters.return_value = iter([]) - teacher_model.eval = MagicMock() - - teacher = FrozenTeacher(teacher_model, mock_training_cfg) + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) for name, module in teacher.postprocess_targets.items(): module.to = MagicMock(return_value=module) @@ -314,37 +314,34 @@ def test_ema_update_post_opt_step_calls_ema_update( class TestFrozenTeacher: """Tests specific to FrozenTeacher behavior.""" - def test_frozen_teacher_init_freezes_parameters(self, mock_training_cfg): + def test_frozen_teacher_init_freezes_parameters(self, model_with_latent_heads): """FrozenTeacher should freeze all model parameters on init.""" from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher - # Create a model with actual parameters - model = nn.Linear(10, 5) - assert all(p.requires_grad for p in model.parameters()) + # Verify model starts with requires_grad=True + assert all(p.requires_grad for p in model_with_latent_heads.parameters()) - teacher = FrozenTeacher(model, mock_training_cfg) + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) # All parameters should be frozen assert all(not p.requires_grad for p in teacher.teacher_model.parameters()) - def test_frozen_teacher_init_sets_eval_mode(self, mock_training_cfg): + def test_frozen_teacher_init_sets_eval_mode(self, model_with_latent_heads): """FrozenTeacher should set model to eval mode.""" from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher - model = nn.Linear(10, 5) - model.train() - assert model.training + model_with_latent_heads.train() + assert model_with_latent_heads.training - teacher = FrozenTeacher(model, mock_training_cfg) + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) assert not teacher.teacher_model.training - def test_frozen_reset_is_noop(self, mock_training_cfg): + def test_frozen_reset_is_noop(self, model_with_latent_heads): """FrozenTeacher.reset() should not change weights.""" from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher - model = nn.Linear(10, 5) - teacher = FrozenTeacher(model, mock_training_cfg) + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) # Get weights before reset weights_before = { @@ -358,12 +355,11 @@ def test_frozen_reset_is_noop(self, mock_training_cfg): for key in weights_before: assert torch.equal(weights_before[key], weights_after[key]) - def test_frozen_update_is_noop(self, mock_training_cfg): + def test_frozen_update_is_noop(self, model_with_latent_heads): """FrozenTeacher.update_state_post_opt_step() should not change weights.""" from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher - model = nn.Linear(10, 5) - teacher = FrozenTeacher(model, mock_training_cfg) + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) # Get weights before update weights_before = { @@ -379,26 +375,27 @@ def test_frozen_update_is_noop(self, mock_training_cfg): for key in weights_before: assert torch.equal(weights_before[key], weights_after[key]) - def test_frozen_weights_require_no_grad(self, mock_training_cfg): + def test_frozen_weights_require_no_grad(self, model_with_latent_heads): """All FrozenTeacher parameters should have requires_grad=False.""" from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher - model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 5)) - teacher = FrozenTeacher(model, mock_training_cfg) + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) for name, param in teacher.teacher_model.named_parameters(): assert not param.requires_grad, f"Parameter {name} should have requires_grad=False" - def test_frozen_model_in_eval_mode(self, mock_training_cfg): + def test_frozen_model_in_eval_mode(self): """FrozenTeacher model should always be in eval mode.""" from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher model = nn.Sequential( nn.Linear(10, 10), nn.BatchNorm1d(10), nn.Linear(10, 5) ) + # Add latent_heads to model + model.latent_heads = nn.ModuleDict({"JEPA": nn.Identity()}) model.train() # Start in train mode - teacher = FrozenTeacher(model, mock_training_cfg) + teacher = FrozenTeacher(model, training_cfg=None) # Model should be in eval mode assert not teacher.teacher_model.training From c3e52d033156eec5ebd396c6575acdf9c39b4071 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Fri, 6 Feb 2026 17:59:43 +0100 Subject: [PATCH 24/38] Enable frozen models not trained with SSL Summary of Changes Key insight from your feedback: The frozen teacher may have been pre-trained with any method (forecasting, MAE, etc.) and doesn't need to have SSL latent heads. We should: 1. Use the student's training config to know which SSL losses are needed 2. Add identity heads (LatentPredictionHeadIdentity) to the teacher if they don't exist 3. Use identity postprocessing (JEPATargetProcessing) for all SSL losses Changes Made src/weathergen/train/target_and_aux_ssl_teacher.py: - Added import for LatentPredictionHeadIdentity - Rewrote FrozenTeacher.__init__ to: - Accept training_cfg (the student's config) to determine required SSL heads - Call _get_required_ssl_heads() to extract loss names from config - Call _ensure_identity_heads() to add missing heads to the teacher model - Set up identity postprocessing for all SSL losses - Added _get_required_ssl_heads(): extracts SSL loss names from training config, defaults to {"JEPA"} if none found - Added _ensure_identity_heads(): adds LatentPredictionHeadIdentity for any missing heads - Updated from_pretrained() to pass cf.training_config to constructor tests/test_encoder_teacher.py: - Added model_without_latent_heads fixture (simulates a forecasting-only teacher) - Added 5 new tests: - test_frozen_teacher_adds_identity_heads_when_missing - test_frozen_teacher_uses_training_cfg_for_heads - test_frozen_teacher_defaults_to_jepa_without_config - test_frozen_teacher_preserves_existing_heads - test_frozen_teacher_all_postprocessing_is_identity --- .../train/target_and_aux_ssl_teacher.py | 110 +++++++++++------- tests/test_encoder_teacher.py | 110 ++++++++++++++++++ 2 files changed, 177 insertions(+), 43 deletions(-) diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index e49184341..a7c090696 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -14,6 +14,7 @@ import torch +from weathergen.model.engines import LatentPredictionHeadIdentity from weathergen.model.ssl_target_processing import ( DINOTargetProcessing, JEPATargetProcessing, @@ -181,6 +182,10 @@ class FrozenTeacher(EncoderTeacher): This teacher uses a model loaded from a previous training run. The weights are frozen and never updated during training. This is useful for distillation from a pre-trained model as described in arXiv:2509.24317. + + The teacher model may have been pre-trained with any method (forecasting, MAE, etc.) + and doesn't need to have SSL latent heads. Identity heads are added automatically + for any SSL losses the student needs. """ def __init__(self, teacher_model, training_cfg, teacher_model_params=None, **kwargs): @@ -188,17 +193,23 @@ def __init__(self, teacher_model, training_cfg, teacher_model_params=None, **kwa Args: teacher_model: Pre-trained model to use as teacher. - training_cfg: Training configuration (used for parent class, may be None). + training_cfg: Current training configuration containing the student's SSL losses. + Used to determine which identity heads to add to the teacher. teacher_model_params: Model parameters matching the teacher's architecture. If None, will use the student's model_params (may cause dimension mismatch). **kwargs: Additional arguments passed to parent. """ - # Don't call parent __init__ - we set up postprocessing based on actual model heads self.teacher_model = teacher_model self.teacher_model_params = teacher_model_params - # Set up postprocessing based on teacher model's actual latent heads - self.postprocess_targets = self._setup_postprocessing_from_model(teacher_model) + # Get required SSL loss names from current training config + required_heads = self._get_required_ssl_heads(training_cfg) + + # Add identity heads to teacher if it doesn't have them + self._ensure_identity_heads(teacher_model, required_heads) + + # Set up identity postprocessing for all SSL losses + self.postprocess_targets = {name: JEPATargetProcessing() for name in required_heads} # Ensure all parameters are frozen for param in self.teacher_model.parameters(): @@ -207,45 +218,55 @@ def __init__(self, teacher_model, training_cfg, teacher_model_params=None, **kwa # Set to eval mode permanently self.teacher_model.eval() - def _setup_postprocessing_from_model(self, teacher_model): - """Set up postprocessing based on the teacher model's actual latent heads. + def _get_required_ssl_heads(self, training_cfg): + """Extract SSL loss names from training config. + + Args: + training_cfg: Training configuration containing losses specification. - This inspects the model's latent_heads to determine what postprocessing is needed, - rather than relying on training config which may not have the SSL loss structure. + Returns: + Set of SSL loss names (e.g., {"JEPA", "DINO"}). """ - postprocess = {} - - # Get latent head names from the model - if hasattr(teacher_model, "latent_heads") and teacher_model.latent_heads is not None: - for head_name in teacher_model.latent_heads.keys(): - if head_name == "JEPA": - postprocess[head_name] = JEPATargetProcessing() - elif head_name == "DINO": - # DINO requires more config - use identity for now - # Full DINO postprocessing would need center_momentum, temps, etc. - logger.warning( - "DINO postprocessing for FrozenTeacher using identity transform. " - "Full DINO centering not supported for frozen teachers." - ) - postprocess[head_name] = JEPATargetProcessing() - elif head_name == "iBOT": - # iBOT requires more config - use identity for now - logger.warning( - "iBOT postprocessing for FrozenTeacher using identity transform. " - "Full iBOT centering not supported for frozen teachers." - ) - postprocess[head_name] = JEPATargetProcessing() - else: - # Unknown head type - use identity - postprocess[head_name] = JEPATargetProcessing() - - if not postprocess: - raise ValueError( - "FrozenTeacher model has no latent heads. " - "Ensure the teacher model was trained with SSL losses." - ) + if training_cfg is None: + # Default to JEPA if no config provided + return {"JEPA"} + + required_heads = set() + for loss_cfg in training_cfg.losses.values(): + if loss_cfg.type == "LossLatentSSLStudentTeacher": + required_heads.update(loss_cfg.loss_fcts.keys()) + + if not required_heads: + # Default to JEPA if no SSL losses found + required_heads = {"JEPA"} - return postprocess + return required_heads + + def _ensure_identity_heads(self, teacher_model, required_heads): + """Add identity latent heads to teacher model if they don't exist. + + The teacher may have been pre-trained without SSL losses (e.g., forecasting). + We add identity heads so that `get_latent_prediction()` returns the raw + encoder representations for the student's SSL losses. + + Args: + teacher_model: The teacher model to modify. + required_heads: Set of head names that must exist. + """ + import torch.nn as nn + + # Ensure latent_heads exists + if not hasattr(teacher_model, "latent_heads") or teacher_model.latent_heads is None: + teacher_model.latent_heads = nn.ModuleDict() + + # Add missing identity heads + for head_name in required_heads: + if head_name not in teacher_model.latent_heads: + logger.info( + f"FrozenTeacher: Adding identity head '{head_name}' to teacher model " + f"(teacher was likely pre-trained without SSL losses)" + ) + teacher_model.latent_heads[head_name] = LatentPredictionHeadIdentity() @classmethod def from_pretrained(cls, cf, dataset, device, params: dict) -> FrozenTeacher: @@ -301,9 +322,12 @@ def from_pretrained(cls, cf, dataset, device, params: dict) -> FrozenTeacher: teacher_model_params = ModelParams(teacher_config).create(teacher_config) teacher_model_params = teacher_model_params.to(device) - # FrozenTeacher sets up postprocessing by inspecting the model's latent heads, - # so we don't need to pass training_config - return cls(teacher_model, training_cfg=None, teacher_model_params=teacher_model_params) + # Pass current training config so FrozenTeacher knows which SSL heads to add + return cls( + teacher_model, + training_cfg=cf.training_config, + teacher_model_params=teacher_model_params, + ) def _forward_teacher(self, model_params, batch): """Execute forward pass on the frozen teacher model. diff --git a/tests/test_encoder_teacher.py b/tests/test_encoder_teacher.py index 118324b65..311259439 100644 --- a/tests/test_encoder_teacher.py +++ b/tests/test_encoder_teacher.py @@ -144,6 +144,14 @@ def model_with_latent_heads(): return model +@pytest.fixture +def model_without_latent_heads(): + """Create a model WITHOUT latent_heads (like a forecasting-only model).""" + model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 5)) + # No latent_heads - simulates a model trained without SSL + return model + + # ============================================================================= # Interface Tests - Both EMATeacher and FrozenTeacher must pass these # ============================================================================= @@ -403,6 +411,108 @@ def test_frozen_model_in_eval_mode(self): for module in teacher.teacher_model.modules(): assert not module.training + def test_frozen_teacher_adds_identity_heads_when_missing( + self, model_without_latent_heads, mock_training_cfg + ): + """FrozenTeacher should add identity heads if teacher lacks them.""" + from weathergen.model.engines import LatentPredictionHeadIdentity + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + # Model has no latent_heads + assert not hasattr(model_without_latent_heads, "latent_heads") + + teacher = FrozenTeacher(model_without_latent_heads, training_cfg=mock_training_cfg) + + # latent_heads should now exist with JEPA + assert hasattr(teacher.teacher_model, "latent_heads") + assert "JEPA" in teacher.teacher_model.latent_heads + assert isinstance( + teacher.teacher_model.latent_heads["JEPA"], LatentPredictionHeadIdentity + ) + + def test_frozen_teacher_uses_training_cfg_for_heads(self, model_without_latent_heads): + """FrozenTeacher should use training_cfg to determine which heads to add.""" + from omegaconf import OmegaConf + + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + # Config with both JEPA and DINO losses + cfg = OmegaConf.create( + { + "losses": { + "ssl_loss": { + "type": "LossLatentSSLStudentTeacher", + "loss_fcts": { + "JEPA": {"head": "identity"}, + "DINO": {"head": "mlp", "out_dim": 256}, + }, + } + } + } + ) + + teacher = FrozenTeacher(model_without_latent_heads, training_cfg=cfg) + + # Both heads should be added + assert "JEPA" in teacher.teacher_model.latent_heads + assert "DINO" in teacher.teacher_model.latent_heads + # Postprocessing should exist for both + assert "JEPA" in teacher.postprocess_targets + assert "DINO" in teacher.postprocess_targets + + def test_frozen_teacher_defaults_to_jepa_without_config(self, model_without_latent_heads): + """FrozenTeacher should default to JEPA head when no config provided.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + teacher = FrozenTeacher(model_without_latent_heads, training_cfg=None) + + # Should default to JEPA + assert "JEPA" in teacher.teacher_model.latent_heads + assert "JEPA" in teacher.postprocess_targets + + def test_frozen_teacher_preserves_existing_heads(self, model_with_latent_heads, mock_training_cfg): + """FrozenTeacher should not overwrite existing latent heads.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + # Get reference to original head + original_head = model_with_latent_heads.latent_heads["JEPA"] + + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=mock_training_cfg) + + # Original head should be preserved (same object) + assert teacher.teacher_model.latent_heads["JEPA"] is original_head + + def test_frozen_teacher_all_postprocessing_is_identity(self, model_without_latent_heads): + """All FrozenTeacher postprocessing should use identity (JEPATargetProcessing).""" + from omegaconf import OmegaConf + + from weathergen.model.ssl_target_processing import JEPATargetProcessing + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + # Config with multiple SSL losses + cfg = OmegaConf.create( + { + "losses": { + "ssl_loss": { + "type": "LossLatentSSLStudentTeacher", + "loss_fcts": { + "JEPA": {"head": "identity"}, + "DINO": {"head": "mlp"}, + "iBOT": {"head": "mlp"}, + }, + } + } + } + ) + + teacher = FrozenTeacher(model_without_latent_heads, training_cfg=cfg) + + # All postprocessors should be JEPATargetProcessing (identity) + for name, processor in teacher.postprocess_targets.items(): + assert isinstance(processor, JEPATargetProcessing), ( + f"Postprocessor for {name} should be JEPATargetProcessing" + ) + # ============================================================================= # EncoderTeacher Base Class Tests From 211f47735394c49dd2a28f5051b615f740a74e1f Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Fri, 6 Feb 2026 18:23:21 +0100 Subject: [PATCH 25/38] Improve code quality --- .../train/target_and_aux_ssl_teacher.py | 270 ++++++++++++++---- tests/test_encoder_teacher.py | 118 ++++++++ 2 files changed, 339 insertions(+), 49 deletions(-) diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index a7c090696..94e4fc6c7 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING import torch +import torch.nn as nn from weathergen.model.engines import LatentPredictionHeadIdentity from weathergen.model.ssl_target_processing import ( @@ -23,7 +24,9 @@ from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase, TargetAuxOutput if TYPE_CHECKING: - pass + from omegaconf import DictConfig + + from weathergen.common.config import Config logger = logging.getLogger(__name__) @@ -40,23 +43,46 @@ class EncoderTeacher(TargetAndAuxModuleBase): postprocess_targets: Dict of postprocessing modules for each loss type. """ - def __init__(self, teacher_model, training_cfg, **kwargs): + def __init__(self, teacher_model, training_cfg: DictConfig, **kwargs): """Initialize the EncoderTeacher. Args: teacher_model: The teacher model (can be EMA model wrapper or frozen model). training_cfg: Training configuration containing loss specifications. + Must have `losses` attribute with at least one LossLatentSSLStudentTeacher. **kwargs: Additional arguments passed to postprocessing setup. + + Raises: + ValueError: If training_cfg has no LossLatentSSLStudentTeacher losses. """ self.teacher_model = teacher_model # Parse SSL losses from config to set up target postprocessing + assert hasattr(training_cfg, "losses"), ( + f"EncoderTeacher requires training_cfg with 'losses' attribute, " + f"got {type(training_cfg).__name__}" + ) + losses_cfg = [ v.loss_fcts for k, v in training_cfg.losses.items() if v.type == "LossLatentSSLStudentTeacher" ] + + if not losses_cfg: + raise ValueError( + "EncoderTeacher requires at least one 'LossLatentSSLStudentTeacher' loss " + "in training_config.losses. Found loss types: " + f"{[v.type for v in training_cfg.losses.values()]}" + ) + # TODO: support multiple LossLatentSSLStudentTeacher loss terms + if len(losses_cfg) > 1: + logger.warning( + f"Found {len(losses_cfg)} LossLatentSSLStudentTeacher losses, " + "but only the first one is used for target postprocessing." + ) + self.postprocess_targets = get_target_postprocessing(losses_cfg[0], training_cfg, **kwargs) def _forward_teacher(self, model_params, batch): @@ -76,26 +102,44 @@ def _forward_teacher(self, model_params, batch): """ raise NotImplementedError("Subclasses must implement _forward_teacher()") - def compute(self, istep, batch, model_params, model) -> TargetAuxOutput: + def compute(self, istep: int, batch, model_params, model) -> TargetAuxOutput: """Compute target representations from the teacher model. Args: istep: Training step index. - batch: Input batch. - model_params: Model parameters. + batch: Input batch with get_samples(), get_output_len(), get_output_idxs() methods. + model_params: Model parameters for the forward pass. model: Student model (not used, but part of interface). Returns: TargetAuxOutput containing latent targets and auxiliary outputs. + + Raises: + KeyError: If teacher model doesn't output a required loss type. """ with torch.no_grad(): - outputs = self._forward_teacher(model_params, batch).get_latent_prediction(0) + model_output = self._forward_teacher(model_params, batch) + outputs = model_output.get_latent_prediction(0) + targets = {} for loss_name, target_module in self.postprocess_targets.items(): + if loss_name not in outputs: + available_keys = list(outputs.keys()) if hasattr(outputs, "keys") else "N/A" + raise KeyError( + f"Teacher model output missing key '{loss_name}'. " + f"Available keys: {available_keys}. " + f"Ensure teacher model has latent head for '{loss_name}'." + ) targets[loss_name] = target_module(outputs[loss_name]) - # collect target meta-information for selected samples - aux_outputs = [list(sample.meta_info.values())[0] for sample in batch.get_samples()] + # Collect target meta-information for selected samples + samples = batch.get_samples() + aux_outputs = [] + for sample in samples: + if sample.meta_info: + aux_outputs.append(list(sample.meta_info.values())[0]) + else: + aux_outputs.append(None) targets_out = TargetAuxOutput(batch.get_output_len(), batch.get_output_idxs()) targets_out.latent = targets @@ -128,21 +172,34 @@ class EMATeacher(EncoderTeacher): to generate target representations for SSL training. """ - def __init__(self, model, ema_model, batch_size, training_cfg, **kwargs): + def __init__(self, model, ema_model, batch_size: int, training_cfg: DictConfig, **kwargs): """Initialize the EMATeacher. Args: model: The student model (used for reference, weights copied to EMA). ema_model: The EMA model wrapper that maintains averaged weights. - batch_size: Global batch size for EMA update scheduling. - training_cfg: Training configuration. + Must have reset(), update(), forward_eval() methods. + batch_size: Global batch size for EMA update scheduling. Must be positive. + training_cfg: Training configuration with SSL loss specifications. **kwargs: Additional arguments passed to parent. Note: The teacher model may have a different architecture to the student, e.g. for JEPA. The ema_model handles weight copying appropriately. You cannot assume model.state_dict equals ema_model.state_dict. + + Raises: + ValueError: If batch_size is not positive. + AssertionError: If ema_model lacks required methods. """ + if batch_size <= 0: + raise ValueError(f"batch_size must be positive, got {batch_size}") + + # Validate ema_model interface + assert hasattr(ema_model, "reset"), "ema_model must have reset() method" + assert hasattr(ema_model, "update"), "ema_model must have update() method" + assert hasattr(ema_model, "forward_eval"), "ema_model must have forward_eval() method" + self.ema_model = ema_model self.batch_size = batch_size super().__init__(ema_model, training_cfg, **kwargs) @@ -186,81 +243,125 @@ class FrozenTeacher(EncoderTeacher): The teacher model may have been pre-trained with any method (forecasting, MAE, etc.) and doesn't need to have SSL latent heads. Identity heads are added automatically for any SSL losses the student needs. + + Note: + This class intentionally does NOT call super().__init__() because: + 1. It sets up identity postprocessing (JEPATargetProcessing) for ALL losses, + regardless of what the student config specifies for DINO/iBOT + 2. The parent class would try to parse the teacher's training config for SSL losses, + but the teacher may have been trained without SSL (e.g., forecasting only) + + Warning: + This class modifies the teacher_model in-place by adding latent_heads if missing. """ - def __init__(self, teacher_model, training_cfg, teacher_model_params=None, **kwargs): + def __init__( + self, + teacher_model: nn.Module, + training_cfg: DictConfig | None, + teacher_model_params=None, + **kwargs, + ): """Initialize the FrozenTeacher. Args: - teacher_model: Pre-trained model to use as teacher. + teacher_model: Pre-trained model to use as teacher. Will be modified in-place + to add identity latent heads if they don't exist. training_cfg: Current training configuration containing the student's SSL losses. Used to determine which identity heads to add to the teacher. - teacher_model_params: Model parameters matching the teacher's architecture. - If None, will use the student's model_params (may cause dimension mismatch). - **kwargs: Additional arguments passed to parent. + If None, defaults to adding a JEPA head. + teacher_model_params: Model parameters matching the teacher's architecture + (positional embeddings, q_cells, etc.). If None, will use the student's + model_params which may cause dimension mismatch if architectures differ. + **kwargs: Additional arguments (unused, for interface compatibility). """ + # Note: We intentionally don't call super().__init__() - see class docstring self.teacher_model = teacher_model self.teacher_model_params = teacher_model_params # Get required SSL loss names from current training config required_heads = self._get_required_ssl_heads(training_cfg) + assert len(required_heads) > 0, "No SSL heads required - this should never happen" - # Add identity heads to teacher if it doesn't have them + # Add identity heads to teacher if it doesn't have them (modifies model in-place) self._ensure_identity_heads(teacher_model, required_heads) # Set up identity postprocessing for all SSL losses + # FrozenTeacher always uses identity (JEPATargetProcessing) regardless of loss type self.postprocess_targets = {name: JEPATargetProcessing() for name in required_heads} - # Ensure all parameters are frozen + # Freeze all parameters for param in self.teacher_model.parameters(): param.requires_grad = False - # Set to eval mode permanently + # Set to eval mode permanently (affects BatchNorm, Dropout, etc.) self.teacher_model.eval() - def _get_required_ssl_heads(self, training_cfg): + def _get_required_ssl_heads(self, training_cfg: DictConfig | None) -> set[str]: """Extract SSL loss names from training config. Args: training_cfg: Training configuration containing losses specification. + If None, defaults to {"JEPA"}. Returns: - Set of SSL loss names (e.g., {"JEPA", "DINO"}). + Set of SSL loss names (e.g., {"JEPA", "DINO"}). Never empty. """ if training_cfg is None: - # Default to JEPA if no config provided + logger.debug("FrozenTeacher: No training_cfg provided, defaulting to JEPA head") + return {"JEPA"} + + if not hasattr(training_cfg, "losses"): + logger.warning( + "FrozenTeacher: training_cfg has no 'losses' attribute, defaulting to JEPA head" + ) return {"JEPA"} required_heads = set() - for loss_cfg in training_cfg.losses.values(): + for loss_name, loss_cfg in training_cfg.losses.items(): + if not hasattr(loss_cfg, "type"): + continue if loss_cfg.type == "LossLatentSSLStudentTeacher": - required_heads.update(loss_cfg.loss_fcts.keys()) + if hasattr(loss_cfg, "loss_fcts"): + required_heads.update(loss_cfg.loss_fcts.keys()) + else: + logger.warning( + f"FrozenTeacher: Loss '{loss_name}' has type LossLatentSSLStudentTeacher " + "but no loss_fcts, skipping" + ) if not required_heads: - # Default to JEPA if no SSL losses found - required_heads = {"JEPA"} + logger.debug( + "FrozenTeacher: No LossLatentSSLStudentTeacher losses found in config, " + "defaulting to JEPA head" + ) + return {"JEPA"} + logger.debug(f"FrozenTeacher: Required SSL heads from config: {required_heads}") return required_heads - def _ensure_identity_heads(self, teacher_model, required_heads): + def _ensure_identity_heads(self, teacher_model: nn.Module, required_heads: set[str]) -> None: """Add identity latent heads to teacher model if they don't exist. The teacher may have been pre-trained without SSL losses (e.g., forecasting). We add identity heads so that `get_latent_prediction()` returns the raw - encoder representations for the student's SSL losses. + encoder representations (specifically, patch_tokens from LatentState) for + the student's SSL losses. + + Warning: + This method modifies teacher_model IN-PLACE by adding to its latent_heads. Args: - teacher_model: The teacher model to modify. - required_heads: Set of head names that must exist. + teacher_model: The teacher model to modify. Will have latent_heads added/modified. + required_heads: Set of head names that must exist (e.g., {"JEPA", "DINO"}). """ - import torch.nn as nn - - # Ensure latent_heads exists + # Ensure latent_heads ModuleDict exists if not hasattr(teacher_model, "latent_heads") or teacher_model.latent_heads is None: + logger.info("FrozenTeacher: Teacher model has no latent_heads, creating ModuleDict") teacher_model.latent_heads = nn.ModuleDict() # Add missing identity heads - for head_name in required_heads: + for head_name in sorted(required_heads): # sorted for deterministic logging if head_name not in teacher_model.latent_heads: logger.info( f"FrozenTeacher: Adding identity head '{head_name}' to teacher model " @@ -269,22 +370,32 @@ def _ensure_identity_heads(self, teacher_model, required_heads): teacher_model.latent_heads[head_name] = LatentPredictionHeadIdentity() @classmethod - def from_pretrained(cls, cf, dataset, device, params: dict) -> FrozenTeacher: + def from_pretrained(cls, cf: Config, dataset, device, params: dict) -> FrozenTeacher: """Create a FrozenTeacher from a pre-trained checkpoint. + This factory method: + 1. Loads the teacher's config from the checkpoint + 2. Creates a model with the teacher's architecture + 3. Loads the pre-trained weights + 4. Creates ModelParams matching the teacher's architecture + 5. Returns a FrozenTeacher instance + Args: - cf: Current training configuration. - dataset: Dataset for model creation. - device: Target device. + cf: Current training configuration. Used for: + - model_path: Where to find saved models + - training_config: To determine which SSL heads are needed + dataset: Dataset for model creation (provides input/output dimensions). + device: Target device (e.g., "cuda:0", "cpu"). params: FrozenTeacher parameters from config, including: - - teacher_run_id (required): Run ID of the pre-trained teacher model. + - teacher_run_id (required): 8-character run ID of the pre-trained teacher. - teacher_mini_epoch (optional): Mini-epoch to load. Default -1 (latest). Returns: FrozenTeacher instance with loaded and frozen weights. Raises: - ValueError: If teacher_run_id is not provided. + ValueError: If teacher_run_id is not provided or invalid. + FileNotFoundError: If checkpoint doesn't exist (from load_run_config/load_model). """ # Lazy imports to avoid circular dependency with model_interface from weathergen.common.config import load_run_config, merge_configs @@ -295,8 +406,21 @@ def from_pretrained(cls, cf, dataset, device, params: dict) -> FrozenTeacher: teacher_run_id = params.get("teacher_run_id") teacher_mini_epoch = params.get("teacher_mini_epoch", -1) + # Validate teacher_run_id if teacher_run_id is None: - raise ValueError("FrozenTeacher requires 'teacher_run_id' in config") + raise ValueError( + "FrozenTeacher requires 'teacher_run_id' in config. " + "Example config:\n" + " target_and_aux_calc:\n" + " FrozenTeacher:\n" + " teacher_run_id: 'a1b2c3d4'" + ) + + if not isinstance(teacher_run_id, str) or len(teacher_run_id) == 0: + raise ValueError( + f"teacher_run_id must be a non-empty string, got {type(teacher_run_id).__name__}: " + f"{teacher_run_id!r}" + ) if is_root(): logger.info( @@ -305,23 +429,32 @@ def from_pretrained(cls, cf, dataset, device, params: dict) -> FrozenTeacher: ) # Load teacher's config (contains full architecture) - teacher_config = load_run_config(teacher_run_id, teacher_mini_epoch, cf.get("model_path")) + model_path = cf.get("model_path") + assert model_path is not None, "cf.model_path is required to load FrozenTeacher checkpoint" + + teacher_config = load_run_config(teacher_run_id, teacher_mini_epoch, model_path) # Disable FSDP/DDP for frozen teacher - it's loaded as a simple non-sharded model + # This avoids complications with distributed training for the teacher teacher_config = merge_configs(teacher_config, {"with_ddp": False, "with_fsdp": False}) # Create model with teacher's architecture teacher_model = get_model(teacher_config, "student", dataset, {}) - # Load weights + # Load weights from checkpoint teacher_model = load_model( teacher_config, teacher_model, device, teacher_run_id, teacher_mini_epoch ) # Create model params matching teacher's architecture + # This includes positional embeddings, q_cells, etc. that depend on architecture teacher_model_params = ModelParams(teacher_config).create(teacher_config) teacher_model_params = teacher_model_params.to(device) + if is_root(): + num_params = sum(p.numel() for p in teacher_model.parameters()) + logger.info(f"FrozenTeacher loaded with {num_params:,} parameters") + # Pass current training config so FrozenTeacher knows which SSL heads to add return cls( teacher_model, @@ -350,20 +483,51 @@ def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: pass -def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): +def get_target_postprocessing( + target_losses: dict[str, DictConfig], training_cfg: DictConfig, **kwargs +) -> dict[str, nn.Module]: """Create postprocessing modules for each SSL loss type. + This function creates the appropriate postprocessing module for each SSL loss + based on its configuration. The postprocessing is applied to teacher outputs + before computing the student-teacher loss. + + - JEPA: Identity (no postprocessing) + - DINO: Centering and temperature sharpening + - iBOT: Patch-level centering and temperature sharpening + Args: - target_losses: Dict of loss configurations keyed by loss name. - training_cfg: Training configuration. - **kwargs: Additional arguments (unused). + target_losses: Dict of loss configurations keyed by loss name (e.g., "JEPA", "DINO"). + Each value should have the required config keys for that loss type. + training_cfg: Training configuration (currently unused, reserved for future use). + **kwargs: Additional arguments (currently unused). Returns: - Dict mapping loss names to their postprocessing modules. + Dict mapping loss names to their postprocessing nn.Module instances. + + Raises: + KeyError: If a loss config is missing required keys (e.g., out_dim for DINO). + + Example: + >>> target_losses = {"JEPA": {"head": "identity"}, "DINO": {"out_dim": 256, ...}} + >>> postprocessors = get_target_postprocessing(target_losses, training_cfg) + >>> postprocessors["JEPA"](teacher_output) # Identity transform """ return_dict = {} for loss_name, conf in target_losses.items(): if loss_name == "iBOT": + # Validate required keys + required_keys = [ + "out_dim", + "center_momentum", + "loss_extra_args", + "teacher_temp", + "teacher_style", + ] + missing = [k for k in required_keys if k not in conf] + if missing: + raise KeyError(f"iBOT loss config missing required keys: {missing}") + return_dict[loss_name] = iBOTPatchTargetProcessing( patch_out_dim=conf["out_dim"], center_momentum=conf["center_momentum"], @@ -372,6 +536,12 @@ def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): teacher_style=conf["teacher_style"], ) elif loss_name == "DINO": + # Validate required keys + required_keys = ["out_dim", "center_momentum", "loss_extra_args", "teacher_style"] + missing = [k for k in required_keys if k not in conf] + if missing: + raise KeyError(f"DINO loss config missing required keys: {missing}") + return_dict[loss_name] = DINOTargetProcessing( out_dim=conf["out_dim"], center_momentum=conf["center_momentum"], @@ -381,6 +551,8 @@ def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): elif loss_name == "JEPA": return_dict[loss_name] = JEPATargetProcessing() else: - # We skip losses that are not handled by the EncoderTeacher + # Skip losses that are not handled by the EncoderTeacher + logger.debug(f"get_target_postprocessing: Skipping unknown loss type '{loss_name}'") continue + return return_dict diff --git a/tests/test_encoder_teacher.py b/tests/test_encoder_teacher.py index 311259439..bb53f234e 100644 --- a/tests/test_encoder_teacher.py +++ b/tests/test_encoder_teacher.py @@ -551,3 +551,121 @@ def test_encoder_teacher_has_forward_teacher_method(self): from weathergen.train.target_and_aux_ssl_teacher import EncoderTeacher assert hasattr(EncoderTeacher, "_forward_teacher") + + +# ============================================================================= +# Validation and Error Handling Tests +# ============================================================================= + + +class TestValidationAndErrorHandling: + """Tests for input validation and error handling.""" + + def test_ema_teacher_rejects_zero_batch_size( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """EMATeacher should reject batch_size <= 0.""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + with pytest.raises(ValueError, match="batch_size must be positive"): + EMATeacher( + simple_model, mock_ema_model, batch_size=0, training_cfg=mock_training_cfg + ) + + def test_ema_teacher_rejects_negative_batch_size( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """EMATeacher should reject negative batch_size.""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + with pytest.raises(ValueError, match="batch_size must be positive"): + EMATeacher( + simple_model, mock_ema_model, batch_size=-5, training_cfg=mock_training_cfg + ) + + def test_encoder_teacher_rejects_config_without_ssl_losses(self, simple_model): + """EncoderTeacher should reject config with no LossLatentSSLStudentTeacher.""" + from omegaconf import OmegaConf + + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + # Config with only physical loss, no SSL + cfg = OmegaConf.create( + { + "losses": { + "physical_loss": { + "type": "LossPhysical", + "weight": 1.0, + } + } + } + ) + + mock_ema = MagicMock() + mock_ema.reset = MagicMock() + mock_ema.update = MagicMock() + mock_ema.forward_eval = MagicMock() + + with pytest.raises(ValueError, match="LossLatentSSLStudentTeacher"): + EMATeacher(simple_model, mock_ema, batch_size=8, training_cfg=cfg) + + def test_frozen_teacher_handles_malformed_config_gracefully(self, model_without_latent_heads): + """FrozenTeacher should handle config without 'losses' attribute.""" + from omegaconf import OmegaConf + + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + # Config without 'losses' key + cfg = OmegaConf.create({"some_other_key": "value"}) + + # Should not raise, should default to JEPA + teacher = FrozenTeacher(model_without_latent_heads, training_cfg=cfg) + assert "JEPA" in teacher.postprocess_targets + + def test_frozen_teacher_from_pretrained_rejects_none_run_id(self): + """from_pretrained should reject None teacher_run_id with helpful message.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + with pytest.raises(ValueError, match="teacher_run_id"): + FrozenTeacher.from_pretrained( + cf=MagicMock(get=lambda k: "/some/path", training_config=None), + dataset=MagicMock(), + device="cpu", + params={}, # Missing teacher_run_id + ) + + def test_frozen_teacher_from_pretrained_rejects_empty_run_id(self): + """from_pretrained should reject empty string teacher_run_id.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + with pytest.raises(ValueError, match="non-empty string"): + FrozenTeacher.from_pretrained( + cf=MagicMock(get=lambda k: "/some/path", training_config=None), + dataset=MagicMock(), + device="cpu", + params={"teacher_run_id": ""}, + ) + + def test_get_target_postprocessing_validates_dino_config(self): + """get_target_postprocessing should raise KeyError for missing DINO config.""" + from omegaconf import OmegaConf + + from weathergen.train.target_and_aux_ssl_teacher import get_target_postprocessing + + # DINO config missing required keys + incomplete_config = OmegaConf.create({"DINO": {"out_dim": 256}}) # Missing other keys + + with pytest.raises(KeyError, match="DINO loss config missing required keys"): + get_target_postprocessing(incomplete_config, training_cfg=None) + + def test_get_target_postprocessing_validates_ibot_config(self): + """get_target_postprocessing should raise KeyError for missing iBOT config.""" + from omegaconf import OmegaConf + + from weathergen.train.target_and_aux_ssl_teacher import get_target_postprocessing + + # iBOT config missing required keys + incomplete_config = OmegaConf.create({"iBOT": {"out_dim": 256}}) # Missing other keys + + with pytest.raises(KeyError, match="iBOT loss config missing required keys"): + get_target_postprocessing(incomplete_config, training_cfg=None) From 491a69d76e2db3d28de3943fc72cdd003161a0a1 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux <24638638+sophie-xhonneux@users.noreply.github.com> Date: Fri, 6 Feb 2026 20:42:51 +0100 Subject: [PATCH 26/38] Test config --- config/config_jepa.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config/config_jepa.yml b/config/config_jepa.yml index 8896791e0..b0ef3262a 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -148,10 +148,10 @@ training_config: learning_rate_scheduling : lr_start: 1e-6 - lr_max: 5e-5 + lr_max: 1e-6 lr_final_decay: 1e-6 lr_final: 0.0 - num_steps_warmup: 512 + num_steps_warmup: 4096 num_steps_cooldown: 512 policy_warmup: "cosine" policy_decay: "constant" @@ -182,7 +182,7 @@ training_config: }, }, target_and_aux_calc: {FrozenTeacher: { - teacher_run_id: "zosrc8ti", # Required + teacher_run_id: "yoqxf234",# "zosrc8ti", # Required teacher_mini_epoch: -1}}, # target_and_aux_calc: { "EMATeacher" : From 08dbf6f1ba77feff015b2efb473aabc4e2bf94f9 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux <24638638+sophie-xhonneux@users.noreply.github.com> Date: Fri, 6 Feb 2026 22:25:36 +0100 Subject: [PATCH 27/38] Update jepa config --- config/config_jepa.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config/config_jepa.yml b/config/config_jepa.yml index b0ef3262a..b2c9c6509 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -148,7 +148,7 @@ training_config: learning_rate_scheduling : lr_start: 1e-6 - lr_max: 1e-6 + lr_max: 1e-4 lr_final_decay: 1e-6 lr_final: 0.0 num_steps_warmup: 4096 @@ -159,8 +159,8 @@ training_config: parallel_scaling_policy: "sqrt" optimizer: - grad_clip: 1.0 - weight_decay: 0.1 + grad_clip: 0.1 + weight_decay: 0.04 log_grad_norms: False adamw : # parameters are scaled by number of DDP workers From 11b2e604d4b79cb774c4aaa167985298ec784e22 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Mon, 16 Feb 2026 15:14:19 +0100 Subject: [PATCH 28/38] Fix model path loading --- src/weathergen/model/model_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 547ba3a90..390dade25 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -21,7 +21,7 @@ ) from torch.distributed.tensor import distribute_tensor -from weathergen.common.config import Config, merge_configs +from weathergen.common.config import Config, get_path_model, merge_configs from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -179,7 +179,7 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): mini_epoch : The mini_epoch to load. Default (-1) is the latest mini_epoch """ - path_run = Path(cf.model_path) / run_id + path_run = get_path_model(run_id=run_id) mini_epoch_id = ( f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest" ) From e08b7f87295ca5d32a3c21fd36d0f75a52cb008c Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Mon, 16 Feb 2026 15:20:28 +0000 Subject: [PATCH 29/38] Fix model_path --- config/config_jepa_finetuning.yml | 5 ++-- config/default_config.yml | 39 +++++++++++++------------ src/weathergen/model/model_interface.py | 1 + 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/config/config_jepa_finetuning.yml b/config/config_jepa_finetuning.yml index e9bf055a8..b4cd7aefa 100644 --- a/config/config_jepa_finetuning.yml +++ b/config/config_jepa_finetuning.yml @@ -92,7 +92,8 @@ zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore ##################################### # streams_directory: "./config/streams/era5_1deg/" -streams_directory: "./config/streams/era5_synop_finetuning/" +# streams_directory: "./config/streams/era5_synop_finetuning/" +streams_directory: "./config/streams/era5_nppatms_finetuning/" streams: ??? general: @@ -271,7 +272,7 @@ validation_config: # write samples in normalized model space normalized_samples: False, # output streams to write; default all - streams: ["SurfaceCombined"], + streams: ["NPPATMS"], } # run validation before training starts (mainly for model development) diff --git a/config/default_config.yml b/config/default_config.yml index 78c398e7d..7b7f35c4b 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -12,7 +12,7 @@ embed_unembed_mode: "block" embed_dropout_rate: 0.1 ae_local_dim_embed: 1024 -ae_local_num_blocks: 4 +ae_local_num_blocks: 2 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 ae_local_with_qk_lnorm: True @@ -37,7 +37,7 @@ ae_global_block_factor: 64 ae_global_mlp_hidden_factor: 2 ae_global_trailing_layer_norm: False -ae_aggregation_num_blocks: 12 +ae_aggregation_num_blocks: 8 ae_aggregation_num_heads: 32 ae_aggregation_dropout_rate: 0.1 ae_aggregation_with_qk_lnorm: True @@ -184,7 +184,7 @@ training_config: optimizer: # Optimizer type: "adamw" (default) or "muon_adamw" (Muon for hidden weights, AdamW for embeddings/heads) type: "muon_adamw" - grad_clip: 0.1 + grad_clip: 0.5 weight_decay: 0.05 log_grad_norms: False adamw : @@ -215,31 +215,32 @@ training_config: target_source_correspondence: {0 : {0 : "subset"} }, }, }, - target_and_aux_calc: {FrozenTeacher: { - teacher_run_id: "yoqxf234", # "zosrc8ti", # Required - teacher_mini_epoch: -1}}, + # target_and_aux_calc: {FrozenTeacher: { + # teacher_run_id: "yoqxf234", # "zosrc8ti", # Required + # teacher_mini_epoch: -1}}, # }, - # target_and_aux_calc: { "EMATeacher" : - # { ema_ramp_up_ratio : null, - # ema_halflife_in_thousands: 1e-1, - # model_param_overrides : { - # training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} - # }, - # } - # } + target_and_aux_calc: { "EMATeacher" : + { ema_ramp_up_ratio : null, + ema_halflife_in_thousands: 1e-0, + model_param_overrides : { + training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} + }, + } + } } } model_input: { "random_easy" : { # masking strategy: "random", "forecast" - masking_strategy: "random", + masking_strategy: "healpix", num_samples: 1, num_steps_input: 1, masking_strategy_config : { - diffusion_rn : True, + diffusion_rn : False, rate : 0.6, - rate_sampling: False + hl_mask: 4, + rate_sampling: True }, }, } @@ -248,7 +249,7 @@ training_config: "random_easy_target" : { masking_strategy: "healpix", num_samples: 1, - masking_strategy_config : { rate : 0.2, hl_mask: 0, rate_sampling: False }, + masking_strategy_config : { rate : 0.66, hl_mask: 4, rate_sampling: True}, }, } @@ -284,7 +285,7 @@ validation_config: } # run validation before training starts (mainly for model development) - validate_before_training: False + validate_before_training: 8 # test config; full test config is merge of validation and test config # test config is used by default when running inference diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 390dade25..47072b46d 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -29,6 +29,7 @@ MultiSelfAttentionHeadLocal, MultiSelfAttentionHeadVarlen, ) +from weathergen.common.config import get_path_model from weathergen.model.ema import EMAModel from weathergen.model.layers import MLP from weathergen.model.model import Model, ModelParams From b42a7785523bfe1405122e89161ad30b52d73da5 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 6 Feb 2026 11:27:58 +0100 Subject: [PATCH 30/38] Fix inference corner case (#1818) * implemented * remove eval in interface * lint * incoporate requested changes * fix imports * Fix corner case in inference where data window is empty * Fix missing handling of missing load_chkpt argument in config --------- Co-authored-by: moritzhauschulz --- src/weathergen/datasets/tokenizer_masking.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 3d61767f0..6dfe71c89 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -66,6 +66,10 @@ def get_tokens_windows(self, stream_info, data, pad_tokens): tokens = [] for rdata in data: + # skip empty data + if rdata.is_empty(): + continue + # tokenize data idxs_cells, idxs_cells_lens = tok( readerdata_to_torch(rdata), token_size, hl, pad_tokens ) From 8521b8c79b690aa6ca7d907338e136d0e0834a13 Mon Sep 17 00:00:00 2001 From: Till Hauer Date: Fri, 6 Feb 2026 12:09:44 +0100 Subject: [PATCH 31/38] fix latent_loss check in mode handling (#1784) --- src/weathergen/datasets/multi_stream_data_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 83d436fb9..71a133da9 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -631,7 +631,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): if "masking" in mode: source_select += ["network_input", "target_coords"] target_select += ["target_values"] - if "student_teacher" in mode or mode == "latent_loss" in mode: + if "student_teacher" in mode or "latent_loss" in mode: source_select += ["network_input"] target_select += ["network_input"] # remove duplicates From 57c8518ac41c5228d1da1b81b24b3003bda28798 Mon Sep 17 00:00:00 2001 From: Simon Grasse <161459968+grassesi@users.noreply.github.com> Date: Fri, 6 Feb 2026 19:55:11 +0100 Subject: [PATCH 32/38] Streamline run_train.py so it is suitable to be run both as a script but also as entrypoints. (#1778) * Clean up docstrings, separate cli parsing from running. * remove unused argument stream_dir * separate parser instantiation from adding args * add unified parser with subparsers * implement main function in run_train using subparsers. * update integration tests * remove redundant methods *_from_args (previously used by integration tests) * Move entrypoints to the top of run_train.py * fix typo in small_multi_stream_test.infer_multi_stream * fix formatting * Organize strings into enum. * fix parser --- integration_tests/jepa1_test.py | 41 ++++--- integration_tests/small1_test.py | 40 ++++--- integration_tests/small_multi_stream_test.py | 30 ++++-- src/weathergen/run_train.py | 106 +++++++++---------- src/weathergen/utils/cli.py | 56 ++++++++-- 5 files changed, 174 insertions(+), 99 deletions(-) diff --git a/integration_tests/jepa1_test.py b/integration_tests/jepa1_test.py index f2959f3c9..ce87a877d 100644 --- a/integration_tests/jepa1_test.py +++ b/integration_tests/jepa1_test.py @@ -12,12 +12,11 @@ import os import shutil from pathlib import Path -import omegaconf -import pytest + import numpy as np +import pytest -from weathergen.evaluate.run_evaluation import evaluate_from_config -from weathergen.run_train import inference_from_args, train_with_args +from weathergen.run_train import main from weathergen.utils.metrics import get_train_metrics_path logger = logging.getLogger(__name__) @@ -48,14 +47,14 @@ def setup(test_run_id): @pytest.mark.parametrize("test_run_id", ["test_jepa1_" + commit_hash]) def test_train(setup, test_run_id): logger.info(f"test_train with run_id {test_run_id} {WEATHERGEN_HOME}") - - train_with_args( - [ f"--config={WEATHERGEN_HOME}/integration_tests/jepa1.yaml" ] - + [ + + main( + [ + "train", + f"--config={WEATHERGEN_HOME}/integration_tests/jepa1.yaml", "--run-id", test_run_id, - ], - f"{WEATHERGEN_HOME}/config/streams/streams_test/", + ] ) assert_missing_metrics_file(test_run_id) @@ -85,12 +84,26 @@ def assert_missing_metrics_file(run_id): def assert_nans_in_metrics_file(run_id): """Test that there are no NaNs in the metrics file.""" metrics = load_metrics(run_id) - loss_values_train = np.array([entry.get('LossLatentSSLStudentTeacher.loss_avg') for entry in metrics if entry.get("stage") == 'train']) - loss_values_val = np.array([entry.get('LossLatentSSLStudentTeacher.loss_avg') for entry in metrics if entry.get("stage") == 'val']) + loss_values_train = np.array( + [ + entry.get('LossLatentSSLStudentTeacher.loss_avg') + for entry in metrics if entry.get("stage") == 'train' + ] + ) + loss_values_val = np.array( + [ + entry.get('LossLatentSSLStudentTeacher.loss_avg') + for entry in metrics if entry.get("stage") == 'val' + ] + ) #remove nans if applicable - loss_values_train = np.array([float(value) if value != 'nan' else np.nan for value in loss_values_train]) - loss_values_val = np.array([float(value) if value != 'nan' else np.nan for value in loss_values_val]) + loss_values_train = np.array( + [float(value) if value != 'nan' else np.nan for value in loss_values_train] + ) + loss_values_val = np.array( + [float(value) if value != 'nan' else np.nan for value in loss_values_val] + ) assert not np.isnan(loss_values_train).any(), ( "NaN values found in training loss metrics!" diff --git a/integration_tests/small1_test.py b/integration_tests/small1_test.py index d3c6e4024..b4845d157 100644 --- a/integration_tests/small1_test.py +++ b/integration_tests/small1_test.py @@ -15,9 +15,9 @@ import omegaconf import pytest -from weathergen.evaluate.run_evaluation import evaluate_from_config -from weathergen.run_train import inference_from_args, train_with_args +from weathergen.evaluate.run_evaluation import evaluate_from_config +from weathergen.run_train import main from weathergen.utils.metrics import get_train_metrics_path logger = logging.getLogger(__name__) @@ -49,13 +49,13 @@ def setup(test_run_id): def test_train(setup, test_run_id): logger.info(f"test_train with run_id {test_run_id} {WEATHERGEN_HOME}") - train_with_args( - f"--config={WEATHERGEN_HOME}/integration_tests/small1.yaml".split() - + [ + main( + [ + "inference", + f"--config={WEATHERGEN_HOME}/integration_tests/small1.yaml", "--run-id", test_run_id, - ], - f"{WEATHERGEN_HOME}/config/streams/streams_test/", + ] ) infer_with_missing(test_run_id) @@ -68,9 +68,16 @@ def test_train(setup, test_run_id): def infer(run_id): logger.info("run inference") - inference_from_args( - ["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--mini-epoch", "0"] - + [ + main( + [ + "-start", + "2022-10-10", + "-end", + "2022-10-11", + "--samples", + "10", + "--mini-epoch", + "0", "--from-run-id", run_id, "--run-id", @@ -83,9 +90,16 @@ def infer(run_id): def infer_with_missing(run_id): logger.info("run inference") - inference_from_args( - ["-start", "2021-10-10", "-end", "2022-10-11", "--samples", "10", "--mini-epoch", "0"] - + [ + main( + [ + "-start", + "2021-10-10", + "-end", + "2022-10-11", + "--samples", + "10", + "--mini-epoch", + "0", "--from-run-id", run_id, "--run-id", diff --git a/integration_tests/small_multi_stream_test.py b/integration_tests/small_multi_stream_test.py index 36cf12f5e..211581e78 100644 --- a/integration_tests/small_multi_stream_test.py +++ b/integration_tests/small_multi_stream_test.py @@ -23,9 +23,9 @@ import omegaconf import pytest -from weathergen.evaluate.run_evaluation import evaluate_from_config -from weathergen.run_train import inference_from_args, train_with_args +from weathergen.evaluate.run_evaluation import evaluate_from_config +from weathergen.run_train import main from weathergen.utils.metrics import get_train_metrics_path logger = logging.getLogger(__name__) @@ -58,15 +58,15 @@ def test_train_multi_stream(setup, test_run_id): """Test training with multiple streams including gridded and observation data.""" logger.info(f"test_train_multi_stream with run_id {test_run_id} {WEATHERGEN_HOME}") - train_with_args( - f"--base-config={WEATHERGEN_HOME}/integration_tests/small_multi_stream.yaml".split() - + [ + main( + [ + "train", + f"--base-config={WEATHERGEN_HOME}/integration_tests/small_multi_stream.yaml", "--run-id", test_run_id, - ], - f"{WEATHERGEN_HOME}/integration_tests/streams_multi/", + ] ) - + infer_multi_stream(test_run_id) # evaluate_multi_stream_results(test_run_id) assert_metrics_file_exists(test_run_id) @@ -78,9 +78,17 @@ def test_train_multi_stream(setup, test_run_id): def infer_multi_stream(run_id): """Run inference for multi-stream model.""" logger.info("run multi-stream inference") - inference_from_args( - ["-start", "2021-10-10", "-end", "2022-10-11", "--samples", "10", "--mini-epoch", "0"] - + [ + main( + [ + "inference", + "-start", + "2021-10-10", + "-end", + "2022-10-11", + "--samples", + "10", + "--mini-epoch", + "0", "--from-run-id", run_id, "--run-id", diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 65551fccb..e91501274 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -27,21 +27,63 @@ logger = logging.getLogger(__name__) +def train() -> None: + """Entry point for calling the training code from the command line.""" + main([cli.Stage.train] + sys.argv[1:]) + + +def train_continue() -> None: + """Entry point for calling train_continue from the command line.""" + main([cli.Stage.train_continue] + sys.argv[1:]) + + def inference(): - # By default, arguments from the command line are read. - inference_from_args(sys.argv[1:]) + """Entry point for calling the inference code from the command line.""" + main([cli.Stage.inference] + sys.argv[1:]) -def inference_from_args(argl: list[str]): +def main(argl: list[str]): + try: + argl = _fix_argl(argl) + except ValueError as e: + logger.error(str(e)) + + parser = cli.get_main_parser() + args = parser.parse_args(argl) + match args.stage: + case cli.Stage.train: + run_train(args) + case cli.Stage.train_continue: + run_continue(args) + case cli.Stage.inference: + run_inference(args) + case _: + logger.error("No stage was found.") + + +def _fix_argl(argl): # TODO remove this fix after grace period + """Ensure `stage` positional argument is in arglist.""" + if argl[0] not in cli.Stage: + try: + stage = os.environ.get("WEATHERGEN_STAGE") + except KeyError as e: + msg = ( + "`stage` postional argument and environment variable 'WEATHERGEN_STAGE' missing.", + "Provide either one or the other.", + ) + raise ValueError(msg) from e + + argl = [stage] + argl + + return argl + + +def run_inference(args): """ Inference function for WeatherGenerator model. - Entry point for calling the inference code from the command line. - When running integration tests, the arguments are directly provided. + Note: Additional configuration for inference (`test_config`) is set in the function. """ - parser = cli.get_inference_parser() - args = parser.parse_args(argl) - inference_overwrite = { "test_config": dict( shuffle=False, @@ -84,24 +126,12 @@ def inference_from_args(argl: list[str]): pdb.post_mortem(tb) -#################################################################################################### -def train_continue() -> None: +def run_continue(args): """ Function to continue training for WeatherGenerator model. - Entry point for calling train_continue from the command line. - Configurations are set in the function body. - Args: - from_run_id (str): Run/model id of pretrained WeatherGenerator model to - continue training. Defaults to None. Note: All model configurations are set in the function body. """ - train_continue_from_args(sys.argv[1:]) - - -def train_continue_from_args(argl: list[str]): - parser = cli.get_continue_parser() - args = parser.parse_args(argl) cli_overwrite = config.from_cli_arglist(args.options) cf = config.load_merge_configs( @@ -135,26 +165,12 @@ def train_continue_from_args(argl: list[str]): pdb.post_mortem(tb) -#################################################################################################### -def train() -> None: +def run_train(args): """ Training function for WeatherGenerator model. - Entry point for calling the training code from the command line. - Configurations are set in the function body. - Args: - run_id (str, optional): Run/model id of pretrained WeatherGenerator model to - continue training. Defaults to None. Note: All model configurations are set in the function body. """ - train_with_args(sys.argv[1:], None) - - -def train_with_args(argl: list[str], stream_dir: str | None): - """ - Training function for WeatherGenerator model.""" - parser = cli.get_train_parser() - args = parser.parse_args(argl) cli_overwrite = config.from_cli_arglist(args.options) @@ -191,20 +207,4 @@ def train_with_args(argl: list[str], stream_dir: str | None): if __name__ == "__main__": - try: - stage = os.environ.get("WEATHERGEN_STAGE") - except KeyError as e: - msg = "missing environment variable 'WEATHERGEN_STAGE'" - raise ValueError(msg) from e - - if stage == "train": - # Entry point for slurm script. - # Check whether --from-run-id passed as argument. - if any("--from-run-id" in arg for arg in sys.argv): - train_continue() - else: - train() - elif stage == "inference": - inference() - else: - logger.error("No stage was found.") + main(sys.argv[1:]) diff --git a/src/weathergen/utils/cli.py b/src/weathergen/utils/cli.py index bc98ba11d..1c7cba6a8 100644 --- a/src/weathergen/utils/cli.py +++ b/src/weathergen/utils/cli.py @@ -1,19 +1,65 @@ import argparse +import enum from pathlib import Path import pandas as pd +class Stage(enum.StrEnum): + train = enum.auto() + train_continue = enum.auto() + inference = enum.auto() + + +def get_main_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(allow_abbrev=False) + subparsers = parser.add_subparsers(dest="stage") + + train_parser = subparsers.add_parser( + Stage.train, + help="Train a WeatherGenerator configuration from the ground up.", + ) + _add_train_args(train_parser) + continue_parser = subparsers.add_parser( + Stage.train_continue, + help="Resume training from a pretrained WeatherGenerator configuration.", + ) + _add_continue_args(continue_parser) + inference_parser = subparsers.add_parser( + Stage.inference, + help="Run infernce on a trained WeatherGenerator configuration", + ) + _add_inference_args(inference_parser) + + return parser + + def get_train_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(allow_abbrev=False) - _add_general_arguments(parser) + _add_train_args(parser) return parser def get_continue_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(allow_abbrev=False) + _add_continue_args(parser) + + return parser + + +def get_inference_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(allow_abbrev=False) + _add_inference_args(parser) + + return parser + + +def _add_train_args(parser: argparse.ArgumentParser): + _add_general_arguments(parser) + +def _add_continue_args(parser: argparse.ArgumentParser): _add_general_arguments(parser) _add_model_loading_params(parser) @@ -26,12 +72,8 @@ def get_continue_parser() -> argparse.ArgumentParser: ), ) - return parser - - -def get_inference_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(allow_abbrev=False) +def _add_inference_args(parser: argparse.ArgumentParser): _add_model_loading_params(parser) _add_general_arguments(parser) @@ -64,8 +106,6 @@ def get_inference_parser() -> argparse.ArgumentParser: help="Output streams during inference.", ) - return parser - def _format_date(date: str) -> str: try: From 424e188479ce355f7434a7b50dc7b68e2f902e13 Mon Sep 17 00:00:00 2001 From: Simon Grasse <161459968+grassesi@users.noreply.github.com> Date: Fri, 6 Feb 2026 20:07:47 +0100 Subject: [PATCH 33/38] Sgrasse/develop/435 unify dataset access (#1757) * Implement best effort backward compatibility. * use new `data_pathes` option to look up training data. * fix integration tests * linting * correct spelling in config.py * correct spelling in multi_stream_data_sampler.py * fix typo "data_path_anmoi" -> "data_path_anemoi" in config.py * Update test_config.py * Add suggested comment. --- integration_tests/small_multi_stream.yaml | 1 + .../common/src/weathergen/common/config.py | 22 ++++++++++++++++ .../src/weathergen/readers_extra/registry.py | 26 +++++-------------- .../datasets/multi_stream_data_sampler.py | 22 +++++++--------- tests/test_config.py | 3 +-- 5 files changed, 40 insertions(+), 34 deletions(-) diff --git a/integration_tests/small_multi_stream.yaml b/integration_tests/small_multi_stream.yaml index a3edcc69a..35a1e2767 100644 --- a/integration_tests/small_multi_stream.yaml +++ b/integration_tests/small_multi_stream.yaml @@ -186,6 +186,7 @@ training_config: time_step: 06:00:00 num_steps: 2 policy: "fixed" + offset: 1 # validation config; full validation config is merge of training and validation config diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 794f829c6..f0243a717 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -301,6 +301,26 @@ def _apply_fixes(config: Config) -> Config: eventually removed. """ config = _check_logging(config) + config = _check_datasets(config) + return config + + +def _check_datasets(config: Config) -> Config: + """ + Collect dataset paths under legacy keys. + """ + config = config.copy() + if config.get("data_paths") is None: # TODO remove this for next version + legacy_keys = [ + "data_path_anemoi", + "data_path_obs", + "data_path_eobs", + "data_path_fesom", + "data_path_icon", + ] + paths = [config.get(key) for key in legacy_keys] + config.data_paths = [path for path in paths if path is not None] + return config @@ -526,6 +546,8 @@ def _load_private_conf(private_home: Path | None = None) -> DictConfig: if "secrets" in private_cf: del private_cf["secrets"] + private_cf = _check_datasets(private_cf) # TODO: remove temp backward compatibility fix + assert isinstance(private_cf, DictConfig) return private_cf diff --git a/packages/readers_extra/src/weathergen/readers_extra/registry.py b/packages/readers_extra/src/weathergen/readers_extra/registry.py index 957a5a350..39953a25e 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/registry.py +++ b/packages/readers_extra/src/weathergen/readers_extra/registry.py @@ -1,36 +1,24 @@ -from collections.abc import Callable -from dataclasses import dataclass - -from weathergen.common.config import Config - - -@dataclass -class ReaderEntry: - data_path: str | None - constructor: Callable - - -def get_extra_reader(name: str, cf: Config) -> object | None: - """Get an extra reader by name.""" +def get_extra_reader(stream_type: str) -> object | None: + """Get an extra reader by stream_type name.""" # Uses lazy imports to avoid circular dependencies and to not load all the readers at start. # There is no sanity check on them, so they may fail at runtime during imports - match name: + match stream_type: case "iconart": from weathergen.readers_extra.data_reader_iconart import DataReaderIconArt - return ReaderEntry(cf.data_path_icon, DataReaderIconArt) + return DataReaderIconArt case "eobs": from weathergen.readers_extra.data_reader_eobs import DataReaderEObs - return ReaderEntry(cf.data_path_eobs, DataReaderEObs) + return DataReaderEObs case "iconesm": from weathergen.readers_extra.data_reader_icon_esm import DataReaderIconEsm - return ReaderEntry(cf.data_path_icon_esm, DataReaderIconEsm) + return DataReaderIconEsm case "cams": from weathergen.readers_extra.data_reader_cams import DataReaderCams - return ReaderEntry(cf.data_path_cams, DataReaderCams) + return DataReaderCams case _: return None diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 71a133da9..86049d389 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -151,40 +151,36 @@ def __init__( match stream_info["type"]: case "obs": dataset = DataReaderObs - datapath = cf.data_path_obs - # kwargs["end"] = end_date_padded # TODO: implement the padding case "anemoi": dataset = DataReaderAnemoi - datapath = cf.data_path_anemoi case "fesom": dataset = DataReaderFesom - datapath = cf.data_path_fesom case type_name: - reader_entry = get_extra_reader(type_name, cf) - if reader_entry is not None: - dataset = reader_entry.constructor - datapath = reader_entry.data_path - else: + dataset = get_extra_reader(type_name) + if dataset is None: msg = f"Unsupported stream type {stream_info['type']}" f"for stream name '{stream_info['name']}'." raise ValueError(msg) - datapath = pathlib.Path(datapath) fname = pathlib.Path(fname) # dont check if file exists since zarr stores might be directories if fname.exists(): # check if fname is a valid path to allow for simple overwriting filename = fname else: - filename = pathlib.Path(datapath) / fname + filenames = [pathlib.Path(path) / fname for path in cf.data_paths] - if not filename.exists(): # see above + if not any(filename.exists() for filename in filenames): # see above msg = ( f"Did not find input data for {stream_info['type']} " - f"stream '{stream_info['name']}': {filename}." + f"stream '{stream_info['name']}': {filenames}." ) raise FileNotFoundError(msg) + # The same dataset can exist on different locations in the filesystem, + # so we need to choose here. + filename = filenames[0] + ds_type = stream_info["type"] if is_root(): logger.info( diff --git a/tests/test_config.py b/tests/test_config.py index c5fa2e5f3..e04390341 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -9,8 +9,7 @@ TEST_RUN_ID = "test123" SECRET_COMPONENT = "53CR3T" DUMMY_PRIVATE_CONF = { - "data_path_anemoi": "/path/to/anmoi/data", - "data_path_obs": "/path/to/observation/data", + "data_paths": ["/path/to/anmoi/data", "/path/to/observation/data"] "secrets": { "my_big_secret": { "my_secret_id": f"{SECRET_COMPONENT}01234", From 141315b9017bd620048cbdcce815d7b36749f2cc Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 13 Feb 2026 22:24:26 +0100 Subject: [PATCH 34/38] Fix plot_train (#1831) * Fixed most parts of plot_train. Currently missing: handling of stage_configs when these are derived from an earlier stage. * Removed outdated or unsupported options * Fixed final problems with consolidated training/validation config. Required to move Stage to a more appropriate place * Removed old, unused code --- src/weathergen/datasets/masking.py | 2 +- .../datasets/multi_stream_data_sampler.py | 5 +- .../loss_modules/loss_module_physical.py | 2 +- src/weathergen/train/trainer.py | 26 ++-- src/weathergen/train/utils.py | 30 ++++- src/weathergen/utils/metrics.py | 2 +- src/weathergen/utils/plot_training.py | 59 +++++++- src/weathergen/utils/train_logger.py | 127 +++++------------- src/weathergen/utils/utils.py | 2 +- 9 files changed, 142 insertions(+), 113 deletions(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index aa3a61f44..f84111541 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -9,7 +9,7 @@ from numpy.typing import NDArray from weathergen.datasets.batch import SampleMetaData -from weathergen.utils.train_logger import Stage +from weathergen.train.utils import Stage from weathergen.utils.utils import is_stream_diagnostic, is_stream_forcing logger = logging.getLogger(__name__) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 86049d389..094032db4 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -31,9 +31,8 @@ get_tokens_lens, ) from weathergen.readers_extra.registry import get_extra_reader -from weathergen.train.utils import get_batch_size_from_config +from weathergen.train.utils import TRAIN, Stage, get_batch_size_from_config from weathergen.utils.distributed import is_root -from weathergen.utils.train_logger import TRAIN, Stage type AnyDataReader = DataReaderBase | DataReaderAnemoi | DataReaderObs type StreamName = str @@ -529,7 +528,7 @@ def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, s # source data: iterate overall input steps input_data = [] - for idx in range(base_idx - num_steps_input_max, base_idx + 1): + for idx in range(base_idx - num_steps_input_max + 1, base_idx + 1): # TODO: check that we are not out of bounds when we go back in time rdata = collect_datasources(stream_ds, idx, "source", self.rng) diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index 64e104827..61be2848b 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -18,7 +18,7 @@ import weathergen.train.loss_modules.loss_functions as loss_fns from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues -from weathergen.utils.train_logger import TRAIN, VAL, Stage +from weathergen.train.utils import TRAIN, VAL, Stage _logger = logging.getLogger(__name__) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 6c9855cf8..6fbdd0487 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -21,7 +21,7 @@ from torch.distributed.tensor import DTensor import weathergen.common.config as config -from weathergen.common.config import Config, merge_configs +from weathergen.common.config import Config from weathergen.datasets.multi_stream_data_sampler import MultiStreamDataSampler from weathergen.model.ema import EMAModel from weathergen.model.model_interface import ( @@ -36,18 +36,25 @@ from weathergen.train.target_and_aux_ssl_teacher import EMATeacher from weathergen.train.trainer_base import TrainerBase from weathergen.train.utils import ( + TRAIN, + VAL, + Stage, + cfg_keys_to_filter, extract_batch_metadata, filter_config_by_enabled, + get_active_stage_config, get_batch_size_from_config, get_target_idxs_from_cfg, ) from weathergen.utils.distributed import is_root -from weathergen.utils.train_logger import TRAIN, VAL, Stage, TrainLogger, prepare_losses_for_logging +from weathergen.utils.train_logger import TrainLogger, prepare_losses_for_logging from weathergen.utils.utils import get_dtype from weathergen.utils.validation_io import write_output logger = logging.getLogger(__name__) +# cfg_keys_to_filter = ["losses", "model_input", "target_input"] + class Trainer(TrainerBase): def __init__(self, train_log_freq: Config): @@ -103,22 +110,21 @@ def init(self, cf: Config, devices): self.freeze_modules = cf.get("freeze_modules", "") - # keys to filter for enabled/disabled - keys_to_filter = ["losses", "model_input", "target_input"] - # get training config and remove disabled options (e.g. because of overrides) self.training_cfg = cf.get("training_config") - self.training_cfg = filter_config_by_enabled(self.training_cfg, keys_to_filter) + self.training_cfg = filter_config_by_enabled(self.training_cfg, cfg_keys_to_filter) assert len(self.training_cfg.model_input.keys()) != 0, ( "You probably have no loss term enabled" ) # validation and test configs are training configs, updated by specified keys - self.validation_cfg = merge_configs(self.training_cfg, cf.get("validation_config", {})) - self.validation_cfg = filter_config_by_enabled(self.validation_cfg, keys_to_filter) + self.validation_cfg = get_active_stage_config( + self.training_cfg, cf.get("validation_config", {}), cfg_keys_to_filter + ) # test cfg is derived from validation cfg with specified keys overwritten - self.test_cfg = merge_configs(self.validation_cfg, cf.get("test_config", {})) - self.test_cfg = filter_config_by_enabled(self.test_cfg, keys_to_filter) + self.test_cfg = get_active_stage_config( + self.validation_cfg, cf.get("test_config", {}), cfg_keys_to_filter + ) # batch sizes self.batch_size_per_gpu = get_batch_size_from_config(self.training_cfg) diff --git a/src/weathergen/train/utils.py b/src/weathergen/train/utils.py index b3ddba5b0..81c8d0ae9 100644 --- a/src/weathergen/train/utils.py +++ b/src/weathergen/train/utils.py @@ -9,11 +9,23 @@ import copy import json +from typing import Literal import torch +from omegaconf import OmegaConf from weathergen.common import config -from weathergen.common.config import Config +from weathergen.common.config import Config, merge_configs + +# Run stages +Stage = Literal["train", "val", "test"] +TRAIN: Stage = "train" +VAL: Stage = "val" +TEST: Stage = "test" + +# keys to filter using enabled: True/False +cfg_keys_to_filter = ["losses", "model_input", "target_input"] + # TODO: remove this definition, it should directly using common. get_run_id = config.get_run_id @@ -149,7 +161,21 @@ def get_target_idxs_from_cfg(cfg, loss_name) -> list[int] | None: return target_idxs -def filter_config_by_enabled(cfg, keys): +def get_active_stage_config( + base_config: dict | OmegaConf, merge_config: dict | OmegaConf, keys_to_filter: list[str] +) -> dict | OmegaConf: + """ + Combine a stage config with its predecessor and filter by enabled: False to obtain the + final config that is used + """ + + result_cfg = merge_configs(base_config, merge_config) + result_cfg = filter_config_by_enabled(result_cfg, keys_to_filter) + + return result_cfg + + +def filter_config_by_enabled(cfg: dict | OmegaConf, keys: list[str]): """ Filtered disabled entries from config """ diff --git a/src/weathergen/utils/metrics.py b/src/weathergen/utils/metrics.py index aedb48739..22e8745de 100644 --- a/src/weathergen/utils/metrics.py +++ b/src/weathergen/utils/metrics.py @@ -61,4 +61,4 @@ def get_train_metrics_path(base_path: Path, run_id: str) -> Path: if (base_path / run_id / "metrics.json").exists(): return base_path / run_id / "metrics.json" else: - return base_path / run_id / f"{run_id}_train_metrics.json" + return base_path / f"{run_id}_train_metrics.json" diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 35bfafe3e..bd54a564c 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -9,8 +9,10 @@ import argparse import logging +import pdb import subprocess import sys +import traceback from pathlib import Path import matplotlib.pyplot as plt @@ -167,7 +169,7 @@ def get_stream_names(run_id: str, model_path: Path | None = "./model"): List of stream names """ # return col names from training (should be identical to validation) - cf = config.load_run_config(run_id, -1, model_path=model_path) + cf = config.load_run_config(run_id, None, model_path=model_path) return [si["name"].replace(",", "").replace("/", "_").replace(" ", "_") for si in cf.streams] @@ -316,6 +318,49 @@ def plot_utilization( plt.close() +def plot_loss_avg(plot_dir: Path, runs_ids, runs_data, x_scale_log=False): + prop_cycle = plt.rcParams["axes.prop_cycle"] + colors = prop_cycle.by_key()["color"] + ["r", "g", "b", "k", "y", "m"] + + # # legend = plt.legend(legend_str, loc="upper right" if not x_scale_log else "lower left") + # for line in legend.get_lines(): + # line.set(alpha=1.0) + _fig = plt.figure(figsize=(10, 7), dpi=300) + + legend_str = [] + for i_run, (run_id, run_data) in enumerate(zip(runs_ids, runs_data, strict=False)): + x_vals = np.array(run_data.train["num_samples"]) + y_vals = np.array(run_data.train["loss_avg_mean"]) + plt.plot( + x_vals, + y_vals, + color=colors[i_run % len(colors)], + ) + legend_str += [run_id + " : " + runs_ids[run_id][1]] + # ("R" if runs_active[j] else "X") + # + " : " + # run_id + ", " + col + " : " + runs_ids[run_id][1] + # ] + + plt.legend(legend_str) + plt.grid(True, which="both", ls="-") + plt.yscale("log") + # cap at 1.0 in case of divergence of run (through normalziation, max should be around 1.0) + # plt.ylim([0.95 * min_val, (None if max_val < 2.0 else min(1.1, 1.025 * max_val))]) + if x_scale_log: + plt.xscale("log") + plt.title("average loss") + plt.ylabel("loss") + plt.xlabel("step") + plt.tight_layout() + rstr = "".join([f"{r}_" for r in runs_ids]) + + plt_fname = plot_dir / f"{rstr}avg.png" + _logger.info(f"Saving avg plot to '{plt_fname}'") + plt.savefig(plt_fname) + plt.close() + + #################################################################################################### def plot_loss_per_stream( modes: list[str], @@ -357,7 +402,7 @@ def plot_loss_per_stream( """ if errs is None: - errs = ["loss_mse"] + errs = ["mse"] modes = [modes] if type(modes) is not list else modes # repeat colors when train and val is plotted simultaneously @@ -688,6 +733,9 @@ def plot_train(args=None): # plot learning rate plot_lr(runs_ids, runs_data, runs_active, plot_dir=out_dir) + # plot average loss + plot_loss_avg(out_dir, runs_ids, runs_data) + # # plot performance # plot_utilization(runs_ids, runs_data, runs_active, plot_dir=out_dir) @@ -746,4 +794,9 @@ def plot_train(args=None): if __name__ == "__main__": args = sys.argv[1:] # get CLI args - plot_train(args) + try: + plot_train(args) + except Exception: + extype, value, tb = sys.exc_info() + traceback.print_exc() + pdb.post_mortem(tb) diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 2812134a0..f2313baa4 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -16,14 +16,15 @@ from collections import defaultdict from dataclasses import dataclass from pathlib import Path -from typing import Literal import numpy as np import polars as pl import torch import weathergen.common.config as config -from weathergen.train.utils import flatten_dict + +# from weathergen.train.trainer import cfg_keys_to_filter +from weathergen.train.utils import Stage, cfg_keys_to_filter, flatten_dict, get_active_stage_config from weathergen.utils.distributed import ddp_average from weathergen.utils.metrics import get_train_metrics_path, read_metrics_file @@ -35,13 +36,8 @@ _logger = logging.getLogger(__name__) -Stage = Literal["train", "val"] RunId = str -# All the stages currently implemented: -TRAIN: Stage = "train" -VAL: Stage = "val" - @dataclass class Metrics: @@ -91,7 +87,7 @@ def log_metrics(self, stage: Stage, metrics: dict[str, float]) -> None: # but we can probably do better and rely for example on the logging module. metrics_path = get_train_metrics_path( - base_path=config.get_path_run(self.cf).parent, run_id=self.cf.general.run_id + base_path=config.get_path_run(self.cf), run_id=self.cf.general.run_id ) with open(metrics_path, "ab") as f: s = json.dumps(clean_metrics) + "\n" @@ -131,10 +127,11 @@ def add_logs( ####################################### @staticmethod - def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: + def read(run_id: str, model_path: str = None, mini_epoch: int | None = None) -> Metrics: """ Read data for run_id """ + # Load config from given model_path if provided, otherwise use path from private config if model_path: cf = config.load_run_config(run_id=run_id, mini_epoch=mini_epoch, model_path=model_path) @@ -148,28 +145,15 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: result_dir = result_dir_base / run_id fname_log_train = result_dir / f"{run_id}_train_log.txt" fname_log_val = result_dir / f"{run_id}_val_log.txt" - fname_perf_val = result_dir / f"{run_id}_perf_log.txt" # training # define cols for training - cols_train = ["dtime", "samples", "mse", "lr"] - cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"] - for si in cf.streams: - for lf in cf.loss_fcts: - cols1 += [_key_loss(si["name"], lf[0])] - cols_train += [ - si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0] - ] - with_stddev = [("stats" in lf) for lf in cf.loss_fcts] - if with_stddev: - for si in cf.streams: - cols1 += [_key_stddev(si["name"])] - cols_train += [ - si["name"].replace(",", "").replace("/", "_").replace(" ", "_") - + ", " - + "stddev" - ] + training_cfg = get_active_stage_config(cf.training_config, {}, cfg_keys_to_filter) + cols1, cols_train = get_loss_terms_per_stream(cf.streams, training_cfg) + cols_train += ["dtime", "samples", "mse", "lr"] + cols1 += [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"] + # read training log data try: with open(fname_log_train, "rb") as f: @@ -211,23 +195,13 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: # validation # define cols for validation + validation_cfg = get_active_stage_config( + training_cfg, cf.get("validation_config", {}), cfg_keys_to_filter + ) + cols2, cols_val = get_loss_terms_per_stream(cf.streams, validation_cfg) cols_val = ["dtime", "samples"] cols2 = [_weathergen_timestamp, "num_samples"] - for si in cf.streams: - for lf in cf.loss_fcts_val: - cols_val += [ - si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0] - ] - cols2 += [_key_loss(si["name"], lf[0])] - with_stddev = [("stats" in lf) for lf in cf.loss_fcts_val] - if with_stddev: - for si in cf.streams: - cols2 += [_key_stddev(si["name"])] - cols_val += [ - si["name"].replace(",", "").replace("/", "_").replace(" ", "_") - + ", " - + "stddev" - ] + # read validation log data try: with open(fname_log_val, "rb") as f: @@ -266,54 +240,7 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: log_val = np.array([]) metrics_val_df = read_metrics(cf, run_id, "val", cols2, result_dir_base) - # performance - # define cols for performance monitoring - cols_perf = ["GPU", "memory"] - # read perf log data - try: - with open(fname_perf_val, "rb") as f: - log_perf = np.loadtxt(f, delimiter=",") - log_perf = log_perf.reshape((log_perf.shape[0] // len(cols_perf), len(cols_perf))) - except ( - TypeError, - AttributeError, - IndexError, - ZeroDivisionError, - ValueError, - ) as e: - _logger.warning( - ( - f"Warning: no validation data loaded for run_id={run_id}", - "Data loading or reshaping failed — " - "possible format, dimension, or logic issue.", - f"Due to specific error: {e}", - ) - ) - except (FileNotFoundError, PermissionError, OSError) as e: - _logger.error( - ( - f"Error: no validation data loaded for run_id={run_id}", - "File system error occurred while handling the log file.", - f"Due to specific error: {e}", - ) - ) - except Exception: - _logger.error( - ( - f"Error: no validation data loaded for run_id={run_id}", - f"Due to exception with trace:\n{traceback.format_exc()}", - ) - ) - log_perf = np.array([]) - metrics_system_df = read_metrics( - cf, - run_id, - None, - [_weathergen_timestamp, _performance_gpu, _performance_memory], - result_dir_base, - ) - - return Metrics(run_id, "train", log_train_df, metrics_val_df, metrics_system_df) + return Metrics(run_id, "train", log_train_df, metrics_val_df, None) def read_metrics( @@ -391,9 +318,27 @@ def clean_name(s: str) -> str: return "".join(c for c in s if c.isalnum() or c == "_") +def get_loss_terms_per_stream(streams, stage_config): + """ + Extract per stream loss terms + """ + cols, cols_stage = [], [] + for si in streams: + for _, loss_config in stage_config.get("losses", {}).items(): + if loss_config.get("type", "LossPhysical") == "LossPhysical": + for lname, _ in loss_config.loss_fcts.items(): + cols += [_key_loss(si["name"], lname)] + cols_stage += [_clean_stream_name(si["name"]) + lname] + return cols, cols_stage + + +def _clean_stream_name(stream_name: str) -> str: + return stream_name.replace(",", "").replace("/", "_").replace(" ", "_") + ", " + + def _key_loss(st_name: str, lf_name: str) -> str: st_name = clean_name(st_name) - return f"stream.{st_name}.loss_{lf_name}.loss_avg" + return f"LossPhysical.{st_name}.{lf_name}.avg" def _key_loss_chn(st_name: str, lf_name: str, ch_name: str) -> str: diff --git a/src/weathergen/utils/utils.py b/src/weathergen/utils/utils.py index aee807341..291ab1521 100644 --- a/src/weathergen/utils/utils.py +++ b/src/weathergen/utils/utils.py @@ -10,7 +10,7 @@ import torch -from weathergen.utils.train_logger import TRAIN, Stage +from weathergen.train.utils import TRAIN, Stage def get_dtype(value: str) -> torch.dtype: From b5b4a44b5e435029de2e2ba903b295220dd3ef9d Mon Sep 17 00:00:00 2001 From: jesicapinon Date: Mon, 16 Feb 2026 11:39:31 +0100 Subject: [PATCH 35/38] nse_metric (#1833) * nse_metric * length --------- Co-authored-by: Jesica Pinyon Rodriguez --- .../src/weathergen/evaluate/scores/score.py | 29 +++++++++++++++++++ src/weathergen/utils/cli.py | 2 +- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index ab322ab28..c5d4cf8d0 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -190,6 +190,7 @@ def __init__( "grad_amplitude": self.calc_spatial_variability, "psnr": self.calc_psnr, "seeps": self.calc_seeps, + "nse": self.calc_nse, } self.prob_metrics_dict = { "ssr": self.calc_ssr, @@ -1199,6 +1200,34 @@ def seeps(ground_truth, prediction, thr_light, thr_heavy, seeps_weights): return seeps_values + def calc_nse(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: + """ + Calculate Nash–Sutcliffe_model_efficiency_coefficient (NSE) + of forecast data vs reference data + Metrics broadly used in hydrology + Parameters + ---------- + p: xr.DataArray + Forecast data array + gt: xr.DataArray + Ground truth data array + Returns + ------- + xr.DataArray + Nash–Sutcliffe_model_efficiency_coefficient (NSE) + + """ + + obs_mean = gt.mean(dim=self._agg_dims) + + num = ((gt - p) ** 2).sum(dim=self._agg_dims) + + den = ((gt - obs_mean) ** 2).sum(dim=self._agg_dims) + + nse = 1 - num / den + + return nse + ### Probablistic scores def calc_spread(self, p: xr.DataArray, **kwargs) -> xr.DataArray: diff --git a/src/weathergen/utils/cli.py b/src/weathergen/utils/cli.py index 1c7cba6a8..2bd9fe2a2 100644 --- a/src/weathergen/utils/cli.py +++ b/src/weathergen/utils/cli.py @@ -14,7 +14,7 @@ class Stage(enum.StrEnum): def get_main_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(allow_abbrev=False) subparsers = parser.add_subparsers(dest="stage") - + train_parser = subparsers.add_parser( Stage.train, help="Train a WeatherGenerator configuration from the ground up.", From f0d4a067ee58f6fb7a4fbb1dcfce53e4be563bce Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Tue, 17 Feb 2026 09:06:07 +0000 Subject: [PATCH 36/38] Random encoder fientuning on NPPATMS --- config/config_jepa_finetuning.yml | 2 +- .../streams/era5_nppatms_finetuning/era5.yml | 40 +++++++++++++++++++ .../era5_nppatms_finetuning/npp_atms.yml | 33 +++++++++++++++ .../datasets/multi_stream_data_sampler.py | 4 +- 4 files changed, 76 insertions(+), 3 deletions(-) create mode 100644 config/streams/era5_nppatms_finetuning/era5.yml create mode 100644 config/streams/era5_nppatms_finetuning/npp_atms.yml diff --git a/config/config_jepa_finetuning.yml b/config/config_jepa_finetuning.yml index b4cd7aefa..92090d5df 100644 --- a/config/config_jepa_finetuning.yml +++ b/config/config_jepa_finetuning.yml @@ -140,7 +140,7 @@ training_config: samples_per_mini_epoch: 4096 shuffle: True - start_date: 1979-01-01T00:00 + start_date: 2012-01-01T00:00 end_date: 2022-12-31T00:00 time_window_step: 06:00:00 diff --git a/config/streams/era5_nppatms_finetuning/era5.yml b/config/streams/era5_nppatms_finetuning/era5.yml new file mode 100644 index 000000000..45d7ddf9c --- /dev/null +++ b/config/streams/era5_nppatms_finetuning/era5.yml @@ -0,0 +1,40 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + forcing: True + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 + diff --git a/config/streams/era5_nppatms_finetuning/npp_atms.yml b/config/streams/era5_nppatms_finetuning/npp_atms.yml new file mode 100644 index 000000000..159cad38e --- /dev/null +++ b/config/streams/era5_nppatms_finetuning/npp_atms.yml @@ -0,0 +1,33 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +NPPATMS : + type : obs + stream_id : 1 + filenames : ['observations-ea-ofb-0001-2012-2023-npp-atms-radiances-v2.zarr'] + loss_weight : 1.0 + diagnostic: True + token_size : 32 + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 2 + dim_embed : 128 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 128 + target_readout : + num_layers : 1 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 + diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 094032db4..5afc72d69 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -533,7 +533,7 @@ def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, s rdata = collect_datasources(stream_ds, idx, "source", self.rng) - if rdata.is_empty() and self._stage == TRAIN: + if rdata.is_empty(): # and self._stage == TRAIN: # work around for https://github.com/pytorch/pytorch/issues/158719 # create non-empty mean data instead of empty tensor time_win = self.time_window_handler.window(idx) @@ -555,7 +555,7 @@ def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, s rdata = collect_datasources(stream_ds, step_forecast_dt, "target", self.rng) - if rdata.is_empty() and self._stage == TRAIN: + if rdata.is_empty(): # and self._stage == TRAIN: # work around for https://github.com/pytorch/pytorch/issues/158719 # create non-empty mean data instead of empty tensor time_win = self.time_window_handler.window(timestep_idx) From 66394b56a698b31a464207119dac4363a6c4ce2a Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux <24638638+sophie-xhonneux@users.noreply.github.com> Date: Tue, 17 Feb 2026 18:35:57 +0100 Subject: [PATCH 37/38] Update configs to fix leak --- config/config_jepa_finetuning.yml | 2 +- config/default_config.yml | 20 +++++------ .../era5_nppatms_finetuning/npp_atms.yml | 33 ------------------- 3 files changed, 11 insertions(+), 44 deletions(-) delete mode 100644 config/streams/era5_nppatms_finetuning/npp_atms.yml diff --git a/config/config_jepa_finetuning.yml b/config/config_jepa_finetuning.yml index 92090d5df..b12d9f98f 100644 --- a/config/config_jepa_finetuning.yml +++ b/config/config_jepa_finetuning.yml @@ -141,7 +141,7 @@ training_config: shuffle: True start_date: 2012-01-01T00:00 - end_date: 2022-12-31T00:00 + end_date: 2021-12-31T00:00 time_window_step: 06:00:00 time_window_len: 06:00:00 diff --git a/config/default_config.yml b/config/default_config.yml index 7b7f35c4b..968ca79f8 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -25,8 +25,8 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 2048 -ae_global_num_blocks: 0 +ae_global_dim_embed: 1024 +ae_global_num_blocks: 2 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True @@ -65,7 +65,7 @@ fe_impute_latent_noise_std: 0.0 # 1e-4 forecast_att_dense_rate: 1.0 with_step_conditioning: True # False -healpix_level: 5 +healpix_level: 4 with_mixed_precision: True with_flash_attention: True @@ -162,7 +162,7 @@ training_config: shuffle: True start_date: 1979-01-01T00:00 - end_date: 2022-12-31T00:00 + end_date: 2021-12-31T00:00 time_window_step: 06:00:00 time_window_len: 06:00:00 @@ -209,8 +209,8 @@ training_config: weight: 1.0, loss_fcts : { "JEPA": { - 'weight': 4, "loss_extra_args": {}, "out_dim": 2048, "head": transformer, - "num_blocks": 6, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 768, + 'weight': 4, "loss_extra_args": {}, "out_dim": 1024, "head": transformer, + "num_blocks": 12, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 512, "dropout_rate": 0.1, target_source_correspondence: {0 : {0 : "subset"} }, }, @@ -239,7 +239,7 @@ training_config: masking_strategy_config : { diffusion_rn : False, rate : 0.6, - hl_mask: 4, + hl_mask: 2, rate_sampling: True }, }, @@ -265,8 +265,8 @@ validation_config: samples_per_mini_epoch: 256 shuffle: False - start_date: 2023-10-01T00:00 - end_date: 2023-12-31T00:00 + start_date: 2022-10-01T00:00 + end_date: 2022-12-31T00:00 # whether to track the exponential moving average of weights for validation validate_with_ema: @@ -312,7 +312,7 @@ wgtags: # issue number. # Expected values are lowercase strings with no spaces, just underscores: # Examples: "rollout_ablation_grid" - exp: null + exp: jepa # *** Experiment-specific tags *** # All extra tags (including lists, dictionaries, etc.) are treated # as strings by mlflow, so treat all extra tags as simple string key: value pairs. diff --git a/config/streams/era5_nppatms_finetuning/npp_atms.yml b/config/streams/era5_nppatms_finetuning/npp_atms.yml deleted file mode 100644 index 159cad38e..000000000 --- a/config/streams/era5_nppatms_finetuning/npp_atms.yml +++ /dev/null @@ -1,33 +0,0 @@ -# (C) Copyright 2024 WeatherGenerator contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -NPPATMS : - type : obs - stream_id : 1 - filenames : ['observations-ea-ofb-0001-2012-2023-npp-atms-radiances-v2.zarr'] - loss_weight : 1.0 - diagnostic: True - token_size : 32 - max_num_targets: -1 - embed : - net : transformer - num_tokens : 1 - num_heads : 2 - dim_embed : 128 - num_blocks : 2 - embed_target_coords : - net : linear - dim_embed : 128 - target_readout : - num_layers : 1 - num_heads : 4 - pred_head : - ens_size : 1 - num_layers : 1 - From 0884ae6a570acb3a76d67b4d177fb8c0335c8cf5 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux <24638638+sophie-xhonneux@users.noreply.github.com> Date: Wed, 18 Feb 2026 14:25:50 +0100 Subject: [PATCH 38/38] Plotting config --- config/default_config.yml | 30 ++++++--- config/eval_nppatms.yml | 67 +++++++++++++++++++ .../era5_nppatms_finetuning/nppatms.yml | 28 ++++++++ 3 files changed, 116 insertions(+), 9 deletions(-) create mode 100644 config/eval_nppatms.yml create mode 100644 config/streams/era5_nppatms_finetuning/nppatms.yml diff --git a/config/default_config.yml b/config/default_config.yml index 968ca79f8..1562b1844 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -11,7 +11,7 @@ embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 -ae_local_dim_embed: 1024 +ae_local_dim_embed: 512 ae_local_num_blocks: 2 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 @@ -25,9 +25,9 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 1024 +ae_global_dim_embed: 512 ae_global_num_blocks: 2 -ae_global_num_heads: 32 +ae_global_num_heads: 16 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True # TODO: switching to < 1 triggers triton-related issues. @@ -38,7 +38,7 @@ ae_global_mlp_hidden_factor: 2 ae_global_trailing_layer_norm: False ae_aggregation_num_blocks: 8 -ae_aggregation_num_heads: 32 +ae_aggregation_num_heads: 16 ae_aggregation_dropout_rate: 0.1 ae_aggregation_with_qk_lnorm: True ae_aggregation_att_dense_rate: 1.0 @@ -132,7 +132,7 @@ data_loading : training_config: # training_mode: "masking", "student_teacher", "latent_loss" - training_mode: ["student_teacher"] + training_mode: ["masking", "student_teacher"] # Collapse monitoring for SSL training (JEPA/DINO/iBOT) # Detects representation collapse via various metrics @@ -171,7 +171,7 @@ training_config: learning_rate_scheduling : lr_start: 1e-6 - lr_max: 1e-4 + lr_max: 1e-5 lr_final_decay: 1e-6 lr_final: 0.0 num_steps_warmup: 4096 @@ -203,14 +203,26 @@ training_config: weight_decay: 0.05 losses : { + # "physical": { + # enabled: False, + # type: LossPhysical, + # weight: 0.1, + # loss_fcts: { + # "mse": { + # weight: 1.0, + # target_source_correspondence: { 0 : { 0 : "subset"} }, + # }, + # }, + # target_and_aux_calc: "Physical", + # }, "student-teacher": { enabled: True, type: LossLatentSSLStudentTeacher, weight: 1.0, loss_fcts : { "JEPA": { - 'weight': 4, "loss_extra_args": {}, "out_dim": 1024, "head": transformer, - "num_blocks": 12, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 512, + 'weight': 4, "loss_extra_args": {}, "out_dim": 512, "head": transformer, + "num_blocks": 12, "num_heads": 16, "with_qk_lnorm": True, "intermediate_dim": 256, "dropout_rate": 0.1, target_source_correspondence: {0 : {0 : "subset"} }, }, @@ -249,7 +261,7 @@ training_config: "random_easy_target" : { masking_strategy: "healpix", num_samples: 1, - masking_strategy_config : { rate : 0.66, hl_mask: 4, rate_sampling: True}, + masking_strategy_config : { rate : 0.66, hl_mask: 3, rate_sampling: True}, }, } diff --git a/config/eval_nppatms.yml b/config/eval_nppatms.yml new file mode 100644 index 000000000..75e279688 --- /dev/null +++ b/config/eval_nppatms.yml @@ -0,0 +1,67 @@ +#optional: if commented out all is taken care of by the default settings +# NB. global options apply to all run_ids +#global_plotting_options: +# region: ["belgium", "global"] +# image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. +# dpi_val : 300 +# fps: 2 +# ERA5: +# marker_size: 2 +# scale_marker_size: 1 +# marker: "o" +# # alpha: 0.5 +# 2t: +# vmin: 250 +# vmax: 300 +# 10u: +# vmin: -40 +# vmax: 40 + +evaluation: + metrics : ["rmse", "mae"] + regions: ["global"] + summary_plots : true + ratio_plots : false + heat_maps : false + summary_dir: "./plots/" + plot_ensemble: "members" #supported: false, "std", "minmax", "members" + plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: false + score_cards: false + bar_plots: false + num_processes: 0 #options: int, "auto", 0 means no parallelism (default) + # baseline: "ar40mckx" + + +default_streams: + NPPATMS: + channels: ["obsvalue_rawbt_1", "obsvalue_rawbt_2", "obsvalue_rawbt_3", "obsvalue_rawbt_4", "obsvalue_rawbt_5", "obsvalue_rawbt_6", "obsvalue_rawbt_10", "obsvalue_rawbt_20"] #, "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + evaluation: + forecast_step: "all" + sample: [1, 2, 3, 4, 5, 6, 7] + ensemble: "all" #supported: "all", "mean", [0,1,2] + plotting: + sample: [1, 2, 3, 4, 5, 6, 7] + forecast_step: "all" #supported: "all", [1,2,3,...], "1-50" (equivalent of [1,2,3,...50]) + ensemble: "all" #supported: "all", "mean", [0,1,2] + plot_maps: true + plot_target: true + plot_histograms: true + plot_animations: false + +run_ids : + # diagnostic NPPATMS model, spm6zeor, continued training of pretrained model k8smwg67 + v9ntzbh2: + label: "diagnostic NPPATMS model spm6zeor, cont. k8smwg67" + results_base_dir : "./results/v9ntzbh2/" + # prognostic NPPATMS model + #i74cu321: + # label: "pretrained model i74cu321" + # results_base_dir : "./results/i74cu321/" + # here below we have one without --options test_config.output.normalized_samples=False + # us3wofcj: + # label: "pretrained model us3wofcj" + # results_base_dir : "./results/us3wofcj/" + #NEW: if "streams" is not specified, the default streams are used diff --git a/config/streams/era5_nppatms_finetuning/nppatms.yml b/config/streams/era5_nppatms_finetuning/nppatms.yml new file mode 100644 index 000000000..a1ff0552b --- /dev/null +++ b/config/streams/era5_nppatms_finetuning/nppatms.yml @@ -0,0 +1,28 @@ +# obs_types +# 0 : polar orbiting satellites +# 1 : geostationay satellites +# 2 : conventional observations + +NPPATMS : + type : obs + stream_id : 1 + filenames : ['observations-ea-ofb-0001-2012-2023-npp-atms-radiances-v2.zarr'] + loss_weight : 1.0 + token_size : 32 + diagnostic: True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 2 + dim_embed : 128 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 128 + target_readout : + num_layers : 1 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1