From 71d2cce021d86e2763e595e91b2c3243abdc3a6e Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Wed, 4 Feb 2026 19:53:44 +0100 Subject: [PATCH 01/10] Add collapse monitoring --- config/config_jepa.yml | 26 +- src/weathergen/train/collapse_monitor.py | 356 +++++++++++++++ .../train/target_and_aux_ssl_teacher.py | 20 + src/weathergen/train/trainer.py | 96 ++++ tests/test_collapse_monitor.py | 410 ++++++++++++++++++ 5 files changed, 907 insertions(+), 1 deletion(-) create mode 100644 src/weathergen/train/collapse_monitor.py create mode 100644 tests/test_collapse_monitor.py diff --git a/config/config_jepa.yml b/config/config_jepa.yml index fc27da8c9..f10f16445 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -130,10 +130,34 @@ data_loading : # config for training training_config: - + # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["student_teacher"] + # Collapse monitoring for SSL training (JEPA/DINO/iBOT) + # Detects representation collapse via various metrics + collapse_monitoring: + enabled: true + compute_frequency: 100 # batches between metric computations + log_frequency: 100 # batches between metric logging + metrics: + effective_rank: + enabled: true + tensor_source: "both" # "student", "teacher", or "both" + sample_size: 2048 # max samples for SVD (0 = no sampling) + singular_values: + enabled: true + top_k: 10 + tensor_source: "both" + sample_size: 2048 + dimension_variance: + enabled: true + tensor_source: "both" # cheap to compute, good early indicator + prototype_entropy: + enabled: true # only applies to DINO + ema_beta: + enabled: true + num_mini_epochs: 32 samples_per_mini_epoch: 4096 shuffle: True diff --git a/src/weathergen/train/collapse_monitor.py b/src/weathergen/train/collapse_monitor.py new file mode 100644 index 000000000..b739908e9 --- /dev/null +++ b/src/weathergen/train/collapse_monitor.py @@ -0,0 +1,356 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +Collapse monitoring metrics for SSL training (JEPA/DINO). + +This module implements metrics to detect representation collapse during self-supervised learning: +- Effective Rank (RankMe): Entropy of normalized singular values +- Singular Value Spectrum: Top-k singular values and concentration ratio +- Per-Dimension Variance: Min/mean/max variance across embedding dimensions +- Prototype Entropy: Normalized entropy of DINO prototype assignments +- EMA Beta: Current teacher momentum value + +References: +- RankMe (ICML 2023): https://arxiv.org/abs/2210.02885 +- C-JEPA (NeurIPS 2024): https://arxiv.org/abs/2410.19560 +""" + +from __future__ import annotations + +import logging +from collections import defaultdict +from typing import Any + +import torch + +logger = logging.getLogger(__name__) + + +class CollapseMonitor: + """ + Monitor for detecting representation collapse during SSL training. + + Computes and caches various collapse indicators that can be logged + at configurable intervals to minimize computational overhead. + """ + + def __init__(self, config: dict[str, Any], device: torch.device) -> None: + """ + Initialize the collapse monitor. + + Args: + config: Configuration dictionary with collapse_monitoring settings. + device: Device to use for computations. + """ + self.device = device + self.enabled = config.get("enabled", False) + self.compute_frequency = config.get("compute_frequency", 100) + self.log_frequency = config.get("log_frequency", 100) + + # Metric configurations + metrics_config = config.get("metrics", {}) + + self.effective_rank_config = metrics_config.get("effective_rank", {}) + self.singular_values_config = metrics_config.get("singular_values", {}) + self.dimension_variance_config = metrics_config.get("dimension_variance", {}) + self.prototype_entropy_config = metrics_config.get("prototype_entropy", {}) + self.ema_beta_config = metrics_config.get("ema_beta", {}) + + # Cache for accumulating metrics between log intervals + self._metrics_cache: dict[str, list[float]] = defaultdict(list) + + def should_compute(self, step: int) -> bool: + """Check if metrics should be computed at this step.""" + return self.enabled and step % self.compute_frequency == 0 + + def should_log(self, step: int) -> bool: + """Check if metrics should be logged at this step.""" + return self.enabled and step % self.log_frequency == 0 + + def compute_metrics( + self, + student_latent: torch.Tensor | None = None, + teacher_latent: torch.Tensor | None = None, + prototype_probs: torch.Tensor | None = None, + ema_beta: float | None = None, + loss_type: str | None = None, + ) -> dict[str, float]: + """ + Compute all enabled collapse monitoring metrics. + + Args: + student_latent: Student model latent representations [B, N, D] or [B, D]. + teacher_latent: Teacher model latent representations [B, N, D] or [B, D]. + prototype_probs: Post-softmax prototype assignment probabilities [B, K] (DINO only). + ema_beta: Current EMA momentum value. + loss_type: Type of SSL loss ("JEPA" or "DINO"). + + Returns: + Dictionary of computed metrics. + """ + if not self.enabled: + return {} + + metrics: dict[str, float] = {} + + # Determine which tensors to monitor based on config + tensors_to_monitor: dict[str, torch.Tensor | None] = {} + + effective_rank_source = self.effective_rank_config.get("tensor_source", "both") + sv_source = self.singular_values_config.get("tensor_source", "both") + var_source = self.dimension_variance_config.get("tensor_source", "both") + + # Build tensor dict based on what's requested + if effective_rank_source in ("student", "both") or sv_source in ( + "student", + "both", + ) or var_source in ("student", "both"): + tensors_to_monitor["student"] = student_latent + + if effective_rank_source in ("teacher", "both") or sv_source in ( + "teacher", + "both", + ) or var_source in ("teacher", "both"): + tensors_to_monitor["teacher"] = teacher_latent + + # Compute effective rank + if self.effective_rank_config.get("enabled", True): + sample_size = self.effective_rank_config.get("sample_size", 2048) + for name, tensor in tensors_to_monitor.items(): + if tensor is not None: + source = self.effective_rank_config.get("tensor_source", "both") + if source == "both" or source == name: + eff_rank = self._compute_effective_rank(tensor, sample_size) + metrics[f"collapse.{name}.effective_rank"] = eff_rank + + # Compute singular value spectrum + if self.singular_values_config.get("enabled", True): + top_k = self.singular_values_config.get("top_k", 10) + sample_size = self.singular_values_config.get("sample_size", 2048) + for name, tensor in tensors_to_monitor.items(): + if tensor is not None: + source = self.singular_values_config.get("tensor_source", "both") + if source == "both" or source == name: + sv_metrics = self._compute_singular_values(tensor, top_k, sample_size) + for key, value in sv_metrics.items(): + metrics[f"collapse.{name}.{key}"] = value + + # Compute per-dimension variance + if self.dimension_variance_config.get("enabled", True): + for name, tensor in tensors_to_monitor.items(): + if tensor is not None: + source = self.dimension_variance_config.get("tensor_source", "both") + if source == "both" or source == name: + var_metrics = self._compute_dimension_variance(tensor) + for key, value in var_metrics.items(): + metrics[f"collapse.{name}.{key}"] = value + + # Compute prototype entropy (DINO only) + if ( + self.prototype_entropy_config.get("enabled", True) + and prototype_probs is not None + and loss_type == "DINO" + ): + entropy = self._compute_prototype_entropy(prototype_probs) + metrics["collapse.dino.prototype_entropy"] = entropy + + # Log EMA beta + if self.ema_beta_config.get("enabled", True) and ema_beta is not None: + metrics["collapse.ema_beta"] = ema_beta + + # Cache metrics for averaging + for key, value in metrics.items(): + self._metrics_cache[key].append(value) + + return metrics + + def get_cached_metrics(self) -> dict[str, float]: + """ + Get averaged cached metrics and clear the cache. + + Returns: + Dictionary of averaged metrics since last call. + """ + averaged_metrics: dict[str, float] = {} + for key, values in self._metrics_cache.items(): + if values: + averaged_metrics[key] = sum(values) / len(values) + + self._metrics_cache.clear() + return averaged_metrics + + def _flatten_to_samples(self, z: torch.Tensor) -> torch.Tensor: + """ + Flatten patch dimension into sample dimension. + + Treats [B, N, D] as [B*N, D] where each patch is an independent sample. + This is consistent with C-JEPA/VICReg approach. + + Args: + z: Tensor of shape [B, N, D] or [B, D]. + + Returns: + Tensor of shape [B*N, D] or [B, D]. + """ + if z.ndim == 3: + return z.reshape(-1, z.shape[-1]) + return z + + def _sample_rows(self, z: torch.Tensor, sample_size: int) -> torch.Tensor: + """ + Randomly sample rows to reduce SVD computation cost. + + Args: + z: Tensor of shape [N, D]. + sample_size: Maximum number of samples (0 = no sampling). + + Returns: + Sampled tensor of shape [min(N, sample_size), D]. + """ + if sample_size <= 0 or z.shape[0] <= sample_size: + return z + + indices = torch.randperm(z.shape[0], device=z.device)[:sample_size] + return z[indices] + + def _compute_effective_rank(self, z: torch.Tensor, sample_size: int = 2048) -> float: + """ + Compute effective rank via entropy of normalized singular values (RankMe). + + The effective rank measures how many dimensions are actually being used + in the representation. A low effective rank indicates collapse. + + Args: + z: Latent representations [B, N, D] or [B, D]. + sample_size: Maximum samples for SVD computation. + + Returns: + Effective rank (exp of entropy of normalized singular values). + """ + z = self._flatten_to_samples(z.detach()) + z = self._sample_rows(z, sample_size) + + # Center the data + z_centered = z - z.mean(dim=0, keepdim=True) + + # Compute SVD + try: + _, s, _ = torch.linalg.svd(z_centered, full_matrices=False) + except RuntimeError: + # SVD can fail on degenerate matrices + logger.warning("SVD failed in effective rank computation") + return 0.0 + + # Normalize singular values to get a probability distribution + s_normalized = s / (s.sum() + 1e-8) + + # Compute entropy + entropy = -torch.sum(s_normalized * torch.log(s_normalized + 1e-8)) + + # Effective rank is exp(entropy) + effective_rank = torch.exp(entropy) + + return effective_rank.item() + + def _compute_singular_values( + self, z: torch.Tensor, top_k: int = 10, sample_size: int = 2048 + ) -> dict[str, float]: + """ + Compute top-k singular values and concentration ratio. + + The concentration ratio (top SV / sum of all SVs) indicates how much + variance is captured by the largest singular value. High concentration + suggests dimensional collapse. + + Args: + z: Latent representations [B, N, D] or [B, D]. + top_k: Number of top singular values to return. + sample_size: Maximum samples for SVD computation. + + Returns: + Dictionary with top-k singular values and concentration ratio. + """ + z = self._flatten_to_samples(z.detach()) + z = self._sample_rows(z, sample_size) + + # Center the data + z_centered = z - z.mean(dim=0, keepdim=True) + + # Compute SVD + try: + _, s, _ = torch.linalg.svd(z_centered, full_matrices=False) + except RuntimeError: + logger.warning("SVD failed in singular value computation") + return {} + + metrics: dict[str, float] = {} + + # Top-k singular values + for i in range(min(top_k, len(s))): + metrics[f"singular_value_{i}"] = s[i].item() + + # Concentration ratio (top SV / sum) + s_sum = s.sum() + 1e-8 + metrics["sv_concentration"] = (s[0] / s_sum).item() + + return metrics + + def _compute_dimension_variance(self, z: torch.Tensor) -> dict[str, float]: + """ + Compute per-dimension variance statistics. + + Low minimum variance indicates "dead" dimensions that are not being used. + Large variance ratio (max/min) suggests imbalanced dimension usage. + + Args: + z: Latent representations [B, N, D] or [B, D]. + + Returns: + Dictionary with var_min, var_mean, var_max. + """ + z = self._flatten_to_samples(z.detach()) + + # Compute variance along sample dimension + var_per_dim = z.var(dim=0) + + return { + "var_min": var_per_dim.min().item(), + "var_mean": var_per_dim.mean().item(), + "var_max": var_per_dim.max().item(), + } + + def _compute_prototype_entropy(self, probs: torch.Tensor) -> float: + """ + Compute normalized entropy of DINO prototype assignments. + + Low entropy indicates collapse to few prototypes. Entropy is normalized + to [0, 1] range where 1 means uniform distribution. + + Args: + probs: Post-softmax prototype assignment probabilities [B, K]. + + Returns: + Normalized entropy in [0, 1]. + """ + probs = probs.detach() + + # Average across batch to get prototype usage distribution + avg_probs = probs.mean(dim=0) + + # Compute entropy + entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-8)) + + # Normalize by maximum possible entropy (uniform distribution) + num_prototypes = probs.shape[1] + max_entropy = torch.log(torch.tensor(float(num_prototypes), device=probs.device)) + + normalized_entropy = entropy / (max_entropy + 1e-8) + + return normalized_entropy.item() diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index cfb252f86..99b1f2860 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -77,6 +77,26 @@ def to_device(self, device) -> EMATeacher: module.to(device) return self + def get_current_beta(self, cur_step: int) -> float: + """ + Get current EMA beta value for monitoring. + + The beta value determines how much the teacher model is updated towards + the student model at each step. Higher beta means slower teacher updates. + + Args: + cur_step: Current training step (typically istep * batch_size). + + Returns: + Current EMA beta value. + """ + halflife_steps = self.ema_model.halflife_steps + rampup_ratio = self.ema_model.rampup_ratio + if rampup_ratio is not None: + halflife_steps = min(halflife_steps, cur_step / 1e3 * rampup_ratio) + beta = 0.5 ** (self.batch_size / max(halflife_steps * 1e3, 1e-6)) + return beta + def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): return_dict = {} diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index e949dc1cc..163ce53a8 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -29,8 +29,10 @@ init_model_and_shard, ) from weathergen.model.utils import apply_fct_to_blocks, set_to_eval +from weathergen.train.collapse_monitor import CollapseMonitor from weathergen.train.loss_calculator import LossCalculator from weathergen.train.lr_scheduler import LearningRateScheduler +from weathergen.train.target_and_aux_ssl_teacher import EMATeacher from weathergen.train.trainer_base import TrainerBase from weathergen.train.utils import ( extract_batch_metadata, @@ -74,6 +76,7 @@ def __init__(self, train_log_freq: Config): self.batch_size_per_gpu = -1 self.batch_size_validation_per_gpu = -1 self.batch_size_test_per_gpu = -1 + self.collapse_monitor: CollapseMonitor | None = None def get_batch_size_total(self, batch_size_per_gpu) -> int: """ @@ -146,6 +149,10 @@ def init(self, cf: Config, devices): self.train_logger = TrainLogger(cf, config.get_path_run(self.cf)) + # Initialize collapse monitor for SSL training + collapse_config = self.training_cfg.get("collapse_monitoring", {}) + self.collapse_monitor = CollapseMonitor(collapse_config, None) # device set later in run() + def get_target_aux_calculators(self, mode_cfg): """ Get target_aux_calculators for given mode_cfg @@ -227,6 +234,9 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): device_type = torch.accelerator.current_accelerator() self.device = torch.device(f"{device_type}:{cf.local_rank}") + # Update collapse monitor device + self.collapse_monitor.device = self.device + # create data loaders self.dataset = MultiStreamDataSampler(cf, self.training_cfg, stage=TRAIN) self.dataset_val = MultiStreamDataSampler(cf, self.validation_cfg, stage=VAL) @@ -501,9 +511,16 @@ def train(self, mini_epoch): if self.validate_with_ema: self.ema_model.update(self.cf.general.istep * batch_size_total, batch_size_total) + # Compute collapse monitoring metrics + if self.collapse_monitor.should_compute(self.cf.general.istep): + self._compute_collapse_metrics(preds, targets_and_auxs) + self._log_terminal(bidx, mini_epoch, TRAIN) if bidx % self.train_log_freq.metrics == 0: self._log(TRAIN) + # Log collapse metrics + if self.collapse_monitor.should_log(self.cf.general.istep): + self._log_collapse_metrics(TRAIN) # save model checkpoint (with designation _latest) if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: @@ -775,3 +792,82 @@ def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): logger.info("\n") self.t_start = time.time() + + def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: + """ + Extract latent tensors from predictions and targets, then compute collapse metrics. + + This method extracts the student and teacher latent representations from the + SSL training outputs and passes them to the collapse monitor. + """ + # Get student latents from predictions (first forecast step) + student_latent = None + teacher_latent = None + prototype_probs = None + ema_beta = None + loss_type = None + + # Find SSL loss type and extract latents + for _loss_name, target_aux in targets_and_auxs.items(): + # Check if this is an EMATeacher-based loss + if hasattr(target_aux, "latent") and target_aux.latent: + # Get the first timestep's latent dict + target_latent_dict = target_aux.latent[0] if target_aux.latent else {} + + # Determine the SSL loss type (JEPA, DINO, iBOT) + for ssl_type in ["JEPA", "DINO", "iBOT"]: + if ssl_type in target_latent_dict: + loss_type = ssl_type + # Get teacher latent + teacher_latent_data = target_latent_dict[ssl_type] + if isinstance(teacher_latent_data, list) and len(teacher_latent_data) > 0: + teacher_latent = teacher_latent_data[0] + elif isinstance(teacher_latent_data, dict): + # Handle LatentState or dict + teacher_latent = teacher_latent_data.get( + "latent", teacher_latent_data + ) + else: + teacher_latent = teacher_latent_data + break + + # Get student latents from predictions + if preds.latent and len(preds.latent) > 0: + pred_latent_dict = preds.latent[0] + for ssl_type in ["JEPA", "DINO", "iBOT"]: + if ssl_type in pred_latent_dict: + student_latent_data = pred_latent_dict[ssl_type] + if isinstance(student_latent_data, list) and len(student_latent_data) > 0: + student_latent = student_latent_data[0] + elif isinstance(student_latent_data, dict): + student_latent = student_latent_data.get("latent", student_latent_data) + else: + student_latent = student_latent_data + loss_type = ssl_type + break + + # Get EMA beta from target_and_aux_calculators + for _calc_name, calculator in self.target_and_aux_calculators.items(): + if isinstance(calculator, EMATeacher): + batch_size_total = self.get_batch_size_total(self.batch_size_per_gpu) + step = batch_size_total * self.cf.general.istep + ema_beta = calculator.get_current_beta(step) + break + + # Ensure tensors are properly formatted + if student_latent is not None and isinstance(student_latent, torch.Tensor): + self.collapse_monitor.compute_metrics( + student_latent=student_latent, + teacher_latent=teacher_latent if isinstance(teacher_latent, torch.Tensor) else None, + prototype_probs=prototype_probs, + ema_beta=ema_beta, + loss_type=loss_type, + ) + + def _log_collapse_metrics(self, stage: Stage) -> None: + """ + Log cached collapse monitoring metrics. + """ + metrics = self.collapse_monitor.get_cached_metrics() + if metrics and is_root(): + self.train_logger.log_metrics(stage, metrics) diff --git a/tests/test_collapse_monitor.py b/tests/test_collapse_monitor.py new file mode 100644 index 000000000..5656205f9 --- /dev/null +++ b/tests/test_collapse_monitor.py @@ -0,0 +1,410 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Unit tests for collapse monitoring metrics.""" + +import pytest +import torch + +from weathergen.train.collapse_monitor import CollapseMonitor + + +@pytest.fixture +def default_config(): + """Default enabled config for collapse monitoring.""" + return { + "enabled": True, + "compute_frequency": 100, + "log_frequency": 100, + "metrics": { + "effective_rank": { + "enabled": True, + "tensor_source": "both", + "sample_size": 2048, + }, + "singular_values": { + "enabled": True, + "top_k": 10, + "tensor_source": "both", + "sample_size": 2048, + }, + "dimension_variance": { + "enabled": True, + "tensor_source": "both", + }, + "prototype_entropy": { + "enabled": True, + }, + "ema_beta": { + "enabled": True, + }, + }, + } + + +@pytest.fixture +def monitor(default_config): + """Create a collapse monitor with default config.""" + device = torch.device("cpu") + return CollapseMonitor(default_config, device) + + +class TestCollapseMonitorInitialization: + """Test CollapseMonitor initialization.""" + + def test_disabled_monitor(self): + """Test that disabled monitor doesn't compute metrics.""" + config = {"enabled": False} + monitor = CollapseMonitor(config, torch.device("cpu")) + assert not monitor.enabled + assert not monitor.should_compute(100) + assert not monitor.should_log(100) + + def test_enabled_monitor(self, default_config): + """Test that enabled monitor computes at correct intervals.""" + monitor = CollapseMonitor(default_config, torch.device("cpu")) + assert monitor.enabled + assert monitor.should_compute(0) + assert monitor.should_compute(100) + assert not monitor.should_compute(50) + + def test_frequency_settings(self): + """Test custom frequency settings.""" + config = { + "enabled": True, + "compute_frequency": 50, + "log_frequency": 200, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + assert monitor.should_compute(50) + assert monitor.should_compute(100) # 100 is a multiple of 50 + assert not monitor.should_compute(75) # 75 is not a multiple of 50 + assert monitor.should_log(200) + assert not monitor.should_log(100) + + +class TestEffectiveRank: + """Test effective rank computation.""" + + def test_full_rank_matrix(self, monitor): + """Full rank random matrix should have effective rank close to min(N, D).""" + torch.manual_seed(42) + # Create a full-rank matrix with orthogonal columns + dim = 64 + num_samples = 128 + z = torch.randn(num_samples, dim) + # Make it more orthogonal via QR decomposition + q, _ = torch.linalg.qr(z.T) + z = q.T # Now z is [dim, dim] with orthogonal rows + z = torch.cat([z, torch.randn(num_samples - dim, dim)], dim=0) + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # For a full-rank matrix, effective rank should be significant portion of D + assert eff_rank > dim * 0.3, f"Expected effective rank > {dim * 0.3}, got {eff_rank}" + + def test_low_rank_matrix(self, monitor): + """Low rank matrix should have effective rank close to actual rank.""" + torch.manual_seed(42) + # Create a rank-5 matrix + actual_rank = 5 + num_samples, dim = 128, 64 + u_mat = torch.randn(num_samples, actual_rank) + v_mat = torch.randn(actual_rank, dim) + z = u_mat @ v_mat + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # Effective rank should be close to actual rank + assert eff_rank < actual_rank * 2, ( + f"Expected effective rank < {actual_rank * 2}, got {eff_rank}" + ) + assert eff_rank > actual_rank * 0.5, ( + f"Expected effective rank > {actual_rank * 0.5}, got {eff_rank}" + ) + + def test_collapsed_matrix(self, monitor): + """Completely collapsed matrix should have effective rank ~1.""" + num_samples, dim = 128, 64 + # All rows are the same (rank 1) + row = torch.randn(1, dim) + z = row.expand(num_samples, dim).clone() + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # Effective rank should be very close to 1 + assert eff_rank < 2, f"Expected effective rank < 2, got {eff_rank}" + + def test_3d_tensor_flattening(self, monitor): + """Test that [B, N, D] tensors are properly flattened.""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + z = torch.randn(batch_size, num_patches, dim) + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # Should compute without error and return reasonable value + assert 1 <= eff_rank <= dim + + +class TestSingularValues: + """Test singular value spectrum computation.""" + + def test_top_k_singular_values(self, monitor): + """Test that top-k singular values are correctly computed.""" + torch.manual_seed(42) + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + + sv_metrics = monitor._compute_singular_values(z, top_k=5, sample_size=0) + + # Check that we got top-5 singular values + assert "singular_value_0" in sv_metrics + assert "singular_value_4" in sv_metrics + assert "singular_value_5" not in sv_metrics + + # Singular values should be in descending order + for i in range(4): + assert sv_metrics[f"singular_value_{i}"] >= sv_metrics[f"singular_value_{i + 1}"] + + def test_concentration_ratio(self, monitor): + """Test singular value concentration ratio.""" + torch.manual_seed(42) + # Create a rank-1 matrix where first SV dominates + num_samples, dim = 128, 64 + # Use outer product to create a truly rank-1 dominated matrix + u_vec = torch.randn(num_samples, 1) + v_vec = torch.randn(1, dim) + z = u_vec @ v_vec * 10 + torch.randn(num_samples, dim) * 0.01 # Strong rank-1 component + + sv_metrics = monitor._compute_singular_values(z, top_k=5, sample_size=0) + + # Concentration should be high when one SV dominates + assert "sv_concentration" in sv_metrics + assert sv_metrics["sv_concentration"] > 0.8 # First SV dominates strongly + + def test_uniform_singular_values(self, monitor): + """Test with approximately uniform singular values.""" + torch.manual_seed(42) + # Create orthogonal matrix with equal singular values + dim = 64 + q, _ = torch.linalg.qr(torch.randn(dim, dim)) + z = q * 10 # Scale uniformly + + sv_metrics = monitor._compute_singular_values(z, top_k=5, sample_size=0) + + # Concentration should be low (close to 1/D) + assert sv_metrics["sv_concentration"] < 0.1 + + +class TestDimensionVariance: + """Test per-dimension variance computation.""" + + def test_random_matrix_balanced_variance(self, monitor): + """Random matrix should have balanced variance across dimensions.""" + torch.manual_seed(42) + num_samples, dim = 1024, 64 + z = torch.randn(num_samples, dim) + + var_metrics = monitor._compute_dimension_variance(z) + + # All variances should be close to 1 for standard normal + assert abs(var_metrics["var_mean"] - 1.0) < 0.2 + # Variance ratio should be small for random matrix + var_ratio = var_metrics["var_max"] / (var_metrics["var_min"] + 1e-8) + assert var_ratio < 5 # Balanced dimensions + + def test_dead_dimensions(self, monitor): + """Test detection of dead (zero-variance) dimensions.""" + torch.manual_seed(42) + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + # Kill some dimensions (set to constant) + z[:, :10] = 0.5 + + var_metrics = monitor._compute_dimension_variance(z) + + # Minimum variance should be very close to 0 (dead dimensions) + assert var_metrics["var_min"] < 1e-6 + + def test_imbalanced_dimensions(self, monitor): + """Test with highly imbalanced dimension variances.""" + torch.manual_seed(42) + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + # Scale some dimensions much more than others + z[:, 0] *= 100 + z[:, 1:10] *= 0.01 + + var_metrics = monitor._compute_dimension_variance(z) + + # Large variance ratio indicates imbalance + var_ratio = var_metrics["var_max"] / (var_metrics["var_min"] + 1e-8) + assert var_ratio > 1000 + + +class TestPrototypeEntropy: + """Test DINO prototype entropy computation.""" + + def test_uniform_prototype_distribution(self, monitor): + """Uniform prototype distribution should have entropy ~1.""" + batch_size, num_prototypes = 128, 64 + # Uniform distribution + probs = torch.ones(batch_size, num_prototypes) / num_prototypes + + entropy = monitor._compute_prototype_entropy(probs) + + # Normalized entropy should be close to 1 + assert abs(entropy - 1.0) < 0.01 + + def test_single_prototype_collapse(self, monitor): + """Collapse to single prototype should have entropy ~0.""" + batch_size, num_prototypes = 128, 64 + # All mass on first prototype + probs = torch.zeros(batch_size, num_prototypes) + probs[:, 0] = 1.0 + + entropy = monitor._compute_prototype_entropy(probs) + + # Normalized entropy should be close to 0 + assert entropy < 0.01 + + def test_partial_collapse(self, monitor): + """Partial collapse should have intermediate entropy.""" + batch_size, num_prototypes = 128, 64 + # Only 4 prototypes used uniformly (much stronger collapse) + probs = torch.zeros(batch_size, num_prototypes) + probs[:, :4] = 0.25 # Only 4 out of 64 prototypes + + entropy = monitor._compute_prototype_entropy(probs) + + # Entropy should be between 0 and 1 (log(4)/log(64) ≈ 0.33) + assert 0.2 < entropy < 0.5 + + +class TestMetricsCaching: + """Test metrics caching and averaging.""" + + def test_cache_accumulation(self, monitor): + """Test that metrics are properly cached.""" + torch.manual_seed(42) + z1 = torch.randn(64, 32) + z2 = torch.randn(64, 32) + + # Compute metrics twice + monitor.compute_metrics(student_latent=z1) + monitor.compute_metrics(student_latent=z2) + + # Cache should contain averaged values + cached = monitor.get_cached_metrics() + assert "collapse.student.effective_rank" in cached + + def test_cache_clear(self, monitor): + """Test that cache is cleared after get_cached_metrics.""" + torch.manual_seed(42) + z = torch.randn(64, 32) + + monitor.compute_metrics(student_latent=z) + _ = monitor.get_cached_metrics() + + # Second call should return empty + cached = monitor.get_cached_metrics() + assert len(cached) == 0 + + +class TestIntegration: + """Integration tests with both student and teacher tensors.""" + + def test_full_metrics_computation(self, monitor): + """Test computing all metrics with both student and teacher.""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + student = torch.randn(batch_size, num_patches, dim) + teacher = torch.randn(batch_size, num_patches, dim) + + metrics = monitor.compute_metrics( + student_latent=student, + teacher_latent=teacher, + ema_beta=0.999, + loss_type="JEPA", + ) + + # Check that both student and teacher metrics are computed + assert "collapse.student.effective_rank" in metrics + assert "collapse.teacher.effective_rank" in metrics + assert "collapse.student.var_min" in metrics + assert "collapse.teacher.var_min" in metrics + assert "collapse.ema_beta" in metrics + assert metrics["collapse.ema_beta"] == 0.999 + + def test_dino_prototype_entropy(self, monitor): + """Test DINO prototype entropy computation.""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + num_prototypes = 128 + student = torch.randn(batch_size, num_patches, dim) + probs = torch.softmax(torch.randn(batch_size, num_prototypes), dim=-1) + + metrics = monitor.compute_metrics( + student_latent=student, + prototype_probs=probs, + loss_type="DINO", + ) + + assert "collapse.dino.prototype_entropy" in metrics + assert 0 <= metrics["collapse.dino.prototype_entropy"] <= 1 + + def test_disabled_metrics(self): + """Test that disabled metrics are not computed.""" + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": {"enabled": False}, + "singular_values": {"enabled": False}, + "dimension_variance": {"enabled": True, "tensor_source": "student"}, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + + torch.manual_seed(42) + z = torch.randn(64, 32) + metrics = monitor.compute_metrics(student_latent=z) + + # Only dimension variance should be computed + assert "collapse.student.var_min" in metrics + assert "collapse.student.effective_rank" not in metrics + assert "collapse.student.singular_value_0" not in metrics + + +class TestSampling: + """Test row sampling for SVD computations.""" + + def test_sampling_reduces_computation(self, monitor): + """Test that sampling works for large tensors.""" + torch.manual_seed(42) + num_samples, dim = 10000, 64 + z = torch.randn(num_samples, dim) + + # With sampling + eff_rank_sampled = monitor._compute_effective_rank(z, sample_size=1024) + # Without sampling + eff_rank_full = monitor._compute_effective_rank(z, sample_size=0) + + # Results should be in same ballpark + assert abs(eff_rank_sampled - eff_rank_full) < eff_rank_full * 0.3 + + def test_no_sampling_when_small(self, monitor): + """Test that small tensors aren't sampled.""" + torch.manual_seed(42) + num_samples, dim = 100, 64 + z = torch.randn(num_samples, dim) + + # Sample size larger than N + sampled = monitor._sample_rows(z, sample_size=1024) + assert sampled.shape[0] == num_samples # No sampling occurred From 1d296117ca92d398e4e3b4785cfd9236a27c1ad0 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Wed, 4 Feb 2026 20:14:17 +0100 Subject: [PATCH 02/10] Fix bug --- src/weathergen/train/trainer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 163ce53a8..7514860a0 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -811,8 +811,14 @@ def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: for _loss_name, target_aux in targets_and_auxs.items(): # Check if this is an EMATeacher-based loss if hasattr(target_aux, "latent") and target_aux.latent: - # Get the first timestep's latent dict - target_latent_dict = target_aux.latent[0] if target_aux.latent else {} + # Handle both cases: + # 1. latent is a list[dict] (as per TargetAuxOutput dataclass) + # 2. latent is a dict (as set directly by EMATeacher) + if isinstance(target_aux.latent, list): + target_latent_dict = target_aux.latent[0] if target_aux.latent else {} + else: + # EMATeacher sets latent directly as a dict + target_latent_dict = target_aux.latent # Determine the SSL loss type (JEPA, DINO, iBOT) for ssl_type in ["JEPA", "DINO", "iBOT"]: From bc92ae7ed53422180d7c2278b2a9a854551e4080 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Wed, 4 Feb 2026 20:25:03 +0100 Subject: [PATCH 03/10] Fix SVD computation failing --- src/weathergen/train/collapse_monitor.py | 34 +++++++++++++++++++++--- src/weathergen/train/trainer.py | 18 +++++++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/src/weathergen/train/collapse_monitor.py b/src/weathergen/train/collapse_monitor.py index b739908e9..3b89d0dec 100644 --- a/src/weathergen/train/collapse_monitor.py +++ b/src/weathergen/train/collapse_monitor.py @@ -199,6 +199,10 @@ def _flatten_to_samples(self, z: torch.Tensor) -> torch.Tensor: Returns: Tensor of shape [B*N, D] or [B, D]. """ + # Convert to float32 for SVD compatibility (bfloat16/float16 can fail) + if z.dtype in (torch.bfloat16, torch.float16): + z = z.float() + if z.ndim == 3: return z.reshape(-1, z.shape[-1]) return z @@ -237,15 +241,26 @@ def _compute_effective_rank(self, z: torch.Tensor, sample_size: int = 2048) -> f z = self._flatten_to_samples(z.detach()) z = self._sample_rows(z, sample_size) + # Validate tensor before SVD + if z.numel() == 0: + logger.warning("Empty tensor in effective rank computation") + return 0.0 + if torch.isnan(z).any() or torch.isinf(z).any(): + logger.warning("NaN/Inf values in tensor for effective rank computation") + return 0.0 + if z.shape[0] < 2 or z.shape[1] < 2: + logger.warning(f"Tensor too small for SVD: shape={z.shape}") + return 0.0 + # Center the data z_centered = z - z.mean(dim=0, keepdim=True) # Compute SVD try: _, s, _ = torch.linalg.svd(z_centered, full_matrices=False) - except RuntimeError: + except RuntimeError as e: # SVD can fail on degenerate matrices - logger.warning("SVD failed in effective rank computation") + logger.warning(f"SVD failed in effective rank computation: {e}, shape={z.shape}") return 0.0 # Normalize singular values to get a probability distribution @@ -280,14 +295,25 @@ def _compute_singular_values( z = self._flatten_to_samples(z.detach()) z = self._sample_rows(z, sample_size) + # Validate tensor before SVD + if z.numel() == 0: + logger.warning("Empty tensor in singular value computation") + return {} + if torch.isnan(z).any() or torch.isinf(z).any(): + logger.warning("NaN/Inf values in tensor for singular value computation") + return {} + if z.shape[0] < 2 or z.shape[1] < 2: + logger.warning(f"Tensor too small for SVD: shape={z.shape}") + return {} + # Center the data z_centered = z - z.mean(dim=0, keepdim=True) # Compute SVD try: _, s, _ = torch.linalg.svd(z_centered, full_matrices=False) - except RuntimeError: - logger.warning("SVD failed in singular value computation") + except RuntimeError as e: + logger.warning(f"SVD failed in singular value computation: {e}, shape={z.shape}") return {} metrics: dict[str, float] = {} diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 7514860a0..af89e3598 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -860,6 +860,19 @@ def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: ema_beta = calculator.get_current_beta(step) break + # Debug logging for tensor extraction + if student_latent is not None: + shape = student_latent.shape if isinstance(student_latent, torch.Tensor) else "N/A" + logger.debug(f"Collapse monitor - student: type={type(student_latent)}, shape={shape}") + else: + logger.debug("Collapse monitor - student_latent is None") + + if teacher_latent is not None: + shape = teacher_latent.shape if isinstance(teacher_latent, torch.Tensor) else "N/A" + logger.debug(f"Collapse monitor - teacher: type={type(teacher_latent)}, shape={shape}") + else: + logger.debug("Collapse monitor - teacher_latent is None") + # Ensure tensors are properly formatted if student_latent is not None and isinstance(student_latent, torch.Tensor): self.collapse_monitor.compute_metrics( @@ -869,6 +882,11 @@ def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: ema_beta=ema_beta, loss_type=loss_type, ) + else: + logger.debug( + f"Collapse monitor - skipping compute_metrics: " + f"student_latent is {'None' if student_latent is None else type(student_latent)}" + ) def _log_collapse_metrics(self, stage: Stage) -> None: """ From 7693c1903fd3f33e533d7422ae6e0a578b247161 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Wed, 4 Feb 2026 21:15:15 +0100 Subject: [PATCH 04/10] Reduce variables logged --- config/config_jepa.yml | 1 - src/weathergen/train/collapse_monitor.py | 17 ++++---- tests/test_collapse_monitor.py | 49 ++++++++++++++---------- 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/config/config_jepa.yml b/config/config_jepa.yml index f10f16445..464cc7e60 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -147,7 +147,6 @@ training_config: sample_size: 2048 # max samples for SVD (0 = no sampling) singular_values: enabled: true - top_k: 10 tensor_source: "both" sample_size: 2048 dimension_variance: diff --git a/src/weathergen/train/collapse_monitor.py b/src/weathergen/train/collapse_monitor.py index 3b89d0dec..d77060ad8 100644 --- a/src/weathergen/train/collapse_monitor.py +++ b/src/weathergen/train/collapse_monitor.py @@ -132,13 +132,12 @@ def compute_metrics( # Compute singular value spectrum if self.singular_values_config.get("enabled", True): - top_k = self.singular_values_config.get("top_k", 10) sample_size = self.singular_values_config.get("sample_size", 2048) for name, tensor in tensors_to_monitor.items(): if tensor is not None: source = self.singular_values_config.get("tensor_source", "both") if source == "both" or source == name: - sv_metrics = self._compute_singular_values(tensor, top_k, sample_size) + sv_metrics = self._compute_singular_values(tensor, sample_size) for key, value in sv_metrics.items(): metrics[f"collapse.{name}.{key}"] = value @@ -275,10 +274,10 @@ def _compute_effective_rank(self, z: torch.Tensor, sample_size: int = 2048) -> f return effective_rank.item() def _compute_singular_values( - self, z: torch.Tensor, top_k: int = 10, sample_size: int = 2048 + self, z: torch.Tensor, sample_size: int = 2048 ) -> dict[str, float]: """ - Compute top-k singular values and concentration ratio. + Compute singular value statistics and concentration ratio. The concentration ratio (top SV / sum of all SVs) indicates how much variance is captured by the largest singular value. High concentration @@ -286,11 +285,10 @@ def _compute_singular_values( Args: z: Latent representations [B, N, D] or [B, D]. - top_k: Number of top singular values to return. sample_size: Maximum samples for SVD computation. Returns: - Dictionary with top-k singular values and concentration ratio. + Dictionary with sv_min, sv_max, sv_mean, and sv_concentration. """ z = self._flatten_to_samples(z.detach()) z = self._sample_rows(z, sample_size) @@ -318,9 +316,10 @@ def _compute_singular_values( metrics: dict[str, float] = {} - # Top-k singular values - for i in range(min(top_k, len(s))): - metrics[f"singular_value_{i}"] = s[i].item() + # Singular value statistics + metrics["sv_min"] = s.min().item() + metrics["sv_max"] = s.max().item() + metrics["sv_mean"] = s.mean().item() # Concentration ratio (top SV / sum) s_sum = s.sum() + 1e-8 diff --git a/tests/test_collapse_monitor.py b/tests/test_collapse_monitor.py index 5656205f9..6a6f2ed8c 100644 --- a/tests/test_collapse_monitor.py +++ b/tests/test_collapse_monitor.py @@ -30,7 +30,6 @@ def default_config(): }, "singular_values": { "enabled": True, - "top_k": 10, "tensor_source": "both", "sample_size": 2048, }, @@ -152,22 +151,23 @@ def test_3d_tensor_flattening(self, monitor): class TestSingularValues: """Test singular value spectrum computation.""" - def test_top_k_singular_values(self, monitor): - """Test that top-k singular values are correctly computed.""" + def test_singular_value_statistics(self, monitor): + """Test that singular value statistics are correctly computed.""" torch.manual_seed(42) num_samples, dim = 128, 64 z = torch.randn(num_samples, dim) - sv_metrics = monitor._compute_singular_values(z, top_k=5, sample_size=0) + sv_metrics = monitor._compute_singular_values(z, sample_size=0) - # Check that we got top-5 singular values - assert "singular_value_0" in sv_metrics - assert "singular_value_4" in sv_metrics - assert "singular_value_5" not in sv_metrics + # Check that we got min, max, mean statistics + assert "sv_min" in sv_metrics + assert "sv_max" in sv_metrics + assert "sv_mean" in sv_metrics + assert "sv_concentration" in sv_metrics - # Singular values should be in descending order - for i in range(4): - assert sv_metrics[f"singular_value_{i}"] >= sv_metrics[f"singular_value_{i + 1}"] + # Max should be >= mean >= min + assert sv_metrics["sv_max"] >= sv_metrics["sv_mean"] + assert sv_metrics["sv_mean"] >= sv_metrics["sv_min"] def test_concentration_ratio(self, monitor): """Test singular value concentration ratio.""" @@ -179,24 +179,31 @@ def test_concentration_ratio(self, monitor): v_vec = torch.randn(1, dim) z = u_vec @ v_vec * 10 + torch.randn(num_samples, dim) * 0.01 # Strong rank-1 component - sv_metrics = monitor._compute_singular_values(z, top_k=5, sample_size=0) + sv_metrics = monitor._compute_singular_values(z, sample_size=0) # Concentration should be high when one SV dominates assert "sv_concentration" in sv_metrics assert sv_metrics["sv_concentration"] > 0.8 # First SV dominates strongly + # Max should be much larger than min for rank-1 dominated matrix + assert sv_metrics["sv_max"] > sv_metrics["sv_min"] * 10 + def test_uniform_singular_values(self, monitor): - """Test with approximately uniform singular values.""" + """Test with random matrix (spread singular values).""" torch.manual_seed(42) - # Create orthogonal matrix with equal singular values - dim = 64 - q, _ = torch.linalg.qr(torch.randn(dim, dim)) - z = q * 10 # Scale uniformly + # Random matrix will have spread singular values + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + + sv_metrics = monitor._compute_singular_values(z, sample_size=0) - sv_metrics = monitor._compute_singular_values(z, top_k=5, sample_size=0) + # Concentration should be relatively low for random matrix + assert sv_metrics["sv_concentration"] < 0.2 - # Concentration should be low (close to 1/D) - assert sv_metrics["sv_concentration"] < 0.1 + # All statistics should be positive + assert sv_metrics["sv_min"] > 0 + assert sv_metrics["sv_max"] > 0 + assert sv_metrics["sv_mean"] > 0 class TestDimensionVariance: @@ -379,7 +386,7 @@ def test_disabled_metrics(self): # Only dimension variance should be computed assert "collapse.student.var_min" in metrics assert "collapse.student.effective_rank" not in metrics - assert "collapse.student.singular_value_0" not in metrics + assert "collapse.student.sv_max" not in metrics class TestSampling: From 7f8de00c84696f9c476fe441a59de693039c5c08 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Wed, 4 Feb 2026 22:03:09 +0100 Subject: [PATCH 05/10] Fix EMA beta value computation --- src/weathergen/model/ema.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/ema.py b/src/weathergen/model/ema.py index 141947863..08f367116 100644 --- a/src/weathergen/model/ema.py +++ b/src/weathergen/model/ema.py @@ -63,8 +63,8 @@ def update(self, cur_step, batch_size): # determine correct interpolation params halflife_steps = self.halflife_steps if self.rampup_ratio is not None: - halflife_steps = min(halflife_steps, cur_step / 1e3 * self.rampup_ratio) - beta = 0.5 ** (batch_size / max(halflife_steps * 1e3, 1e-6)) + halflife_steps = min(halflife_steps, cur_step / self.rampup_ratio) + beta = 0.5 ** (batch_size / max(halflife_steps, 1e-6)) for name, p_ema in self.ema_model.named_parameters(): p_src = self.src_params.get(name, None) From c3eb019adfdfcf5e79d48b0b00710f367fe10caa Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Wed, 4 Feb 2026 22:06:59 +0100 Subject: [PATCH 06/10] Refactor get_current_beta to ema.py --- src/weathergen/model/ema.py | 24 +++++++++++++++---- .../train/target_and_aux_ssl_teacher.py | 19 +-------------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/weathergen/model/ema.py b/src/weathergen/model/ema.py index 08f367116..f126e0d83 100644 --- a/src/weathergen/model/ema.py +++ b/src/weathergen/model/ema.py @@ -55,16 +55,32 @@ def requires_grad_(self, flag: bool): for p in self.ema_model.parameters(): p.requires_grad = flag + def get_current_beta(self, cur_step: int) -> float: + """ + Get current EMA beta value for monitoring. + + The beta value determines how much the teacher model is updated towards + the student model at each step. Higher beta means slower teacher updates. + + Args: + cur_step: Current training step (typically istep * batch_size). + + Returns: + Current EMA beta value. + """ + halflife_steps = self.ema_model.halflife_steps + if self.rampup_ratio is not None: + halflife_steps = min(halflife_steps, cur_step / self.rampup_ratio) + beta = 0.5 ** (self.batch_size / max(halflife_steps, 1e-6)) + return beta + @torch.no_grad() def update(self, cur_step, batch_size): # ensure model remains sharded if self.is_model_sharded: self.ema_model.reshard() # determine correct interpolation params - halflife_steps = self.halflife_steps - if self.rampup_ratio is not None: - halflife_steps = min(halflife_steps, cur_step / self.rampup_ratio) - beta = 0.5 ** (batch_size / max(halflife_steps, 1e-6)) + beta = self.get_current_beta(cur_step) for name, p_ema in self.ema_model.named_parameters(): p_src = self.src_params.get(name, None) diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index 99b1f2860..05213931a 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -78,24 +78,7 @@ def to_device(self, device) -> EMATeacher: return self def get_current_beta(self, cur_step: int) -> float: - """ - Get current EMA beta value for monitoring. - - The beta value determines how much the teacher model is updated towards - the student model at each step. Higher beta means slower teacher updates. - - Args: - cur_step: Current training step (typically istep * batch_size). - - Returns: - Current EMA beta value. - """ - halflife_steps = self.ema_model.halflife_steps - rampup_ratio = self.ema_model.rampup_ratio - if rampup_ratio is not None: - halflife_steps = min(halflife_steps, cur_step / 1e3 * rampup_ratio) - beta = 0.5 ** (self.batch_size / max(halflife_steps * 1e3, 1e-6)) - return beta + return self.ema_model.get_current_beta(cur_step) def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): From 59a0a8972c1bc933ab94c2691cf3b3cf3f684bab Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Wed, 4 Feb 2026 21:33:29 +0000 Subject: [PATCH 07/10] Sensible default for ema in jepa --- config/config_jepa.yml | 4 ++-- src/weathergen/model/ema.py | 4 +++- src/weathergen/train/target_and_aux_ssl_teacher.py | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/config/config_jepa.yml b/config/config_jepa.yml index 464cc7e60..f90a3a88f 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -205,8 +205,8 @@ training_config: }, }, target_and_aux_calc: { "EMATeacher" : - { ema_ramp_up_ratio : 0.09, - ema_halflife_in_thousands: 1e-3, + { ema_ramp_up_ratio : null, + ema_halflife_in_thousands: 1e-1, model_param_overrides : { training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} }, diff --git a/src/weathergen/model/ema.py b/src/weathergen/model/ema.py index f126e0d83..b42d756d4 100644 --- a/src/weathergen/model/ema.py +++ b/src/weathergen/model/ema.py @@ -30,6 +30,7 @@ def __init__( self.rampup_ratio = rampup_ratio self.ema_model = empty_model self.is_model_sharded = is_model_sharded + self.batch_size = 1 # Build a name → param map once self.src_params = dict(self.original_model.named_parameters()) @@ -68,7 +69,7 @@ def get_current_beta(self, cur_step: int) -> float: Returns: Current EMA beta value. """ - halflife_steps = self.ema_model.halflife_steps + halflife_steps = self.halflife_steps if self.rampup_ratio is not None: halflife_steps = min(halflife_steps, cur_step / self.rampup_ratio) beta = 0.5 ** (self.batch_size / max(halflife_steps, 1e-6)) @@ -80,6 +81,7 @@ def update(self, cur_step, batch_size): if self.is_model_sharded: self.ema_model.reshard() # determine correct interpolation params + self.batch_size = batch_size beta = self.get_current_beta(cur_step) for name, p_ema in self.ema_model.named_parameters(): diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index 05213931a..76994221c 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -78,7 +78,8 @@ def to_device(self, device) -> EMATeacher: return self def get_current_beta(self, cur_step: int) -> float: - return self.ema_model.get_current_beta(cur_step) + beta = self.ema_model.get_current_beta(cur_step) + return beta def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): From ebbbf33c262d0660205c8733fb91a738d8996a9f Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Thu, 5 Feb 2026 22:05:47 +0100 Subject: [PATCH 08/10] Allow collapse monitoring for forecasting --- config/config_jepa.yml | 6 + src/weathergen/train/collapse_monitor.py | 319 +++++++++++++--- src/weathergen/train/trainer.py | 4 +- tests/test_collapse_monitor.py | 461 +++++++++++++++++++++++ 4 files changed, 739 insertions(+), 51 deletions(-) diff --git a/config/config_jepa.yml b/config/config_jepa.yml index f90a3a88f..9b5495897 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -136,6 +136,7 @@ training_config: # Collapse monitoring for SSL training (JEPA/DINO/iBOT) # Detects representation collapse via various metrics + # For forecasting, supports sequences of latents with per-step and aggregate metrics collapse_monitoring: enabled: true compute_frequency: 100 # batches between metric computations @@ -145,13 +146,18 @@ training_config: enabled: true tensor_source: "both" # "student", "teacher", or "both" sample_size: 2048 # max samples for SVD (0 = no sampling) + # For forecasting sequences: "all" (per-step + aggregates), + # "aggregate_only" (mean/min/max/degradation), "per_step_only" + forecast_aggregation: "all" singular_values: enabled: true tensor_source: "both" sample_size: 2048 + forecast_aggregation: "all" dimension_variance: enabled: true tensor_source: "both" # cheap to compute, good early indicator + forecast_aggregation: "all" prototype_entropy: enabled: true # only applies to DINO ema_beta: diff --git a/src/weathergen/train/collapse_monitor.py b/src/weathergen/train/collapse_monitor.py index d77060ad8..2ab2f1eb7 100644 --- a/src/weathergen/train/collapse_monitor.py +++ b/src/weathergen/train/collapse_monitor.py @@ -17,6 +17,11 @@ - Prototype Entropy: Normalized entropy of DINO prototype assignments - EMA Beta: Current teacher momentum value +For forecasting, the monitor supports sequences of latents (one per time step) and computes: +- Per-step metrics (e.g., effective_rank.step_0, step_1, ...) +- Aggregate metrics (mean, min across steps) +- Degradation ratio (final step / initial step) + References: - RankMe (ICML 2023): https://arxiv.org/abs/2210.02885 - C-JEPA (NeurIPS 2024): https://arxiv.org/abs/2410.19560 @@ -26,12 +31,19 @@ import logging from collections import defaultdict +from collections.abc import Callable from typing import Any import torch logger = logging.getLogger(__name__) +# Valid values for tensor_source config option +VALID_TENSOR_SOURCES = frozenset({"student", "teacher", "both"}) + +# Valid values for forecast_aggregation config option +VALID_FORECAST_AGGREGATIONS = frozenset({"all", "aggregate_only", "per_step_only"}) + class CollapseMonitor: """ @@ -39,6 +51,8 @@ class CollapseMonitor: Computes and caches various collapse indicators that can be logged at configurable intervals to minimize computational overhead. + + Supports both single latent tensors and sequences of latents for forecasting. """ def __init__(self, config: dict[str, Any], device: torch.device) -> None: @@ -47,13 +61,22 @@ def __init__(self, config: dict[str, Any], device: torch.device) -> None: Args: config: Configuration dictionary with collapse_monitoring settings. - device: Device to use for computations. + device: Device to use for computations (currently unused, tensors + are processed on their original device). + + Raises: + ValueError: If config contains invalid values. """ - self.device = device self.enabled = config.get("enabled", False) self.compute_frequency = config.get("compute_frequency", 100) self.log_frequency = config.get("log_frequency", 100) + # Validate frequencies + if self.compute_frequency <= 0: + raise ValueError(f"compute_frequency must be positive, got {self.compute_frequency}") + if self.log_frequency <= 0: + raise ValueError(f"log_frequency must be positive, got {self.log_frequency}") + # Metric configurations metrics_config = config.get("metrics", {}) @@ -63,8 +86,38 @@ def __init__(self, config: dict[str, Any], device: torch.device) -> None: self.prototype_entropy_config = metrics_config.get("prototype_entropy", {}) self.ema_beta_config = metrics_config.get("ema_beta", {}) + # Validate tensor_source values + self._validate_tensor_source(self.effective_rank_config, "effective_rank") + self._validate_tensor_source(self.singular_values_config, "singular_values") + self._validate_tensor_source(self.dimension_variance_config, "dimension_variance") + + # Validate forecast_aggregation values + self._validate_forecast_aggregation(self.effective_rank_config, "effective_rank") + self._validate_forecast_aggregation(self.singular_values_config, "singular_values") + self._validate_forecast_aggregation(self.dimension_variance_config, "dimension_variance") + # Cache for accumulating metrics between log intervals - self._metrics_cache: dict[str, list[float]] = defaultdict(list) + self._metrics_cache: defaultdict[str, list[float]] = defaultdict(list) + + def _validate_tensor_source(self, metric_config: dict[str, Any], metric_name: str) -> None: + """Validate tensor_source config value.""" + source = metric_config.get("tensor_source", "both") + if source not in VALID_TENSOR_SOURCES: + raise ValueError( + f"Invalid tensor_source '{source}' for {metric_name}. " + f"Must be one of: {sorted(VALID_TENSOR_SOURCES)}" + ) + + def _validate_forecast_aggregation( + self, metric_config: dict[str, Any], metric_name: str + ) -> None: + """Validate forecast_aggregation config value.""" + aggregation = metric_config.get("forecast_aggregation", "all") + if aggregation not in VALID_FORECAST_AGGREGATIONS: + raise ValueError( + f"Invalid forecast_aggregation '{aggregation}' for {metric_name}. " + f"Must be one of: {sorted(VALID_FORECAST_AGGREGATIONS)}" + ) def should_compute(self, step: int) -> bool: """Check if metrics should be computed at this step.""" @@ -74,10 +127,140 @@ def should_log(self, step: int) -> bool: """Check if metrics should be logged at this step.""" return self.enabled and step % self.log_frequency == 0 + def _get_tensors_to_monitor( + self, + student_latent: torch.Tensor | list[torch.Tensor] | None, + teacher_latent: torch.Tensor | list[torch.Tensor] | None, + metric_config: dict[str, Any], + ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: + """ + Get tensors to monitor based on metric config's tensor_source. + + Args: + student_latent: Student latent(s). + teacher_latent: Teacher latent(s). + metric_config: Config dict with tensor_source key. + + Returns: + Dict mapping "student"/"teacher" to their tensors (if requested). + """ + source = metric_config.get("tensor_source", "both") + result: dict[str, torch.Tensor | list[torch.Tensor] | None] = {} + + if source in ("student", "both"): + result["student"] = student_latent + if source in ("teacher", "both"): + result["teacher"] = teacher_latent + + return result + + def _compute_sequence_metrics( + self, + latents: list[torch.Tensor], + compute_fn: Callable[..., float], + metric_name: str, + aggregation: str, + **kwargs: Any, + ) -> dict[str, float]: + """ + Compute metrics for a sequence of latents (forecasting). + + Args: + latents: List of latent tensors, one per time step. + compute_fn: Function to compute metric for a single tensor. + metric_name: Base name for the metric (e.g., "effective_rank"). + aggregation: One of "all", "aggregate_only", "per_step_only". + **kwargs: Additional arguments passed to compute_fn. + + Returns: + Dictionary of metrics with per-step and/or aggregate values. + """ + metrics: dict[str, float] = {} + + if not latents: + return metrics + + # Compute per-step metrics + step_values: list[float] = [] + for step_idx, latent in enumerate(latents): + value = compute_fn(latent, **kwargs) + step_values.append(value) + + if aggregation in ("all", "per_step_only"): + metrics[f"{metric_name}.step_{step_idx}"] = value + + # Compute aggregate metrics + if aggregation in ("all", "aggregate_only") and step_values: + # Filter out invalid values (0.0 indicates computation failure) + valid_values = [v for v in step_values if v > 0] + + if valid_values: + metrics[f"{metric_name}.mean"] = sum(valid_values) / len(valid_values) + metrics[f"{metric_name}.min"] = min(valid_values) + metrics[f"{metric_name}.max"] = max(valid_values) + + # Degradation: ratio of last step to first step + # Values > 1 mean rank increased, < 1 means degradation + if step_values[0] > 0 and step_values[-1] > 0: + metrics[f"{metric_name}.degradation"] = step_values[-1] / step_values[0] + + return metrics + + def _compute_sequence_dict_metrics( + self, + latents: list[torch.Tensor], + compute_fn: Callable[..., dict[str, float]], + base_prefix: str, + aggregation: str, + **kwargs: Any, + ) -> dict[str, float]: + """ + Compute dict-returning metrics for a sequence of latents. + + For metrics like singular_values that return multiple values per tensor. + + Args: + latents: List of latent tensors. + compute_fn: Function returning dict of metrics for single tensor. + base_prefix: Prefix for metric names (e.g., "collapse.student"). + aggregation: One of "all", "aggregate_only", "per_step_only". + **kwargs: Additional arguments passed to compute_fn. + + Returns: + Dictionary of metrics. + """ + metrics: dict[str, float] = {} + + if not latents: + return metrics + + # Collect per-step values for each sub-metric + step_metrics: dict[str, list[float]] = defaultdict(list) + + for step_idx, latent in enumerate(latents): + step_result = compute_fn(latent, **kwargs) + + for key, value in step_result.items(): + step_metrics[key].append(value) + + if aggregation in ("all", "per_step_only"): + metrics[f"{base_prefix}.{key}.step_{step_idx}"] = value + + # Compute aggregates for each sub-metric + if aggregation in ("all", "aggregate_only"): + for key, values in step_metrics.items(): + valid_values = [v for v in values if v > 0 or key.startswith("var_")] + if valid_values: + metrics[f"{base_prefix}.{key}.mean"] = sum(valid_values) / len(valid_values) + metrics[f"{base_prefix}.{key}.min"] = min(valid_values) + metrics[f"{base_prefix}.{key}.max"] = max(valid_values) + + return metrics + def compute_metrics( self, - student_latent: torch.Tensor | None = None, - teacher_latent: torch.Tensor | None = None, + student_latent: torch.Tensor | list[torch.Tensor] | None = None, + teacher_latent: torch.Tensor | list[torch.Tensor] | None = None, prototype_probs: torch.Tensor | None = None, ema_beta: float | None = None, loss_type: str | None = None, @@ -85,9 +268,14 @@ def compute_metrics( """ Compute all enabled collapse monitoring metrics. + Supports both single tensors and sequences of tensors (for forecasting). + For sequences, computes per-step metrics and aggregates based on config. + Args: - student_latent: Student model latent representations [B, N, D] or [B, D]. - teacher_latent: Teacher model latent representations [B, N, D] or [B, D]. + student_latent: Student model latent representations. + Single tensor [B, N, D] or [B, D], or list of tensors for forecasting. + teacher_latent: Teacher model latent representations. + Single tensor [B, N, D] or [B, D], or list of tensors for forecasting. prototype_probs: Post-softmax prototype assignment probabilities [B, K] (DINO only). ema_beta: Current EMA momentum value. loss_type: Type of SSL loss ("JEPA" or "DINO"). @@ -100,56 +288,80 @@ def compute_metrics( metrics: dict[str, float] = {} - # Determine which tensors to monitor based on config - tensors_to_monitor: dict[str, torch.Tensor | None] = {} - - effective_rank_source = self.effective_rank_config.get("tensor_source", "both") - sv_source = self.singular_values_config.get("tensor_source", "both") - var_source = self.dimension_variance_config.get("tensor_source", "both") - - # Build tensor dict based on what's requested - if effective_rank_source in ("student", "both") or sv_source in ( - "student", - "both", - ) or var_source in ("student", "both"): - tensors_to_monitor["student"] = student_latent - - if effective_rank_source in ("teacher", "both") or sv_source in ( - "teacher", - "both", - ) or var_source in ("teacher", "both"): - tensors_to_monitor["teacher"] = teacher_latent - # Compute effective rank if self.effective_rank_config.get("enabled", True): sample_size = self.effective_rank_config.get("sample_size", 2048) - for name, tensor in tensors_to_monitor.items(): - if tensor is not None: - source = self.effective_rank_config.get("tensor_source", "both") - if source == "both" or source == name: - eff_rank = self._compute_effective_rank(tensor, sample_size) - metrics[f"collapse.{name}.effective_rank"] = eff_rank + aggregation = self.effective_rank_config.get("forecast_aggregation", "all") + tensors = self._get_tensors_to_monitor( + student_latent, teacher_latent, self.effective_rank_config + ) + + for name, tensor in tensors.items(): + if tensor is None: + continue + + if isinstance(tensor, list): + seq_metrics = self._compute_sequence_metrics( + tensor, + self._compute_effective_rank, + f"collapse.{name}.effective_rank", + aggregation, + sample_size=sample_size, + ) + metrics.update(seq_metrics) + else: + eff_rank = self._compute_effective_rank(tensor, sample_size) + metrics[f"collapse.{name}.effective_rank"] = eff_rank # Compute singular value spectrum if self.singular_values_config.get("enabled", True): sample_size = self.singular_values_config.get("sample_size", 2048) - for name, tensor in tensors_to_monitor.items(): - if tensor is not None: - source = self.singular_values_config.get("tensor_source", "both") - if source == "both" or source == name: - sv_metrics = self._compute_singular_values(tensor, sample_size) - for key, value in sv_metrics.items(): - metrics[f"collapse.{name}.{key}"] = value + aggregation = self.singular_values_config.get("forecast_aggregation", "all") + tensors = self._get_tensors_to_monitor( + student_latent, teacher_latent, self.singular_values_config + ) + + for name, tensor in tensors.items(): + if tensor is None: + continue + + if isinstance(tensor, list): + seq_metrics = self._compute_sequence_dict_metrics( + tensor, + self._compute_singular_values, + f"collapse.{name}", + aggregation, + sample_size=sample_size, + ) + metrics.update(seq_metrics) + else: + sv_metrics = self._compute_singular_values(tensor, sample_size) + for key, value in sv_metrics.items(): + metrics[f"collapse.{name}.{key}"] = value # Compute per-dimension variance if self.dimension_variance_config.get("enabled", True): - for name, tensor in tensors_to_monitor.items(): - if tensor is not None: - source = self.dimension_variance_config.get("tensor_source", "both") - if source == "both" or source == name: - var_metrics = self._compute_dimension_variance(tensor) - for key, value in var_metrics.items(): - metrics[f"collapse.{name}.{key}"] = value + aggregation = self.dimension_variance_config.get("forecast_aggregation", "all") + tensors = self._get_tensors_to_monitor( + student_latent, teacher_latent, self.dimension_variance_config + ) + + for name, tensor in tensors.items(): + if tensor is None: + continue + + if isinstance(tensor, list): + seq_metrics = self._compute_sequence_dict_metrics( + tensor, + self._compute_dimension_variance, + f"collapse.{name}", + aggregation, + ) + metrics.update(seq_metrics) + else: + var_metrics = self._compute_dimension_variance(tensor) + for key, value in var_metrics.items(): + metrics[f"collapse.{name}.{key}"] = value # Compute prototype entropy (DINO only) if ( @@ -338,10 +550,21 @@ def _compute_dimension_variance(self, z: torch.Tensor) -> dict[str, float]: z: Latent representations [B, N, D] or [B, D]. Returns: - Dictionary with var_min, var_mean, var_max. + Dictionary with var_min, var_mean, var_max. Empty dict if tensor is invalid. """ z = self._flatten_to_samples(z.detach()) + # Validate tensor + if z.numel() == 0: + logger.warning("Empty tensor in dimension variance computation") + return {} + if torch.isnan(z).any() or torch.isinf(z).any(): + logger.warning("NaN/Inf values in tensor for dimension variance computation") + return {} + if z.shape[0] < 2: + logger.warning(f"Need at least 2 samples to compute variance: shape={z.shape}") + return {} + # Compute variance along sample dimension var_per_dim = z.var(dim=0) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index af89e3598..150fa06ab 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -830,9 +830,7 @@ def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: teacher_latent = teacher_latent_data[0] elif isinstance(teacher_latent_data, dict): # Handle LatentState or dict - teacher_latent = teacher_latent_data.get( - "latent", teacher_latent_data - ) + teacher_latent = teacher_latent_data.get("latent", teacher_latent_data) else: teacher_latent = teacher_latent_data break diff --git a/tests/test_collapse_monitor.py b/tests/test_collapse_monitor.py index 6a6f2ed8c..e5fc75985 100644 --- a/tests/test_collapse_monitor.py +++ b/tests/test_collapse_monitor.py @@ -415,3 +415,464 @@ def test_no_sampling_when_small(self, monitor): # Sample size larger than N sampled = monitor._sample_rows(z, sample_size=1024) assert sampled.shape[0] == num_samples # No sampling occurred + + +class TestConfigValidation: + """Test configuration validation.""" + + def test_invalid_compute_frequency(self): + """Test that non-positive compute_frequency raises error.""" + config = {"enabled": True, "compute_frequency": 0} + with pytest.raises(ValueError, match="compute_frequency must be positive"): + CollapseMonitor(config, torch.device("cpu")) + + config = {"enabled": True, "compute_frequency": -1} + with pytest.raises(ValueError, match="compute_frequency must be positive"): + CollapseMonitor(config, torch.device("cpu")) + + def test_invalid_log_frequency(self): + """Test that non-positive log_frequency raises error.""" + config = {"enabled": True, "log_frequency": 0} + with pytest.raises(ValueError, match="log_frequency must be positive"): + CollapseMonitor(config, torch.device("cpu")) + + def test_invalid_tensor_source(self): + """Test that invalid tensor_source raises error.""" + config = { + "enabled": True, + "metrics": { + "effective_rank": {"tensor_source": "invalid_source"}, + }, + } + with pytest.raises(ValueError, match="Invalid tensor_source"): + CollapseMonitor(config, torch.device("cpu")) + + def test_invalid_forecast_aggregation(self): + """Test that invalid forecast_aggregation raises error.""" + config = { + "enabled": True, + "metrics": { + "effective_rank": {"forecast_aggregation": "invalid_agg"}, + }, + } + with pytest.raises(ValueError, match="Invalid forecast_aggregation"): + CollapseMonitor(config, torch.device("cpu")) + + def test_valid_tensor_sources(self): + """Test that valid tensor_source values are accepted.""" + for source in ["student", "teacher", "both"]: + config = { + "enabled": True, + "metrics": { + "effective_rank": {"tensor_source": source}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + assert monitor is not None + + def test_valid_forecast_aggregations(self): + """Test that valid forecast_aggregation values are accepted.""" + for agg in ["all", "aggregate_only", "per_step_only"]: + config = { + "enabled": True, + "metrics": { + "effective_rank": {"forecast_aggregation": agg}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + assert monitor is not None + + +class TestDimensionVarianceValidation: + """Test validation in _compute_dimension_variance.""" + + def test_empty_tensor(self, monitor): + """Test that empty tensor returns empty dict.""" + z = torch.empty(0, 64) + result = monitor._compute_dimension_variance(z) + assert result == {} + + def test_nan_tensor(self, monitor): + """Test that tensor with NaN returns empty dict.""" + z = torch.randn(64, 32) + z[0, 0] = float("nan") + result = monitor._compute_dimension_variance(z) + assert result == {} + + def test_inf_tensor(self, monitor): + """Test that tensor with Inf returns empty dict.""" + z = torch.randn(64, 32) + z[0, 0] = float("inf") + result = monitor._compute_dimension_variance(z) + assert result == {} + + def test_single_sample(self, monitor): + """Test that single sample returns empty dict (can't compute variance).""" + z = torch.randn(1, 32) + result = monitor._compute_dimension_variance(z) + assert result == {} + + def test_valid_tensor(self, monitor): + """Test that valid tensor returns metrics.""" + torch.manual_seed(42) + z = torch.randn(64, 32) + result = monitor._compute_dimension_variance(z) + assert "var_min" in result + assert "var_mean" in result + assert "var_max" in result + + +class TestForecastingSequences: + """Test forecasting with sequences of latents.""" + + @pytest.fixture + def forecast_config(self): + """Config for forecasting tests.""" + return { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": { + "enabled": True, + "tensor_source": "student", + "sample_size": 0, + "forecast_aggregation": "all", + }, + "singular_values": { + "enabled": True, + "tensor_source": "student", + "sample_size": 0, + "forecast_aggregation": "all", + }, + "dimension_variance": { + "enabled": True, + "tensor_source": "student", + "forecast_aggregation": "all", + }, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + + @pytest.fixture + def forecast_monitor(self, forecast_config): + """Create a monitor configured for forecasting.""" + return CollapseMonitor(forecast_config, torch.device("cpu")) + + def test_sequence_per_step_metrics(self, forecast_monitor): + """Test that per-step metrics are computed for sequences.""" + torch.manual_seed(42) + # Create 3 time steps of latents + latents = [torch.randn(32, 64) for _ in range(3)] + + metrics = forecast_monitor.compute_metrics(student_latent=latents) + + # Check per-step effective rank metrics + assert "collapse.student.effective_rank.step_0" in metrics + assert "collapse.student.effective_rank.step_1" in metrics + assert "collapse.student.effective_rank.step_2" in metrics + + # Check per-step variance metrics + assert "collapse.student.var_min.step_0" in metrics + assert "collapse.student.var_min.step_1" in metrics + assert "collapse.student.var_min.step_2" in metrics + + def test_sequence_aggregate_metrics(self, forecast_monitor): + """Test that aggregate metrics are computed for sequences.""" + torch.manual_seed(42) + latents = [torch.randn(32, 64) for _ in range(3)] + + metrics = forecast_monitor.compute_metrics(student_latent=latents) + + # Check aggregate effective rank metrics + assert "collapse.student.effective_rank.mean" in metrics + assert "collapse.student.effective_rank.min" in metrics + assert "collapse.student.effective_rank.max" in metrics + assert "collapse.student.effective_rank.degradation" in metrics + + # Verify aggregates are consistent + step_values = [ + metrics["collapse.student.effective_rank.step_0"], + metrics["collapse.student.effective_rank.step_1"], + metrics["collapse.student.effective_rank.step_2"], + ] + assert metrics["collapse.student.effective_rank.min"] == min(step_values) + assert metrics["collapse.student.effective_rank.max"] == max(step_values) + assert abs(metrics["collapse.student.effective_rank.mean"] - sum(step_values) / 3) < 1e-6 + + def test_degradation_metric(self, forecast_monitor): + """Test degradation metric (final/initial ratio).""" + torch.manual_seed(42) + # Create latents with controlled rank degradation + dim = 64 + # Step 0: Full rank random matrix + step_0 = torch.randn(128, dim) + # Step 1: Lower rank (rank ~32) + u1 = torch.randn(128, 32) + v1 = torch.randn(32, dim) + step_1 = u1 @ v1 + # Step 2: Even lower rank (rank ~8) + u2 = torch.randn(128, 8) + v2 = torch.randn(8, dim) + step_2 = u2 @ v2 + + latents = [step_0, step_1, step_2] + metrics = forecast_monitor.compute_metrics(student_latent=latents) + + # Degradation should be < 1 since rank decreases + degradation = metrics["collapse.student.effective_rank.degradation"] + assert degradation < 1.0, f"Expected degradation < 1.0, got {degradation}" + + # Verify degradation is step_2 / step_0 + expected = ( + metrics["collapse.student.effective_rank.step_2"] + / metrics["collapse.student.effective_rank.step_0"] + ) + assert abs(degradation - expected) < 1e-6 + + def test_aggregate_only_mode(self): + """Test forecast_aggregation='aggregate_only' mode.""" + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": { + "enabled": True, + "tensor_source": "student", + "forecast_aggregation": "aggregate_only", + }, + "singular_values": {"enabled": False}, + "dimension_variance": {"enabled": False}, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + + torch.manual_seed(42) + latents = [torch.randn(32, 64) for _ in range(3)] + metrics = monitor.compute_metrics(student_latent=latents) + + # Should NOT have per-step metrics + assert "collapse.student.effective_rank.step_0" not in metrics + + # Should have aggregate metrics + assert "collapse.student.effective_rank.mean" in metrics + assert "collapse.student.effective_rank.min" in metrics + + def test_per_step_only_mode(self): + """Test forecast_aggregation='per_step_only' mode.""" + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": { + "enabled": True, + "tensor_source": "student", + "forecast_aggregation": "per_step_only", + }, + "singular_values": {"enabled": False}, + "dimension_variance": {"enabled": False}, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + + torch.manual_seed(42) + latents = [torch.randn(32, 64) for _ in range(3)] + metrics = monitor.compute_metrics(student_latent=latents) + + # Should have per-step metrics + assert "collapse.student.effective_rank.step_0" in metrics + assert "collapse.student.effective_rank.step_1" in metrics + + # Should NOT have aggregate metrics + assert "collapse.student.effective_rank.mean" not in metrics + assert "collapse.student.effective_rank.degradation" not in metrics + + def test_empty_sequence(self, forecast_monitor): + """Test that empty sequence returns no metrics.""" + metrics = forecast_monitor.compute_metrics(student_latent=[]) + + # No effective rank metrics should be present + assert not any(key.startswith("collapse.student.effective_rank") for key in metrics) + + def test_single_step_sequence(self, forecast_monitor): + """Test sequence with single step (no degradation possible).""" + torch.manual_seed(42) + latents = [torch.randn(32, 64)] + + metrics = forecast_monitor.compute_metrics(student_latent=latents) + + # Should have step_0 + assert "collapse.student.effective_rank.step_0" in metrics + + # Should have aggregates (single value) + assert "collapse.student.effective_rank.mean" in metrics + + # Degradation should be 1.0 (same step) + assert metrics["collapse.student.effective_rank.degradation"] == 1.0 + + def test_mixed_single_and_sequence(self, forecast_monitor): + """Test with single tensor for student and sequence for teacher.""" + # Need to enable teacher monitoring + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": { + "enabled": True, + "tensor_source": "both", + "forecast_aggregation": "all", + }, + "singular_values": {"enabled": False}, + "dimension_variance": {"enabled": False}, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + + torch.manual_seed(42) + student = torch.randn(32, 64) # Single tensor + teacher = [torch.randn(32, 64) for _ in range(3)] # Sequence + + metrics = monitor.compute_metrics(student_latent=student, teacher_latent=teacher) + + # Student should have single metric + assert "collapse.student.effective_rank" in metrics + assert "collapse.student.effective_rank.step_0" not in metrics + + # Teacher should have sequence metrics + assert "collapse.teacher.effective_rank.step_0" in metrics + assert "collapse.teacher.effective_rank.mean" in metrics + + def test_3d_tensor_sequence(self, forecast_monitor): + """Test sequence of 3D tensors [B, N, D].""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + latents = [torch.randn(batch_size, num_patches, dim) for _ in range(3)] + + metrics = forecast_monitor.compute_metrics(student_latent=latents) + + # Should flatten each tensor and compute metrics + assert "collapse.student.effective_rank.step_0" in metrics + assert "collapse.student.effective_rank.mean" in metrics + + # Values should be reasonable (between 1 and dim) + for i in range(3): + value = metrics[f"collapse.student.effective_rank.step_{i}"] + assert 1 <= value <= dim + + +class TestSequenceSingularValues: + """Test singular value metrics for sequences.""" + + @pytest.fixture + def sv_monitor(self): + """Monitor configured for singular value tests.""" + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": {"enabled": False}, + "singular_values": { + "enabled": True, + "tensor_source": "student", + "sample_size": 0, + "forecast_aggregation": "all", + }, + "dimension_variance": {"enabled": False}, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + return CollapseMonitor(config, torch.device("cpu")) + + def test_sv_sequence_metrics(self, sv_monitor): + """Test singular value metrics for sequences.""" + torch.manual_seed(42) + latents = [torch.randn(64, 32) for _ in range(3)] + + metrics = sv_monitor.compute_metrics(student_latent=latents) + + # Per-step metrics + assert "collapse.student.sv_min.step_0" in metrics + assert "collapse.student.sv_max.step_0" in metrics + assert "collapse.student.sv_concentration.step_0" in metrics + + # Aggregate metrics + assert "collapse.student.sv_min.mean" in metrics + assert "collapse.student.sv_max.mean" in metrics + assert "collapse.student.sv_concentration.mean" in metrics + + +class TestSequenceVariance: + """Test dimension variance metrics for sequences.""" + + @pytest.fixture + def var_monitor(self): + """Monitor configured for variance tests.""" + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": {"enabled": False}, + "singular_values": {"enabled": False}, + "dimension_variance": { + "enabled": True, + "tensor_source": "student", + "forecast_aggregation": "all", + }, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + return CollapseMonitor(config, torch.device("cpu")) + + def test_variance_sequence_metrics(self, var_monitor): + """Test variance metrics for sequences.""" + torch.manual_seed(42) + latents = [torch.randn(64, 32) for _ in range(3)] + + metrics = var_monitor.compute_metrics(student_latent=latents) + + # Per-step metrics + assert "collapse.student.var_min.step_0" in metrics + assert "collapse.student.var_mean.step_0" in metrics + assert "collapse.student.var_max.step_0" in metrics + + # Aggregate metrics + assert "collapse.student.var_min.mean" in metrics + assert "collapse.student.var_mean.mean" in metrics + assert "collapse.student.var_max.mean" in metrics + + def test_variance_detects_collapse_over_time(self, var_monitor): + """Test that variance metrics can detect collapse over forecast steps.""" + torch.manual_seed(42) + dim = 32 + + # Step 0: Normal variance + step_0 = torch.randn(128, dim) + + # Step 1: Some dimensions start dying + step_1 = torch.randn(128, dim) + step_1[:, :8] *= 0.1 # Reduce variance in 8 dims + + # Step 2: More dimensions dead + step_2 = torch.randn(128, dim) + step_2[:, :16] *= 0.01 # Almost dead in 16 dims + + latents = [step_0, step_1, step_2] + metrics = var_monitor.compute_metrics(student_latent=latents) + + # var_min should decrease over steps (more dead dimensions) + assert metrics["collapse.student.var_min.step_0"] > metrics["collapse.student.var_min.step_1"] + assert metrics["collapse.student.var_min.step_1"] > metrics["collapse.student.var_min.step_2"] From 97f973463b0cb99337462b201e4e7da8846b06dc Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Thu, 5 Feb 2026 22:36:32 +0100 Subject: [PATCH 09/10] Fix no collapse monitoring for forecasting --- src/weathergen/train/trainer.py | 141 +++++++++++++++++++++----------- 1 file changed, 93 insertions(+), 48 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 150fa06ab..2c8b82431 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -24,6 +24,7 @@ from weathergen.common.config import Config, merge_configs from weathergen.datasets.multi_stream_data_sampler import MultiStreamDataSampler from weathergen.model.ema import EMAModel +from weathergen.model.engines import LatentState from weathergen.model.model_interface import ( get_target_aux_calculator, init_model_and_shard, @@ -798,58 +799,85 @@ def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: Extract latent tensors from predictions and targets, then compute collapse metrics. This method extracts the student and teacher latent representations from the - SSL training outputs and passes them to the collapse monitor. + model outputs. It supports two modes: + + 1. SSL training (JEPA/DINO/iBOT): Extracts latents from SSL-specific keys + 2. Forecasting: Extracts latents from 'latent_state' at each forecast step + + For forecasting, a list of latent tensors is passed to enable per-step metrics. """ - # Get student latents from predictions (first forecast step) - student_latent = None - teacher_latent = None + student_latent: torch.Tensor | list[torch.Tensor] | None = None + teacher_latent: torch.Tensor | list[torch.Tensor] | None = None prototype_probs = None ema_beta = None loss_type = None - # Find SSL loss type and extract latents + # Helper to extract tensor from various latent formats + def extract_latent_tensor( + latent_data: torch.Tensor | LatentState | list | dict | None, + ) -> torch.Tensor | None: + """Extract tensor from various latent data formats.""" + if latent_data is None: + return None + if isinstance(latent_data, torch.Tensor): + return latent_data + if isinstance(latent_data, LatentState): + # Use patch_tokens as the primary latent representation + return latent_data.patch_tokens + if isinstance(latent_data, list) and len(latent_data) > 0: + return extract_latent_tensor(latent_data[0]) + if isinstance(latent_data, dict): + # Try common keys + for key in ["latent", "patch_tokens"]: + if key in latent_data: + return extract_latent_tensor(latent_data[key]) + return None + + # Find SSL loss type and extract teacher latents for _loss_name, target_aux in targets_and_auxs.items(): - # Check if this is an EMATeacher-based loss - if hasattr(target_aux, "latent") and target_aux.latent: - # Handle both cases: - # 1. latent is a list[dict] (as per TargetAuxOutput dataclass) - # 2. latent is a dict (as set directly by EMATeacher) - if isinstance(target_aux.latent, list): - target_latent_dict = target_aux.latent[0] if target_aux.latent else {} - else: - # EMATeacher sets latent directly as a dict - target_latent_dict = target_aux.latent - - # Determine the SSL loss type (JEPA, DINO, iBOT) - for ssl_type in ["JEPA", "DINO", "iBOT"]: - if ssl_type in target_latent_dict: - loss_type = ssl_type - # Get teacher latent - teacher_latent_data = target_latent_dict[ssl_type] - if isinstance(teacher_latent_data, list) and len(teacher_latent_data) > 0: - teacher_latent = teacher_latent_data[0] - elif isinstance(teacher_latent_data, dict): - # Handle LatentState or dict - teacher_latent = teacher_latent_data.get("latent", teacher_latent_data) - else: - teacher_latent = teacher_latent_data - break + if not hasattr(target_aux, "latent") or not target_aux.latent: + continue + + # Handle both list[dict] and dict formats + if isinstance(target_aux.latent, list): + target_latent_dict = target_aux.latent[0] if target_aux.latent else {} + else: + target_latent_dict = target_aux.latent + + # Try SSL-specific keys first + for ssl_type in ["JEPA", "DINO", "iBOT"]: + if ssl_type in target_latent_dict: + loss_type = ssl_type + teacher_latent = extract_latent_tensor(target_latent_dict[ssl_type]) + break - # Get student latents from predictions + # Extract student latents from predictions if preds.latent and len(preds.latent) > 0: + # First, try SSL-specific keys (JEPA/DINO/iBOT) from first step pred_latent_dict = preds.latent[0] for ssl_type in ["JEPA", "DINO", "iBOT"]: if ssl_type in pred_latent_dict: - student_latent_data = pred_latent_dict[ssl_type] - if isinstance(student_latent_data, list) and len(student_latent_data) > 0: - student_latent = student_latent_data[0] - elif isinstance(student_latent_data, dict): - student_latent = student_latent_data.get("latent", student_latent_data) - else: - student_latent = student_latent_data + student_latent = extract_latent_tensor(pred_latent_dict[ssl_type]) loss_type = ssl_type break + # If no SSL keys found, extract from latent_state for all forecast steps + if student_latent is None: + student_latents_list: list[torch.Tensor] = [] + for step_latent_dict in preds.latent: + if "latent_state" in step_latent_dict: + step_tensor = extract_latent_tensor(step_latent_dict["latent_state"]) + if step_tensor is not None: + student_latents_list.append(step_tensor) + + # Use list if multiple steps, single tensor otherwise + if len(student_latents_list) > 1: + student_latent = student_latents_list + n_steps = len(student_latents_list) + logger.debug(f"Collapse monitor - forecasting mode: {n_steps} steps") + elif len(student_latents_list) == 1: + student_latent = student_latents_list[0] + # Get EMA beta from target_and_aux_calculators for _calc_name, calculator in self.target_and_aux_calculators.items(): if isinstance(calculator, EMATeacher): @@ -858,24 +886,41 @@ def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: ema_beta = calculator.get_current_beta(step) break - # Debug logging for tensor extraction + # Debug logging if student_latent is not None: - shape = student_latent.shape if isinstance(student_latent, torch.Tensor) else "N/A" - logger.debug(f"Collapse monitor - student: type={type(student_latent)}, shape={shape}") + if isinstance(student_latent, list): + shapes = [t.shape for t in student_latent] + logger.debug(f"Collapse monitor - student (list): {len(shapes)} steps") + else: + logger.debug(f"Collapse monitor - student: shape={student_latent.shape}") else: logger.debug("Collapse monitor - student_latent is None") if teacher_latent is not None: - shape = teacher_latent.shape if isinstance(teacher_latent, torch.Tensor) else "N/A" - logger.debug(f"Collapse monitor - teacher: type={type(teacher_latent)}, shape={shape}") - else: - logger.debug("Collapse monitor - teacher_latent is None") + if isinstance(teacher_latent, list): + logger.debug(f"Collapse monitor - teacher (list): {len(teacher_latent)} steps") + else: + shape = teacher_latent.shape if isinstance(teacher_latent, torch.Tensor) else "N/A" + logger.debug(f"Collapse monitor - teacher: shape={shape}") + + # Compute metrics if we have valid student latent + has_valid_latent = student_latent is not None and ( + isinstance(student_latent, torch.Tensor) + or (isinstance(student_latent, list) and len(student_latent) > 0) + ) + + if has_valid_latent: + # Prepare teacher latent (must match student format if provided) + teacher_for_metrics = None + if teacher_latent is not None: + is_valid_tensor = isinstance(teacher_latent, torch.Tensor) + is_valid_list = isinstance(teacher_latent, list) and len(teacher_latent) > 0 + if is_valid_tensor or is_valid_list: + teacher_for_metrics = teacher_latent - # Ensure tensors are properly formatted - if student_latent is not None and isinstance(student_latent, torch.Tensor): self.collapse_monitor.compute_metrics( student_latent=student_latent, - teacher_latent=teacher_latent if isinstance(teacher_latent, torch.Tensor) else None, + teacher_latent=teacher_for_metrics, prototype_probs=prototype_probs, ema_beta=ema_beta, loss_type=loss_type, From 0111e75d6e8b2de5c3a52fd1264551af33e32e50 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Thu, 5 Feb 2026 22:43:16 +0100 Subject: [PATCH 10/10] Try to fix forecasting --- config/default_config.yml | 31 ++++++++++++++++++++++++++++++- src/weathergen/train/trainer.py | 11 +++++++++-- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 4bff1abfd..0d1dc5230 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -134,10 +134,39 @@ data_loading : # config for training training_config: - + # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking"] + # Collapse monitoring for detecting representation collapse + # Works with SSL training (JEPA/DINO) and forecasting modes + # For forecasting, monitors latent_state.patch_tokens at each forecast step + collapse_monitoring: + enabled: true + compute_frequency: 100 # batches between metric computations + log_frequency: 100 # batches between metric logging + metrics: + effective_rank: + enabled: true + tensor_source: "student" # "student", "teacher", or "both" + sample_size: 2048 # max samples for SVD (0 = no sampling) + # For forecasting sequences: "all" (per-step + aggregates), + # "aggregate_only" (mean/min/max/degradation), "per_step_only" + forecast_aggregation: "all" + singular_values: + enabled: true + tensor_source: "student" + sample_size: 2048 + forecast_aggregation: "all" + dimension_variance: + enabled: true + tensor_source: "student" + forecast_aggregation: "all" + prototype_entropy: + enabled: false # only relevant for DINO + ema_beta: + enabled: false # only relevant for SSL with EMA teacher + num_mini_epochs: 32 samples_per_mini_epoch: 4096 shuffle: True diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 2c8b82431..fda4eab60 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -823,12 +823,19 @@ def extract_latent_tensor( return latent_data if isinstance(latent_data, LatentState): # Use patch_tokens as the primary latent representation - return latent_data.patch_tokens + # For forecast steps > 0, patch_tokens is None, so fall back to z_pre_norm + if latent_data.patch_tokens is not None: + return latent_data.patch_tokens + if latent_data.z_pre_norm is not None: + # z_pre_norm includes register/class tokens, extract patch tokens only + # This assumes the same token layout as patch_tokens + return latent_data.z_pre_norm + return None if isinstance(latent_data, list) and len(latent_data) > 0: return extract_latent_tensor(latent_data[0]) if isinstance(latent_data, dict): # Try common keys - for key in ["latent", "patch_tokens"]: + for key in ["latent", "patch_tokens", "z_pre_norm"]: if key in latent_data: return extract_latent_tensor(latent_data[key]) return None